diff --git a/.claude/feat/swarm_mode/task.json b/.claude/feat/swarm_mode/task.json new file mode 100644 index 000000000..f6c36e380 --- /dev/null +++ b/.claude/feat/swarm_mode/task.json @@ -0,0 +1,588 @@ +{ + "project": "PicoClaw - Swarm Mode 多实例协同架构", + "repository": "https://github.com/sipeed/picoclaw", + "version": "1.0.0", + "sprint": "00.sprint - 0213~0228", + "created_at": "2026-02-16", + "updated_at": "2026-02-20", + "tasks": [ + { + "id": "PHASE0-DOCS-001", + "milestone": "Phase 0 - 身份与权限", + "subject": "完成架构文档与场景定义", + "status": "completed", + "priority": "P0", + "assignee": "Claude", + "description": "完成 Swarm Mode 的架构设计文档,定义参考场景和关键技术栈", + "deliverables": [ + "00.架构设计.md / 00.Architecture.md", + "01.记忆权限模型设计.md / 01.Memory-Permissions.md", + "02.技术介绍与概念梳理.md / 02.Technical-Guide.md", + "03.Scenario-Local-Distributed-Swarm.md" + ], + "links": [ + "https://github.com/sipeed/picoclaw/discussions/119" + ], + "estimated_weeks": 0 + }, + { + "id": "PHASE0-ID-001", + "milestone": "Phase 0 - 身份与权限", + "subject": "实现 Identity 模型(H-id / S-id)", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现两级身份模型,支持 CLI 参数、环境变量、配置文件、自动生成四种来源", + "deliverables": [ + "pkg/identity/identity.go - Identity 结构体", + "pkg/identity/loader.go - 身份加载逻辑", + "pkg/identity/generator.go - 身份自动生成", + "cmd/picoclaw/identity flags 支持" + ], + "blocked_by": [], + "dependencies": [], + "estimated_weeks": 0.5 + }, + { + "id": "PHASE0-MEM-001", + "milestone": "Phase 0 - 身份与权限", + "subject": "实现 Memory Permissions 模型", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现三层记忆权限模型(Private/Shared/Public)与 MemoryItem 结构", + "deliverables": [ + "pkg/memory/types.go - MemoryScope, MemoryType 枚举", + "pkg/memory/item.go - MemoryItem 结构", + "pkg/memory/permission.go - 权限检查逻辑", + "pkg/memory/acl.go - AllowList/DenyList 支持" + ], + "blocked_by": ["PHASE0-ID-001"], + "dependencies": ["PHASE0-ID-001"], + "estimated_weeks": 1 + }, + { + "id": "PHASE0-REL-001", + "milestone": "Phase 0 - 身份与权限", + "subject": "实现 Resource + Relation 权限模型", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现基于 Resource + Relation 的细粒度访问控制,支持跨 H-id 授权", + "deliverables": [ + "pkg/relation/resource.go - Resource 结构", + "pkg/relation/relation.go - Relation 结构", + "pkg/relation/registry.go - Relation Registry", + "pkg/relation/policy.go - Policy 定义与检查", + "pkg/relation/authorizer.go - 授权逻辑" + ], + "blocked_by": ["PHASE0-ID-001", "PHASE0-MEM-001"], + "dependencies": ["PHASE0-ID-001", "PHASE0-MEM-001"], + "estimated_weeks": 1.5 + }, + { + "id": "PHASE0-TEST-001", + "milestone": "Phase 0 - 身份与权限", + "subject": "Phase 0 单元测试", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "为 Identity、Memory Permissions、Relation 模型编写单元测试", + "deliverables": [ + "pkg/identity/*_test.go", + "pkg/memory/*_test.go", + "pkg/relation/*_test.go", + "测试覆盖率 > 80%" + ], + "blocked_by": ["PHASE0-REL-001"], + "dependencies": ["PHASE0-REL-001"], + "estimated_weeks": 0.5 + }, + { + "id": "PHASE1-NATS-001", + "milestone": "Phase 1 - NATS 基础设施", + "subject": "实现 NATS Bridge 基础实现", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现 NATS 连接管理、重连逻辑、基本发布订阅功能", + "deliverables": [ + "pkg/nats/conn.go - 连接管理", + "pkg/nats/options.go - 连接选项配置", + "pkg/nats/publisher.go - 发布者封装", + "pkg/nats/subscriber.go - 订阅者封装", + "pkg/nats/reconnect.go - 重连逻辑" + ], + "blocked_by": ["PHASE0-TEST-001"], + "dependencies": ["PHASE0-TEST-001"], + "estimated_weeks": 1 + }, + { + "id": "PHASE1-NATS-002", + "milestone": "Phase 1 - NATS 基础设施", + "subject": "实现按 H-id 分区", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现基于 H-id 的主题分区,确保租户隔离", + "deliverables": [ + "pkg/nats/subject.go - Subject 构建器", + "pkg/nats/partition.go - H-id 分区逻辑", + "主题命名规范实现: picoclaw.{domain}.{hid}.*" + ], + "blocked_by": ["PHASE1-NATS-001"], + "dependencies": ["PHASE1-NATS-001"], + "estimated_weeks": 0.5 + }, + { + "id": "PHASE1-HB-001", + "milestone": "Phase 1 - NATS 基础设施", + "subject": "实现心跳检测与存活判定", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现节点心跳发布(10s)、可疑判定(30s)、离线判定(60s)机制", + "deliverables": [ + "pkg/swarm/heartbeat.go - 心跳发布器", + "pkg/swarm/monitor.go - 心跳监控器", + "pkg/swarm/registry.go - 节点注册表", + "HeartbeatConfig 配置支持", + "NodeStatus 状态机: online/busy/draining/offline" + ], + "blocked_by": ["PHASE1-NATS-002"], + "dependencies": ["PHASE1-NATS-002"], + "estimated_weeks": 1 + }, + { + "id": "PHASE1-DISC-001", + "milestone": "Phase 1 - NATS 基础设施", + "subject": "实现节点发现机制", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现基于 NATS 的节点发现,支持能力声明与查询", + "deliverables": [ + "pkg/swarm/discovery.go - 节点发现", + "NodeCapabilities 结构体", + "discovery.announce 主题", + "discovery.query 主题", + "本地节点缓存维护" + ], + "blocked_by": ["PHASE1-HB-001"], + "dependencies": ["PHASE1-HB-001"], + "estimated_weeks": 1 + }, + { + "id": "PHASE1-XHID-001", + "milestone": "Phase 1 - NATS 基础设施", + "subject": "实现跨 H-id 通信授权", + "status": "completed", + "priority": "P1", + "assignee": null, + "description": "实现跨租户通信机制,基于 Relation 授权检查", + "deliverables": [ + "pkg/nats/cross_hid.go - 跨 H-id 通信", + "picoclaw.x.{from_hid}.{to_hid} 主题支持", + "Relation 授权集成", + "Export/Import 配置支持" + ], + "blocked_by": ["PHASE1-DISC-001", "PHASE0-REL-001"], + "dependencies": ["PHASE1-DISC-001", "PHASE0-REL-001"], + "estimated_weeks": 1.5 + }, + { + "id": "PHASE2-COORD-001", + "milestone": "Phase 2 - 多节点协作", + "subject": "实现 Coordinator 角色基础", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现 Coordinator 节点角色,负责接收请求、分解任务、编排工作流", + "deliverables": [ + "pkg/swarm/coordinator.go - Coordinator 核心", + "pkg/swarm/scheduler.go - 任务调度器", + "task.assign.{sid} 主题支持", + "task.broadcast.{capability} 主题支持" + ], + "blocked_by": ["PHASE1-XHID-001"], + "dependencies": ["PHASE1-XHID-001"], + "estimated_weeks": 1.5 + }, + { + "id": "PHASE2-WORK-001", + "milestone": "Phase 2 - 多节点协作", + "subject": "实现 Worker 角色基础", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现 Worker 节点角色,负责执行任务、报告结果", + "deliverables": [ + "pkg/swarm/worker.go - Worker 核心", + "pkg/swarm/executor.go - 任务执行��", + "task.result.{task_id} 主题支持", + "task.progress.{task_id} 主题支持" + ], + "blocked_by": ["PHASE1-XHID-001"], + "dependencies": ["PHASE1-XHID-001"], + "estimated_weeks": 1.5 + }, + { + "id": "PHASE2-QUEUE-001", + "milestone": "Phase 2 - 多节点协作", + "subject": "实现 NATS Queue Group 负载均衡", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "使用 NATS Queue Group 实现任务在多个 Worker 间的负载均衡", + "deliverables": [ + "pkg/swarm/queue_group.go - Queue Group 封装", + "Worker 自动注册到 Queue Group", + "任务自动分配逻辑", + "Worker 下线自动重分配" + ], + "blocked_by": ["PHASE2-WORK-001"], + "dependencies": ["PHASE2-WORK-001"], + "estimated_weeks": 1 + }, + { + "id": "PHASE3-TEMP-001", + "milestone": "Phase 3 - Temporal 工作流", + "subject": "集成 Temporal SDK", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "集成 Temporal Go SDK,建立与 Temporal Server 的连接", + "deliverables": [ + "pkg/temporal/client.go - Temporal 客户端", + "pkg/temporal/options.go - 连接选项", + "Task Queue 配置", + "Worker 注册机制" + ], + "blocked_by": ["PHASE2-QUEUE-001"], + "dependencies": ["PHASE2-QUEUE-001"], + "estimated_weeks": 1 + }, + { + "id": "PHASE3-WF-001", + "milestone": "Phase 3 - Temporal 工作流", + "subject": "实现核心 Workflow", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现长运行任务的 Workflow 编排,支持任务分解、子任务协调、结果聚合", + "deliverables": [ + "pkg/temporal/workflow/task.go - 任务 Workflow", + "pkg/temporal/workflow/parallel.go - 并行任务 Workflow", + "Workflow 定义接口", + "Activity 调用封装" + ], + "blocked_by": ["PHASE3-TEMP-001"], + "dependencies": ["PHASE3-TEMP-001"], + "estimated_weeks": 2 + }, + { + "id": "PHASE3-ACT-001", + "milestone": "Phase 3 - Temporal 工作流", + "subject": "实现 Activity 与测试", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现具体的 Activity(LLM 调用、Memory 读写等)和测试", + "deliverables": [ + "pkg/temporal/activity/llm.go - LLM Activity", + "pkg/temporal/activity/memory.go - Memory Activity", + "pkg/temporal/activity/tool.go - Tool Activity", + "Activity 单元测试", + "Workflow 集成测试" + ], + "blocked_by": ["PHASE3-WF-001"], + "dependencies": ["PHASE3-WF-001"], + "estimated_weeks": 1.5 + }, + { + "id": "PHASE4-LIFECYCLE-001", + "milestone": "Phase 4 - 故障迁移", + "subject": "实现任务生命周期存储", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "基于 JetStream 实现任务生命周期持久化,支持状态转换记录", + "deliverables": [ + "pkg/swarm/lifecycle/store.go - 生命周期存储", + "pkg/swarm/lifecycle/state.go - 状态机", + "task.lifecycle.{task_id} 主题", + "TaskStatus 状态转换验证", + "状态转换历史查询" + ], + "blocked_by": ["PHASE3-ACT-001"], + "dependencies": ["PHASE3-ACT-001"], + "estimated_weeks": 1 + }, + { + "id": "PHASE4-CHECKPOINT-001", + "milestone": "Phase 4 - 故障迁移", + "subject": "实现任务检查点机制", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现任务检查点创建、存储、加载,支持故障恢复", + "deliverables": [ + "pkg/swarm/checkpoint/checkpoint.go - 检查点结构", + "pkg/swarm/checkpoint/writer.go - 检查点写入器", + "pkg/swarm/checkpoint/reader.go - 检查点读取器", + "task.checkpoint.{task_id} 主题", + "CheckpointConfig 配置" + ], + "blocked_by": ["PHASE4-LIFECYCLE-001"], + "dependencies": ["PHASE4-LIFECYCLE-001"], + "estimated_weeks": 1 + }, + { + "id": "PHASE4-FAILOVER-001", + "milestone": "Phase 4 - 故障迁移", + "subject": "实现故障迁移机制", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "实现故障检测、任务声明、检查点恢复的完整迁移流程", + "deliverables": [ + "pkg/swarm/failure/migrator.go - 故障迁移器", + "pkg/swarm/failure/claimer.go - 任务声明逻辑", + "pkg/swarm/failure/restorer.go - 检查点恢复", + "脑裂防护(JetStream 精确一次语义)", + "故障迁移测试" + ], + "blocked_by": ["PHASE4-CHECKPOINT-001"], + "dependencies": ["PHASE4-CHECKPOINT-001"], + "estimated_weeks": 1.5 + }, + { + "id": "PHASE5-PARALLEL-001", + "milestone": "Phase 5 - 并行执行", + "subject": "实现并行任务执行", + "status": "completed", + "priority": "P1", + "assignee": null, + "description": "实现父任务分解、子任务并行分配、结果聚合的并行执行模型", + "deliverables": [ + "pkg/swarm/parallel/executor.go - 并行执行器", + "pkg/swarm/parallel/subtask.go - 子任务定义", + "pkg/swarm/parallel/coordinator.go - 子任务协调", + "WaitMode: All/Any/Quorum/None", + "DAG 依赖关系支持" + ], + "blocked_by": ["PHASE4-FAILOVER-001"], + "dependencies": ["PHASE4-FAILOVER-001"], + "estimated_weeks": 1.5 + }, + { + "id": "PHASE6-SPEC-001", + "milestone": "Phase 6 - Specialist 角色", + "subject": "实现 Specialist 角色与能力注册", + "status": "completed", + "priority": "P1", + "assignee": null, + "description": "实现专业化节点角色,支持特定领域能力声明与任务路由", + "deliverables": [ + "pkg/swarm/specialist/role.go - Specialist 角色", + "pkg/swarm/capability/registry.go - 能力注册表", + "基于能力的任务路由", + "Specialist 配置支持" + ], + "blocked_by": ["PHASE5-PARALLEL-001"], + "dependencies": ["PHASE5-PARALLEL-001"], + "estimated_weeks": 1 + }, + { + "id": "PHASE7-EDGE-001", + "milestone": "Phase 7 - 边缘部署", + "subject": "边缘部署支持", + "status": "completed", + "priority": "P1", + "assignee": null, + "description": "支持在资源受限的边缘设备(RISC-V、移动设备)上部署 Worker 节点", + "deliverables": [ + "交叉编译配置", + "轻量级 Worker 模式", + "资源限制配置", + "边缘设备部署文档" + ], + "blocked_by": ["PHASE6-SPEC-001"], + "dependencies": ["PHASE6-SPEC-001"], + "estimated_weeks": 1 + }, + { + "id": "PHASE8-DOCS-001", + "milestone": "Phase 8 - 文档与示例", + "subject": "完善文档与示例", + "status": "completed", + "priority": "P1", + "assignee": null, + "description": "完善用户文档、API 文档、部署示例", + "deliverables": [ + "用户部署指南", + "API 参考文档", + "配置说明", + "示例部署配置(Docker Compose/K8s)", + "故障排查指南" + ], + "blocked_by": ["PHASE7-EDGE-001"], + "dependencies": ["PHASE7-EDGE-001"], + "estimated_weeks": 0.5 + }, + { + "id": "PHASE8-JETSTREAM-001", + "milestone": "Phase 1 - NATS 基础设施", + "subject": "实现 JetStream Memory Store", + "status": "completed", + "priority": "P0", + "assignee": null, + "description": "基于 NATS JetStream 实现热存储,支持记忆的 CRUD、查询、Watch", + "deliverables": [ + "pkg/memory/jetstream/store.go - JetStream Memory Store", + "pkg/memory/jetstream/stream.go - Stream 管理", + "pkg/memory/jetstream/consumer.go - Consumer 管理", + "KV Store 支持", + "Query 与 Watch API" + ], + "blocked_by": ["PHASE1-NATS-002"], + "dependencies": ["PHASE1-NATS-002"], + "estimated_weeks": 1.5 + }, + { + "id": "PHASE8-OSS-001", + "milestone": "Phase 8 - 冷热存储", + "subject": "实现 OSS 冷存储", + "status": "pending", + "priority": "P1", + "assignee": null, + "description": "实现 S3/MinIO 冷存储,支持记忆归档与提升", + "deliverables": [ + "pkg/memory/oss/backend.go - OSS 后端接口", + "pkg/memory/oss/s3.go - S3 实现", + "pkg/memory/oss/minio.go - MinIO 实现", + "自动归档任务", + "冷热透明访问" + ], + "blocked_by": ["PHASE8-JETSTREAM-001"], + "dependencies": ["PHASE8-JETSTREAM-001"], + "estimated_weeks": 1.5 + }, + { + "id": "PHASE9-DASH-001", + "milestone": "Phase 9 - Dashboard 与监控", + "subject": "实现 Swarm Dashboard", + "status": "completed", + "priority": "P2", + "assignee": "Claude", + "description": "实现基于终端的 Dashboard,用于监控 swarm 状态", + "deliverables": [ + "pkg/swarm/dashboard.go - Dashboard 核心", + "pkg/swarm/dashboard_test.go - 测试", + "实时状态显示(节点、连接、选举)", + "紧凑单行状态输出" + ], + "blocked_by": [], + "dependencies": [], + "estimated_weeks": 0.5 + } + ], + "summary": { + "total": 26, + "by_status": { + "completed": 24, + "pending": 2, + "in_progress": 0, + "blocked": 0 + }, + "by_priority": { + "P0": 20, + "P1": 5, + "P2": 1 + }, + "by_milestone": { + "Phase 0 - 身份与权限": 5, + "Phase 1 - NATS 基础设施": 6, + "Phase 2 - 多节点协作": 3, + "Phase 3 - Temporal 工作流": 3, + "Phase 4 - 故障迁移": 3, + "Phase 5 - 并行执行": 1, + "Phase 6 - Specialist 角色": 1, + "Phase 7 - 边缘部署": 1, + "Phase 8 - 文档与示例": 2 + }, + "estimated_weeks": 25 + }, + "milestones": [ + { + "name": "Phase 0 - 身份与权限", + "description": "建立身份模型(H-id/S-id)和权限控制(Memory Permissions + Resource/Relation)", + "priority": "P0", + "weeks": 4, + "tasks": ["PHASE0-ID-001", "PHASE0-MEM-001", "PHASE0-REL-001", "PHASE0-TEST-001"] + }, + { + "name": "Phase 1 - NATS 基础设施", + "description": "实现 NATS 连接、H-id 分区、心跳检测、节点发现、跨 H-id 通信", + "priority": "P0", + "weeks": 6, + "tasks": ["PHASE1-NATS-001", "PHASE1-NATS-002", "PHASE1-HB-001", "PHASE1-DISC-001", "PHASE1-XHID-001", "PHASE8-JETSTREAM-001"] + }, + { + "name": "Phase 2 - 多节点协作", + "description": "实现 Coordinator 和 Worker 角色,NATS Queue Group 负载均衡", + "priority": "P0", + "weeks": 4, + "tasks": ["PHASE2-COORD-001", "PHASE2-WORK-001", "PHASE2-QUEUE-001"] + }, + { + "name": "Phase 3 - Temporal 工作流", + "description": "集成 Temporal SDK,实现核心 Workflow 和 Activity", + "priority": "P0", + "weeks": 4.5, + "tasks": ["PHASE3-TEMP-001", "PHASE3-WF-001", "PHASE3-ACT-001"] + }, + { + "name": "Phase 4 - 故障迁移", + "description": "实现任务生命周期存储、检查点机制、故障迁移", + "priority": "P0", + "weeks": 3.5, + "tasks": ["PHASE4-LIFECYCLE-001", "PHASE4-CHECKPOINT-001", "PHASE4-FAILOVER-001"] + }, + { + "name": "Phase 5 - 并行执行", + "description": "实现并行任务执行、DAG 依赖关系、子任务协调", + "priority": "P1", + "weeks": 1.5, + "tasks": ["PHASE5-PARALLEL-001"] + }, + { + "name": "Phase 6 - Specialist 角色", + "description": "实现专业化节点角色与能力注册", + "priority": "P1", + "weeks": 1, + "tasks": ["PHASE6-SPEC-001"] + }, + { + "name": "Phase 7 - 边缘部署", + "description": "支持边缘设备部署", + "priority": "P1", + "weeks": 1, + "tasks": ["PHASE7-EDGE-001"] + }, + { + "name": "Phase 8 - 冷热存储", + "description": "实现 OSS 冷存储与自动归档", + "priority": "P1", + "weeks": 1.5, + "tasks": ["PHASE8-OSS-001"] + }, + { + "name": "文档与示例", + "description": "完善文档与示例", + "priority": "P1", + "weeks": 0.5, + "tasks": ["PHASE8-DOCS-001"] + } + ] +} diff --git a/.env.example b/.env.example index c450b6e8c..06d43070c 100644 --- a/.env.example +++ b/.env.example @@ -5,10 +5,13 @@ # ANTHROPIC_API_KEY=sk-ant-xxx # OPENAI_API_KEY=sk-xxx # GEMINI_API_KEY=xxx +# CEREBRAS_API_KEY=xxx # ── Chat Channel ────────────────────────── # TELEGRAM_BOT_TOKEN=123456:ABC... # DISCORD_BOT_TOKEN=xxx +# LINE_CHANNEL_SECRET=xxx +# LINE_CHANNEL_ACCESS_TOKEN=xxx # ── Web Search (optional) ──────────────── # BRAVE_SEARCH_API_KEY=BSA... diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..4be385b22 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,28 @@ +--- +name: Bug report +about: Report a bug or unexpected behavior +title: "[BUG]" +labels: bug +assignees: '' + +--- + +## Quick Summary + +## Environment & Tools +- **PicoClaw Version:** (e.g., v0.1.2 or commit hash) +- **Go Version:** (e.g., go 1.22) +- **AI Model & Provider:** (e.g., GPT-4o via OpenAI / DeepSeek via SiliconFlow) +- **Operating System:** (e.g., Ubuntu 22.04 / macOS / Android Termux) +- **Channels:** (e.g., Discord, Telegram, Feishu, ...) + +## 📸 Steps to Reproduce +1. +2. +3. + +## ❌ Actual Behavior + +## ✅ Expected Behavior + +## 💬 Additional Context diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..d3df0e79c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,23 @@ +--- +name: Feature request +about: Suggest a new idea or improvement +title: "[Feature]" +labels: enhancement +assignees: '' + +--- + +## 🎯 The Goal / Use Case + +## 💡 Proposed Solution + +## 🛠 Potential Implementation (Optional) + +## 🚦 Impact & Roadmap Alignment +- [ ] This is a Core Feature +- [ ] This is a Nice-to-Have / Enhancement +- [ ] This aligns with the current Roadmap + +## 🔄 Alternatives Considered + +## 💬 Additional Context diff --git a/.github/ISSUE_TEMPLATE/general-task---todo.md b/.github/ISSUE_TEMPLATE/general-task---todo.md new file mode 100644 index 000000000..eab70c030 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/general-task---todo.md @@ -0,0 +1,26 @@ +--- +name: General Task / Todo +about: A specific piece of work like doc, refactoring, or maintenance. +title: "[Task]" +labels: '' +assignees: '' + +--- + +## 📝 Objective + +## 📋 To-Do List +- [ ] Step 1 +- [ ] Step 2 +- [ ] Step 3 + +## 🎯 Definition of Done (Acceptance Criteria) +- [ ] Documentation is updated in the README/docs folder. +- [ ] Code follows project linting standards. +- [ ] (If applicable) Basic tests pass. + +## 💡 Context / Motivation + +## 🔗 Related Issues / PRs +- Fixes # +- Relates to # diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..c96b7da12 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,43 @@ +## 📝 Description + + + +## 🗣️ Type of Change +- [ ] 🐞 Bug fix (non-breaking change which fixes an issue) +- [ ] ✨ New feature (non-breaking change which adds functionality) +- [ ] 📖 Documentation update +- [ ] ⚡ Code refactoring (no functional changes, no api changes) + +## 🤖 AI Code Generation +- [ ] 🤖 Fully AI-generated (100% AI, 0% Human) +- [ ] 🛠️ Mostly AI-generated (AI draft, Human verified/modified) +- [ ] 👨‍💻 Mostly Human-written (Human lead, AI assisted or none) + + +## 🔗 Related Issue + + + +## 📚 Technical Context (Skip for Docs) +- **Reference URL:** +- **Reasoning:** + +## 🧪 Test Environment +- **Hardware:** +- **OS:** +- **Model/Provider:** +- **Channels:** + + +## 📸 Evidence (Optional) +
+Click to view Logs/Screenshots + + + +
+ +## ☑️ Checklist +- [ ] My code/docs follow the style of this project. +- [ ] I have performed a self-review of my own changes. +- [ ] I have updated the documentation accordingly. \ No newline at end of file diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index aad0f3262..9b89b69ae 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,18 +2,17 @@ name: build on: push: - branches: ["main"] - pull_request: + branches: [ "main" ] jobs: build: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version-file: go.mod diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 5f8bbd303..dadbed212 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -1,13 +1,18 @@ name: 🐳 Build & Push Docker Image on: - push: - branches: [main] - tags: ["v*"] + workflow_call: + inputs: + tag: + description: "Release tag" + required: true + type: string env: - REGISTRY: ghcr.io - IMAGE_NAME: ${{ github.repository_owner }}/picoclaw + GHCR_REGISTRY: ghcr.io + GHCR_IMAGE_NAME: ${{ github.repository_owner }}/picoclaw + DOCKERHUB_REGISTRY: docker.io + DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }} jobs: build: @@ -20,7 +25,9 @@ jobs: steps: # ── Checkout ────────────────────────────── - name: 📥 Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 + with: + ref: ${{ inputs.tag }} # ── Docker Buildx ───────────────────────── - name: 🔧 Set up Docker Buildx @@ -28,36 +35,42 @@ jobs: # ── Login to GHCR ───────────────────────── - name: 🔑 Login to GitHub Container Registry - if: github.event_name != 'pull_request' uses: docker/login-action@v3 with: - registry: ${{ env.REGISTRY }} + registry: ${{ env.GHCR_REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - # ── Metadata (tags & labels) ────────────── - - name: 🏷️ Extract Docker metadata - id: meta - uses: docker/metadata-action@v5 + # ── Login to Docker Hub ──────────────────── + - name: 🔑 Login to Docker Hub + uses: docker/login-action@v3 with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} - tags: | - type=ref,event=branch - type=ref,event=pr - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha,prefix= - type=raw,value=latest,enable={{is_default_branch}} - type=raw,value={{date 'YYYYMMDD-HHmmss'}},enable={{is_default_branch}} + registry: ${{ env.DOCKERHUB_REGISTRY }} + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + # ── Metadata (tags & labels) ────────────── + - name: 🏷️ Prepare image tags + id: tags + shell: bash + run: | + tag="${{ inputs.tag }}" + echo "ghcr_tag=${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_NAME }}:${tag}" >> "$GITHUB_OUTPUT" + echo "ghcr_latest=${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_NAME }}:latest" >> "$GITHUB_OUTPUT" + echo "dockerhub_tag=${{ env.DOCKERHUB_REGISTRY }}/${{ env.DOCKERHUB_IMAGE_NAME }}:${tag}" >> "$GITHUB_OUTPUT" + echo "dockerhub_latest=${{ env.DOCKERHUB_REGISTRY }}/${{ env.DOCKERHUB_IMAGE_NAME }}:latest" >> "$GITHUB_OUTPUT" # ── Build & Push ────────────────────────── - name: 🚀 Build and push Docker image uses: docker/build-push-action@v6 with: context: . - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} + push: true + tags: | + ${{ steps.tags.outputs.ghcr_tag }} + ${{ steps.tags.outputs.ghcr_latest }} + ${{ steps.tags.outputs.dockerhub_tag }} + ${{ steps.tags.outputs.dockerhub_latest }} cache-from: type=gha cache-to: type=gha,mode=max - platforms: linux/amd64,linux/arm64 + platforms: linux/amd64,linux/arm64,linux/riscv64 diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml new file mode 100644 index 000000000..be1c10c52 --- /dev/null +++ b/.github/workflows/pr.yml @@ -0,0 +1,43 @@ +name: PR + +on: + pull_request: { } + +jobs: + lint: + name: Linter + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Setup Go + uses: actions/setup-go@v6 + with: + go-version-file: go.mod + + - name: Run go generate + run: go generate ./... + + - name: Golangci Lint + uses: golangci/golangci-lint-action@v9 + with: + version: v2.10.1 + + test: + name: Tests + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Setup Go + uses: actions/setup-go@v6 + with: + go-version-file: go.mod + + - name: Run go generate + run: go generate ./... + + - name: Run go test + run: go test ./... diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 59cc6caeb..786c893ef 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,74 +26,77 @@ jobs: contents: write steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Create and push tag shell: bash + env: + RELEASE_TAG: ${{ inputs.tag }} run: | git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" - git tag -a "${{ inputs.tag }}" -m "Release ${{ inputs.tag }}" - git push origin "${{ inputs.tag }}" + git tag -a "$RELEASE_TAG" -m "Release $RELEASE_TAG" + git push origin "$RELEASE_TAG" - build-binaries: - name: Build Release Binaries + release: + name: GoReleaser Release needs: create-tag runs-on: ubuntu-latest + permissions: + contents: write + packages: write steps: - name: Checkout tag - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: + fetch-depth: 0 ref: ${{ inputs.tag }} - name: Setup Go from go.mod - uses: actions/setup-go@v5 + id: setup-go + uses: actions/setup-go@v6 with: go-version-file: go.mod - - name: Build all binaries - run: make build-all + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 - - name: Generate checksums - shell: bash - run: | - shasum -a 256 build/picoclaw-* > build/sha256sums.txt + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 - - name: Upload release binaries artifact - uses: actions/upload-artifact@v4 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 with: - name: picoclaw-binaries - path: | - build/picoclaw-* - build/sha256sums.txt - if-no-files-found: error + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} - create-release: - name: Create GitHub Release - needs: [create-tag, build-binaries] - runs-on: ubuntu-latest - permissions: - contents: write - steps: - - name: Download all artifacts - uses: actions/download-artifact@v4 + - name: Login to Docker Hub + uses: docker/login-action@v3 with: - path: release-artifacts + registry: docker.io + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Show downloaded files - run: ls -R release-artifacts - - - name: Create release - uses: softprops/action-gh-release@v2 + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v6 with: - tag_name: ${{ inputs.tag }} - name: ${{ inputs.tag }} - draft: ${{ inputs.draft }} - prerelease: ${{ inputs.prerelease }} - files: | - release-artifacts/**/* - generate_release_notes: true + distribution: goreleaser + version: ~> v2 + args: release --clean env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }} + DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }} + GOVERSION: ${{ steps.setup-go.outputs.go-version }} + + - name: Apply release flags + shell: bash + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + gh release edit "${{ inputs.tag }}" \ + --draft=${{ inputs.draft }} \ + --prerelease=${{ inputs.prerelease }} diff --git a/.gitignore b/.gitignore index 7163f5fdf..ce30d749e 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ build/ *.out /picoclaw /picoclaw-test +cmd/picoclaw/workspace # Picoclaw specific @@ -34,4 +35,12 @@ coverage.html # Ralph workspace ralph/ -.ralph/ \ No newline at end of file +.ralph/ +tasks/ + +# Editors +.vscode/ +.idea/ + +# Added by goreleaser init: +dist/ diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 000000000..d45d69e67 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,182 @@ +version: "2" + +linters: + default: all + disable: + # TODO: Tweak for current project needs + - containedctx + - cyclop + - depguard + - dupl + - dupword + - err113 + - exhaustruct + - funcorder + - gochecknoglobals + - godot + - intrange + - ireturn + - nlreturn + - noctx + - noinlineerr + - nonamedreturns + - tagliatelle + - testpackage + - varnamelen + - wrapcheck + - wsl + - wsl_v5 + + # TODO: Disabled, because they are failing at the moment, we should fix them and enable (step by step) + - bodyclose + - contextcheck + - dogsled + - embeddedstructfieldcheck + - errcheck + - errchkjson + - errorlint + - exhaustive + - forbidigo + - forcetypeassert + - funlen + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godox + - goprintffuncname + - gosec + - ineffassign + - lll + - maintidx + - misspell + - mnd + - modernize + - nakedret + - nestif + - nilnil + - paralleltest + - perfsprint + - prealloc + - predeclared + - revive + - staticcheck + - tagalign + - testifylint + - thelper + - unparam + - unused + - usestdlibvars + - usetesting + - wastedassign + - whitespace + settings: + errcheck: + check-type-assertions: true + check-blank: true + exhaustive: + default-signifies-exhaustive: true + funlen: + lines: 120 + statements: 40 + gocognit: + min-complexity: 25 + gocyclo: + min-complexity: 20 + govet: + enable-all: true + disable: + - fieldalignment + lll: + line-length: 120 + tab-width: 4 + misspell: + locale: US + mnd: + checks: + - argument + - assign + - case + - condition + - operation + - return + nakedret: + max-func-lines: 3 + revive: + enable-all-rules: true + rules: + - name: add-constant + disabled: true + - name: argument-limit + arguments: + - 7 + severity: warning + - name: banned-characters + disabled: true + - name: cognitive-complexity + disabled: true + - name: comment-spacings + arguments: + - nolint + severity: warning + - name: cyclomatic + disabled: true + - name: file-header + disabled: true + - name: function-result-limit + arguments: + - 3 + severity: warning + - name: function-length + disabled: true + - name: line-length-limit + disabled: true + - name: max-public-structs + disabled: true + - name: modifies-value-receiver + disabled: true + - name: package-comments + disabled: true + - name: unused-receiver + disabled: true + exclusions: + generated: lax + rules: + - linters: + - lll + source: '^//go:generate ' + - linters: + - funlen + - maintidx + - gocognit + - gocyclo + path: _test\.go$ + +issues: + max-issues-per-linter: 0 + max-same-issues: 0 + +formatters: + enable: + - gci + - gofmt + - gofumpt + - goimports + - golines + settings: + gci: + sections: + - standard + - default + - localmodule + custom-order: true + gofmt: + simplify: true + rewrite-rules: + - pattern: "interface{}" + replacement: "any" + - pattern: "a[b:len(a)]" + replacement: "a[b:]" + golines: + max-len: 120 diff --git a/.goreleaser.yaml b/.goreleaser.yaml new file mode 100644 index 000000000..2c47f7d86 --- /dev/null +++ b/.goreleaser.yaml @@ -0,0 +1,87 @@ +# yaml-language-server: $schema=https://goreleaser.com/static/schema.json +# vim: set ts=2 sw=2 tw=0 fo=cnqoj +version: 2 + +before: + hooks: + - go mod tidy + - go generate ./cmd/picoclaw + +builds: + - id: picoclaw + env: + - CGO_ENABLED=0 + tags: + - stdjson + ldflags: + - -s -w + - -X main.version={{ .Version }} + - -X main.gitCommit={{ .ShortCommit }} + - -X main.buildTime={{ .Date }} + - -X main.goVersion={{ .Env.GOVERSION }} + goos: + - linux + - windows + - darwin + - freebsd + goarch: + - amd64 + - arm64 + - riscv64 + - s390x + - mips64 + - arm + main: ./cmd/picoclaw + ignore: + - goos: windows + goarch: arm + +dockers_v2: + - id: picoclaw + dockerfile: Dockerfile.goreleaser + ids: + - picoclaw + images: + - "ghcr.io/{{ .Env.GITHUB_REPOSITORY_OWNER }}/picoclaw" + - "docker.io/{{ .Env.DOCKERHUB_IMAGE_NAME }}" + tags: + - "{{ .Tag }}" + - "latest" + platforms: + - linux/amd64 + - linux/arm64 + - linux/riscv64 + +archives: + - formats: [tar.gz] + # this name template makes the OS and Arch compatible with the results of `uname`. + name_template: >- + {{ .ProjectName }}_ + {{- title .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + # use zip for windows archives + format_overrides: + - goos: windows + formats: [zip] + +changelog: + sort: asc + filters: + exclude: + - "^docs:" + - "^test:" + +# upx: +# - enabled: true +# compress: best +# lzma: true + +release: + footer: >- + + --- + + Released by [GoReleaser](https://github.com/goreleaser/goreleaser). diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..88227f493 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,302 @@ +# Contributing to PicoClaw + +Thank you for your interest in contributing to PicoClaw! This project is a community-driven effort to build the lightweight and versatile personal AI assistant. We welcome contributions of all kinds: bug fixes, features, documentation, translations, and testing. + +PicoClaw itself was substantially developed with AI assistance — we embrace this approach and have built our contribution process around it. + +## Table of Contents + +- [Code of Conduct](#code-of-conduct) +- [Ways to Contribute](#ways-to-contribute) +- [Getting Started](#getting-started) +- [Development Setup](#development-setup) +- [Making Changes](#making-changes) +- [AI-Assisted Contributions](#ai-assisted-contributions) +- [Pull Request Process](#pull-request-process) +- [Branch Strategy](#branch-strategy) +- [Code Review](#code-review) +- [Communication](#communication) + +--- + +## Code of Conduct + +We are committed to maintaining a welcoming and respectful community. Be kind, constructive, and assume good faith. Harassment or discrimination of any kind will not be tolerated. + +--- + +## Ways to Contribute + +- **Bug reports** — Open an issue using the bug report template. +- **Feature requests** — Open an issue using the feature request template; discuss before implementing. +- **Code** — Fix bugs or implement features. See the workflow below. +- **Documentation** — Improve READMEs, docs, inline comments, or translations. +- **Testing** — Run PicoClaw on new hardware, channels, or LLM providers and report your results. + +For substantial new features, please open an issue first to discuss the design before writing code. This prevents wasted effort and ensures alignment with the project's direction. + +--- + +## Getting Started + +1. **Fork** the repository on GitHub. +2. **Clone** your fork locally: + ```bash + git clone https://github.com//picoclaw.git + cd picoclaw + ``` +3. Add the upstream remote: + ```bash + git remote add upstream https://github.com/sipeed/picoclaw.git + ``` + +--- + +## Development Setup + +### Prerequisites + +- Go 1.25 or later +- `make` + +### Build + +```bash +make build # Build binary (runs go generate first) +make generate # Run go generate only +make check # Full pre-commit check: deps + fmt + vet + test +``` + +### Running Tests + +```bash +make test # Run all tests +go test -run TestName -v ./pkg/session/ # Run a single test +go test -bench=. -benchmem -run='^$' ./... # Run benchmarks +``` + +### Code Style + +```bash +make fmt # Format code +make vet # Static analysis +make lint # Full linter run +``` + +All CI checks must pass before a PR can be merged. Run `make check` locally before pushing to catch issues early. + +--- + +## Making Changes + +### Branching + +Always branch off `main` and target `main` in your PR. Never push directly to `main` or any `release/*` branch: + +```bash +git checkout main +git pull upstream main +git checkout -b your-feature-branch +``` + +Use descriptive branch names, e.g. `fix/telegram-timeout`, `feat/ollama-provider`, `docs/contributing-guide`. + +### Commits + +- Write clear, concise commit messages in English. +- Use the imperative mood: "Add retry logic" not "Added retry logic". +- Reference the related issue when relevant: `Fix session leak (#123)`. +- Keep commits focused. One logical change per commit is preferred. +- For minor cleanups or typo fixes, squash them into a single commit before opening a PR. +- Refer to https://www.conventionalcommits.org/zh-hans/v1.0.0/ + +### Keeping Up to Date + +Rebase your branch onto upstream `main` before opening a PR: + +```bash +git fetch upstream +git rebase upstream/main +``` + +--- + +## AI-Assisted Contributions + +PicoClaw was built with substantial AI assistance, and we fully embrace AI-assisted development. However, contributors must understand their responsibilities when using AI tools. + +### Disclosure Is Required + +Every PR must disclose AI involvement using the PR template's **🤖 AI Code Generation** section. There are three levels: + +| Level | Description | +|---|---| +| 🤖 Fully AI-generated | AI wrote the code; contributor reviewed and validated it | +| 🛠️ Mostly AI-generated | AI produced the draft; contributor made significant modifications | +| 👨‍💻 Mostly Human-written | Contributor led; AI provided suggestions or none at all | + +Honest disclosure is expected. There is no stigma attached to any level — what matters is the quality of the contribution. + +### You Are Responsible for What You Submit + +Using AI to generate code does not reduce your responsibility as the contributor. Before opening a PR with AI-generated code, you must: + +- **Read and understand** every line of the generated code. +- **Test it** in a real environment (see the Test Environment section of the PR template). +- **Check for security issues** — AI models can generate subtly insecure code (e.g., path traversal, injection, credential exposure). Review carefully. +- **Verify correctness** — AI-generated logic can be plausible-sounding but wrong. Validate the behavior, not just the syntax. + +PRs where it is clear the contributor has not read or tested the AI-generated code will be closed without review. + +### AI-Generated Code Quality Standards + +AI-generated contributions are held to the **same quality bar** as human-written code: + +- It must pass all CI checks (`make check`). +- It must be idiomatic Go and consistent with the existing codebase style. +- It must not introduce unnecessary abstractions, dead code, or over-engineering. +- It must include or update tests where appropriate. + +### Security Review + +AI-generated code requires extra security scrutiny. Pay special attention to: + +- File path handling and sandbox escapes (see commit `244eb0b` for a real example) +- External input validation in channel handlers and tool implementations +- Credential or secret handling +- Command execution (`exec.Command`, shell invocations) + +If you are unsure whether a piece of AI-generated code is safe, say so in the PR — reviewers will help. + +--- + +## Pull Request Process + +### Before Opening a PR + +- [ ] Run `make check` and ensure it passes locally. +- [ ] Fill in the PR template completely, including the AI disclosure section. +- [ ] Link any related issue(s) in the PR description. +- [ ] Keep the PR focused. Avoid bundling unrelated changes together. + +### PR Template Sections + +The PR template asks for: + +- **Description** — What does this change do and why? +- **Type of Change** — Bug fix, feature, docs, or refactor. +- **AI Code Generation** — Disclosure of AI involvement (required). +- **Related Issue** — Link to the issue this addresses. +- **Technical Context** — Reference URLs and reasoning (skip for pure docs PRs). +- **Test Environment** — Hardware, OS, model/provider, and channels used for testing. +- **Evidence** — Optional logs or screenshots demonstrating the change works. +- **Checklist** — Self-review confirmation. + +### PR Size + +Prefer small, reviewable PRs. A PR that changes 200 lines across 5 files is much easier to review than one that changes 2000 lines across 30 files. If your feature is large, consider splitting it into a series of smaller, logically complete PRs. + +--- + +## Branch Strategy + +### Long-Lived Branches + +- **`main`** — the active development branch. All feature PRs target `main`. The branch is protected: direct pushes are not permitted, and at least one maintainer approval is required before merging. +- **`release/x.y`** — stable release branches, cut from `main` when a version is ready to ship. These branches are more strictly protected than `main`. + +### Requirements to Merge into `main` + +A PR can only be merged when all of the following are satisfied: + +1. **CI passes** — All GitHub Actions workflows (lint, test, build) must be green. +2. **Reviewer approval** — At least one maintainer has approved the PR. +3. **No unresolved review comments** — All review threads must be resolved. +4. **PR template is complete** — Including AI disclosure and test environment. + +### Who Can Merge + +Only maintainers can merge PRs. Contributors cannot merge their own PRs, even if they have write access. + +### Merge Strategy + +We use **squash merge** for most PRs to keep the `main` history clean and readable. Each merged PR becomes a single commit referencing the PR number, e.g.: + +``` +feat: Add Ollama provider support (#491) +``` + +If a PR consists of multiple independent, well-separated commits that tell a clear story, a regular merge may be used at the maintainer's discretion. + +### Release Branches + +When a version is ready, maintainers cut a `release/x.y` branch from `main`. After that point: + +- **New features are not backported.** The release branch receives no new functionality after it is cut. +- **Security fixes and critical bug fixes are cherry-picked.** If a fix in `main` qualifies (security vulnerability, data loss, crash), maintainers will cherry-pick the relevant commit(s) onto the affected `release/x.y` branch and issue a patch release. + +If you believe a fix in `main` should be backported to a release branch, note it in the PR description or open a separate issue. The decision rests with the maintainers. + +Release branches have stricter protections than `main` and are never directly pushed to under any circumstances. + +--- + +## Code Review + +### For Contributors + +- Respond to review comments within a reasonable time. If you need more time, say so. +- When you update a PR in response to feedback, briefly note what changed (e.g., "Updated to use `sync.RWMutex` as suggested"). +- If you disagree with feedback, engage respectfully. Explain your reasoning; reviewers can be wrong too. +- Do not force-push after a review has started — it makes it harder for reviewers to see what changed. Use additional commits instead; the maintainer will squash on merge. + +### For Reviewers + +Review for: + +1. **Correctness** — Does the code do what it claims? Are there edge cases? +2. **Security** — Especially for AI-generated code, tool implementations, and channel handlers. +3. **Architecture** — Is the approach consistent with the existing design? +4. **Simplicity** — Is there a simpler solution? Does this add unnecessary complexity? +5. **Tests** — Are the changes covered by tests? Are existing tests still meaningful? + +Be constructive and specific. "This could have a race condition if two goroutines call this concurrently — consider using a mutex here" is better than "this looks wrong". + + +### Reviewer List +Once your PR is submitted, you can reach out to the assigned reviewers listed in the following table. + +|Function| Reviewer| +|--- |--- | +|Provider|@yinwm | +|Channel |@yinwm | +|Agent |@lxowalle| +|Tools |@lxowalle| +|SKill || +|MCP || +|Optimization|@lxowalle| +|Security|| +|AI CI |@imguoguo| +|UX || +|Document|| + +--- + +## Communication + +- **GitHub Issues** — Bug reports, feature requests, design discussions. +- **GitHub Discussions** — General questions, ideas, community conversation. +- **Pull Request comments** — Code-specific feedback. +- **Wechat&Discord** — We will invite you when you have at least one merged PR + +When in doubt, open an issue before writing code. It costs little and prevents wasted effort. + +--- + +## A Note on the Project's AI-Driven Origin + +PicoClaw's architecture was substantially designed and implemented with AI assistance, guided by human oversight. If you find something that looks odd or over-engineered, it may be an artifact of that process — opening an issue to discuss it is always welcome. + +We believe AI-assisted development done responsibly produces great results. We also believe humans must remain accountable for what they ship. These two beliefs are not in conflict. + +Thank you for contributing! diff --git a/CONTRIBUTING.zh.md b/CONTRIBUTING.zh.md new file mode 100644 index 000000000..01a1abfd5 --- /dev/null +++ b/CONTRIBUTING.zh.md @@ -0,0 +1,303 @@ +# 参与贡献 PicoClaw + +感谢你对 PicoClaw 的关注!本项目是一个社区驱动的开源项目,目标是构建 轻量灵活,人人可用 的个人AI助手。我们欢迎一切形式的贡献:Bug 修复、新功能、文档、翻译和测试。 + +PicoClaw 本身在很大程度上是借助 AI 辅助开发的——我们拥抱这种方式,并围绕它构建了贡献流程。 + +## 目录 + +- [行为准则](#行为准则) +- [贡献方式](#贡献方式) +- [快速开始](#快速开始) +- [开发环境配置](#开发环境配置) +- [提交修改](#提交修改) +- [AI 辅助贡献](#ai-辅助贡献) +- [Pull Request 流程](#pull-request-流程) +- [分支策略](#分支策略) +- [代码审查](#代码审查) +- [沟通渠道](#沟通渠道) + +--- + +## 行为准则 + +我们致力于维护一个友好、互相尊重的社区环境。请保持善意、建设性的态度,并善意地理解他人。任何形式的骚扰或歧视均不被接受。 + +--- + +## 贡献方式 + +- **Bug 反馈** — 使用 Bug 报告模板提交 Issue。 +- **功能建议** — 使用功能请求模板提交 Issue,建议在开始实现前先进行讨论。 +- **代码贡献** — 修复 Bug 或实现新功能,参见下方工作流程。 +- **文档改进** — 完善 README、文档、代码注释或翻译。 +- **测试与验证** — 在新硬件、新渠道或新 LLM 提供商上运行 PicoClaw 并反馈结果。 + +对于较大的新功能,请先提交 Issue 讨论设计方案,再动手写代码。这能避免无效投入,也确保与项目方向保持一致。 + +--- + +## 快速开始 + +1. 在 GitHub 上 **Fork** 本仓库。 +2. 将你的 Fork **克隆**到本地: + ```bash + git clone https://github.com/<你的用户名>/picoclaw.git + cd picoclaw + ``` +3. 添加上游远程仓库: + ```bash + git remote add upstream https://github.com/sipeed/picoclaw.git + ``` + +--- + +## 开发环境配置 + +### 前置依赖 + +- Go 1.25 或更高版本 +- `make` + +### 构建 + +```bash +make build # 构建二进制文件(会先执行 go generate) +make generate # 仅执行 go generate +make check # 完整的提交前检查:deps + fmt + vet + test +``` + +### 运行测试 + +```bash +make test # 运行所有测试 +go test -run TestName -v ./pkg/session/ # 运行单个测试 +go test -bench=. -benchmem -run='^$' ./... # 运行基准测试 +``` + +### 代码风格 + +```bash +make fmt # 格式化代码 +make vet # 静态分析 +make lint # 完整的 lint 检查 +``` + +所有 CI 检查通过后 PR 才能被合并。推送代码前请先在本地运行 `make check`,提前发现问题。 + +--- + +## 提交修改 + +### 分支管理 + +始终从 `main` 分支切出,并在 PR 中以 `main` 为目标分支。不要直接向 `main` 或任何 `release/*` 分支推送代码: + +```bash +git checkout main +git pull upstream main +git checkout -b 你的功能分支名 +``` + +请使用描述性的分支名,例如:`fix/telegram-timeout`、`feat/ollama-provider`、`docs/contributing-guide`。 + +### Commit 规范 + +- 使用英文撰写清晰、简洁的 commit 信息。 +- 使用祈使句:写 "Add retry logic",而不是 "Added retry logic"。 +- 有关联 Issue 时请引用:`Fix session leak (#123)`。 +- 保持 commit 专注,每个 commit 只做一件事。 +- 对于小的清理或拼写修正,提 PR 前请将其合并为一个 commit。 +- 按照 https://www.conventionalcommits.org/zh-hans/v1.0.0/ 规范来撰写 + +### 保持与上游同步 + +提 PR 前,请将你的分支变基到上游 `main`: + +```bash +git fetch upstream +git rebase upstream/main +``` + +--- + +## AI 辅助贡献 + +PicoClaw 在很大程度上借助 AI 辅助开发,我们完全拥抱这种开发方式。但贡献者必须清楚地了解自己在使用 AI 工具时所承担的责任。 + +### 必须披露 AI 使用情况 + +每个 PR 都必须通过 PR 模板中的 **🤖 AI 代码生成** 部分披露 AI 参与情况,共分三个级别: + +| 级别 | 说明 | +|---|---| +| 🤖 完全由 AI 生成 | AI 编写代码,贡献者负责审查和验证 | +| 🛠️ 主要由 AI 生成 | AI 起草,贡献者做了较大修改 | +| 👨‍💻 主要由人工编写 | 贡献者主导,AI 仅提供辅助或未使用 AI | + +我们期望你诚实填写。三种级别均可接受,没有任何歧视——重要的是贡献的质量。 + +### 你对提交的代码负全责 + +使用 AI 生成代码并不能减轻你作为贡献者的责任。在提交含有 AI 生成代码的 PR 之前,你必须: + +- **逐行阅读并理解**生成的代码。 +- **在真实环境中测试**(参见 PR 模板中的测试环境部分)。 +- **检查安全问题** — AI 模型可能生成存在安全隐患的代码(如路径穿越、注入攻击、凭据泄露等),请仔细审查。 +- **验证正确性** — AI 生成的逻辑可能听起来合理但实际上是错误的,请验证行为,而不仅仅是语法。 + +如果明显可以看出贡献者没有阅读或测试 AI 生成的代码,该 PR 将被直接关闭,不予审查。 + +### AI 生成代码的质量标准 + +AI 生成的代码与人工编写的代码遵循**相同的质量要求**: + +- 必须通过所有 CI 检查(`make check`)。 +- 必须符合 Go 惯用写法,并与现有代码库的风格保持一致。 +- 不得引入不必要的抽象、死代码或过度设计。 +- 须在适当的地方包含或更新测试。 + +### 安全审查 + +AI 生成的代码需要格外仔细的安全审查。请特别关注以下方面: + +- 文件路径处理与沙箱逃逸(项目历史中的 commit `244eb0b` 就是真实案例) +- channel 处理器和 tool 实现中的外部输入校验 +- 凭据或密钥的处理 +- 命令执行(`exec.Command`、shell 调用等) + +如果你不确定某段 AI 生成代码是否安全,请在 PR 中说明——审查者会帮助判断。 + +--- + +## Pull Request 流程 + +### 提 PR 前的检查 + +- [ ] 在本地运行 `make check` 并确认通过。 +- [ ] 完整填写 PR 模板,包括 AI 披露部分。 +- [ ] 在 PR 描述中关联相关 Issue。 +- [ ] 保持 PR 专注,避免将不相关的修改混在一起。 + +### PR 模板各部分说明 + +PR 模板要求填写: + +- **描述** — 这个改动做了什么,为什么要做? +- **变更类型** — Bug 修复、新功能、文档或重构。 +- **AI 代码生成** — AI 参与情况披露(必填)。 +- **关联 Issue** — 此 PR 解决的 Issue 链接。 +- **技术背景** — 参考链接和设计理由(纯文档类 PR 可跳过)。 +- **测试环境** — 用于测试的硬件、操作系统、模型/提供商和渠道。 +- **验证证据** — 可选的日志或截图,用于证明改动有效。 +- **检查清单** — 自我审查确认。 + +### PR 规模 + +请尽量提交小而易于审查的 PR。一个涉及 5 个文件共 200 行改动的 PR,远比涉及 30 个文件共 2000 行改动的 PR 容易审查。如果你的功能较大,可以考虑将其拆分为一系列逻辑完整的小 PR。 + +--- + +## 分支策略 + +### 长期分支 + +- **`main`** — 活跃开发分支。所有功能 PR 均以 `main` 为目标。该分支受保护:禁止直接推送,合并前必须获得至少一名维护者的批准。 +- **`release/x.y`** — 稳定发布分支,在某个版本准备发布时从 `main` 切出。这些分支的保护级别高于 `main`。 + +### 合并到 `main` 的前提条件 + +PR 必须同时满足以下所有条件,才能被合并: + +1. **CI 全部通过** — 所有 GitHub Actions 工作流(lint、test、build)均为绿色。 +2. **获得审查者批准** — 至少一名维护者已批准该 PR。 +3. **无未解决的审查意见** — 所有审查讨论线程均已关闭。 +4. **PR 模板填写完整** — 包括 AI 披露和测试环境信息。 + +### 谁可以合并 + +只有维护者才能合并 PR。贡献者不能合并自己的 PR,即使拥有写权限也不行。 + +### 合并策略 + +为保持 `main` 历史清晰可读,我们对大多数 PR 使用 **Squash Merge**。每个合并的 PR 变为一个包含 PR 编号的单独 commit,例如: + +``` +feat: Add Ollama provider support (#491) +``` + +如果一个 PR 包含多个独立、结构清晰、能讲述完整故事的 commit,维护者可视情况使用普通 merge。 + +### Release 分支 + +当某个版本准备就绪时,维护者会从 `main` 切出 `release/x.y` 分支。此后: + +- **新功能不会被回溯(backport)。** Release 分支切出后,不再接收任何新功能。 +- **安全修复和关键 Bug 修复会被 cherry-pick 进来。** 若 `main` 上的某个修复属于安全漏洞、数据丢失或崩溃类问题,维护者会将相关 commit cherry-pick 到受影响的 `release/x.y` 分支,并发布补丁版本。 + +如果你认为 `main` 上的某个修复应该被回溯到某个 release 分支,请在 PR 描述中注明,或单独开一个 Issue 说明。最终决定由维护者做出。 + +Release 分支的保护级别高于 `main`,在任何情况下均不允许直接推送。 + +--- + +## 代码审查 + +### 对贡献者的建议 + +- 在合理时间内回复审查意见。如果需要更多时间,请告知。 +- 更新 PR 以响应反馈时,简要说明改动内容(例如:"按建议改用了 `sync.RWMutex`")。 +- 如果你不同意某条反馈,请礼貌地阐述你的理由——审查者也可能有判断失误的时候。 +- 审查开始后请不要 force push——这会让审查者难以追踪变化。请使用额外的 commit,维护者在合并时会进行 squash。 + +### 对审查者的建议 + +审查重点: + +1. **正确性** — 代码是否实现了其声称的功能?是否存在边界情况? +2. **安全性** — 对 AI 生成代码、tool 实现和 channel 处理器尤其需要关注。 +3. **架构** — 实现方式是否与现有设计一致? +4. **简洁性** — 是否有更简单的方案?是否引入了不必要的复杂度? +5. **测试** — 改动是否有测试覆盖?现有测试是否仍然有意义? + +请给出建设性且具体的反馈。"如果两个 goroutine 同时调用这个函数可能会有竞态条件,建议在这里加一个 mutex" 远比 "这里看起来有问题" 更有帮助。 + +### 审查者列表 +提交对应PR后,可以参考下表联系对应的审查人员沟通 + +|Function| Reviewer| +|--- |--- | +|Provider|@yinwm | +|Channel |@yinwm | +|Agent |@lxowalle| +|Tools |@lxowalle| +|SKill || +|MCP || +|Optimization|@lxowalle| +|Security|| +|AI CI |@imguoguo| +|UX || +|Document|| + + + +--- + +## 沟通渠道 + +- **GitHub Issues** — Bug 报告、功能建议、设计讨论。 +- **GitHub Discussions** — 一般性问题、想法交流、社区讨论。 +- **Pull Request 评论** — 与具体代码相关的反馈。 +- **Wechat&Discord** — 当你有至少一个已合并的PR后,我们会邀请你加入开发者交流群 + +有疑问时,请先开 Issue 讨论,再动手写代码。这几乎没有成本,却能避免大量无效投入。 + +--- + +## 关于本项目的 AI 驱动起源 + +PicoClaw 的架构在人工监督下,经由 AI 辅助完成了大量设计和实现工作。如果你发现某处看起来奇怪或过度设计,这可能是该过程留下的痕迹——欢迎提 Issue 讨论。 + +我们相信,负责任地使用 AI 辅助开发能产生优秀的成果。我们同样相信,人类必须对自己提交的内容负责。这两点并不矛盾。 + +感谢你的贡献! diff --git a/DISCUSSION-swarm-features.md b/DISCUSSION-swarm-features.md new file mode 100644 index 000000000..fcaab7a83 --- /dev/null +++ b/DISCUSSION-swarm-features.md @@ -0,0 +1,147 @@ +# PicoClaw Swarm: What Should a Fleet of Shrimp Actually Do? + +> `pkg/swarm/` | Open Discussion | 2026-02-13 + +--- + +## The Pitch + +So we can now spin up multiple PicoClaws and they find each other over NATS, exchange heartbeats, hand off tasks, and report results. Cool. But right now they're basically just vibing in a group chat — lots of pinging, not much doing. + +This discussion is about **what we actually want a swarm of shrimp to pull off**, and in what order we should build it. + +--- + +## P0 — Without These, the Swarm Is Just a Chat Room + +### 1. Task Decomposition + +`DecomposeTaskActivity` currently returns nil. Every request goes to one shrimp, no matter how big it is. That's like asking one guy to research Rust, Go, *and* Zig embedded dev and write the comparison report — while three other shrimp sit there doing nothing. + +The coordinator should be able to look at a request and go: "okay, this breaks into three parallel research jobs and one synthesis job." Three shrimp research simultaneously, a fourth writes the report. Done in a third of the time. + +**Open questions:** +- LLM-driven decomposition, rule-based heuristics, or a mix? +- How do we template the decomposition prompt without it getting brittle? +- Max decomposition depth? We don't want recursive splitting all the way to heat death. + +### 2. Result Synthesis + +`SynthesizeResultsActivity` literally concatenates sub-results with `=== Result 1 ===` headers. It's `fmt.Sprintf` cosplaying as intelligence. + +When multiple shrimp contribute partial answers, the coordinator should use the LLM to merge them into one coherent response — not a mechanical paste job. + +**Open questions:** +- Keep sub-task metadata in the final output (who ran it, how long it took)? +- If one sub-task failed, do we skip it, flag it, or retry before synthesis? +- Context budget — what happens when sub-results are collectively too long? + +### 3. Smarter Capability Routing + +Capabilities are currently static strings in config (`capabilities: ["code", "research"]`). That's fine for now, but it's pretty rigid. + +Where this should go: +- **Dynamic registration**: shrimp installs a new Skill, its capability list updates automatically. +- **Fuzzy matching**: task needs "code" but shrimp only advertises "golang" and "python" — close enough? +- **Weighted capabilities**: two shrimp both claim "code", but one is clearly better at it. Let the routing reflect that. + +--- + +## P1 — Real Collaboration, Not Just Delegation + +### 4. Inter-Shrimp Messaging + +Right now shrimp can only talk through task assignments and result callbacks. They can't have a side conversation. If the Rust-research shrimp stumbles on a great Go article, there's no way to toss it to the Go-research shrimp mid-task. + +Something like `picoclaw.swarm.chat.{from}.{to}` for point-to-point, maybe a shared topic channel for task-scoped broadcasts. + +**Open questions:** +- Trust model — do we blindly trust all shrimp in the same swarm, or add permissions? +- Persist chat history or treat it as ephemeral? + +### 5. Shared Context + +Each shrimp has its own `MEMORY.md`. Memories are fully isolated. When five shrimp collaborate on one job, they're all working with different context — that's a recipe for contradictory outputs. + +We need a task-scoped shared context pool. The coordinator seeds it with background info, workers push intermediate findings to it, everyone reads from it. + +Possible backends: +- **NATS KV Store** — lightweight, already in our dependency tree +- **Shared filesystem** — dead simple, doesn't scale +- **Redis** — proven, but adds a dependency (violates the "10MB shrimp" spirit) + +### 6. Dynamic Role Switching + +Roles are hardcoded at startup (`--role coordinator`). If the coordinator dies, the whole swarm goes headless. + +Ideas: +- **Leader election**: if the coordinator disappears, a worker gets promoted. NATS JetStream could handle this, or we go simple with "first to claim wins." +- **Boss does manual labor too**: if all workers are slammed, the coordinator should be able to pick up a task itself. +- **Role fluidity**: a worker could become a specialist on demand, or vice versa, based on what the swarm needs right now. + +**Open questions:** +- Election mechanism — JetStream advisory, Raft, or just a NATS-based lock? +- What triggers a role switch? Pure load? Manual? LLM decides? + +--- + +## P2 — Nice to Have + +### 7. Priority Queues and Preemption + +The `Priority` field on tasks (0=low to 3=critical) is currently decorative. Everything is FIFO. A `priority=3` alert should jump the queue — and maybe interrupt a shrimp that's leisurely composing a blog post. + +### 8. DAG Execution + +Sub-tasks are all fired in parallel right now with no dependency awareness. We should support: + +``` +Research Rust ──┐ +Research Go ──┼──> Write Comparison ──> Format & Polish +Research Zig ──┘ +``` + +The report waits until all three research tasks finish. Format waits for the report. Parallel where possible, sequential where necessary. + +### 9. Swarm Dashboard + +A live view of the fleet: +- Which shrimp are online, their roles, capabilities +- Per-node load and active tasks +- Task flow in real time +- Historical success rates and latencies + +Could be a TUI with `bubbletea` (no browser needed, stays true to the terminal-native ethos) or a minimal web page bolted onto the gateway. Or both. + +### 10. Swarm Security + +- **NATS TLS + auth**: keep rogue shrimp out +- **Task signing**: make sure tasks actually come from a trusted coordinator +- **Audit trail**: who ran what, when, and what happened + +Not urgent for dev, but unavoidable before anyone else runs this. + +--- + +## Design Principles to Keep in Mind + +1. **Stay light.** 10MB RAM budget is real. New feature? Sure. New dependency? Think twice. +2. **Degrade gracefully.** No Temporal? Skip workflows. No remote workers? Run locally. Nothing should be a hard requirement except NATS. +3. **Single-shrimp must stay simple.** Swarm code should be invisible when you're running one instance. Zero overhead, zero config needed. +4. **Ship it, then polish it.** Working beats perfect. Get the ugly version running first. + +--- + +## How to Weigh In + +Reply with: + +``` +## Re: [Feature Name] + +**Vote**: must-have / nice-to-have / skip / rethink +**Why**: ... +**Alternative idea**: ... +``` + +Or propose something not on this list. The shrimp are listening. diff --git a/Dockerfile b/Dockerfile index 5168e7baf..0360cfda6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ # ============================================================ # Stage 1: Build the picoclaw binary # ============================================================ -FROM golang:1.25-alpine AS builder +FROM golang:1.26.0-alpine AS builder RUN apk add --no-cache git make @@ -18,19 +18,26 @@ RUN make build # ============================================================ # Stage 2: Minimal runtime image # ============================================================ -FROM alpine:3.21 +FROM alpine:3.23 -RUN apk add --no-cache ca-certificates tzdata +RUN apk add --no-cache ca-certificates tzdata curl + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD wget -q --spider http://localhost:18790/health || exit 1 # Copy binary COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw -# Copy builtin skills -COPY --from=builder /src/skills /opt/picoclaw/skills +# Create non-root user and group +RUN addgroup -g 1000 picoclaw && \ + adduser -D -u 1000 -G picoclaw picoclaw + +# Switch to non-root user +USER picoclaw -# Create picoclaw home directory -RUN mkdir -p /root/.picoclaw/workspace/skills && \ - cp -r /opt/picoclaw/skills/* /root/.picoclaw/workspace/skills/ 2>/dev/null || true +# Run onboard to create initial directories and config +RUN /usr/local/bin/picoclaw onboard ENTRYPOINT ["picoclaw"] CMD ["gateway"] diff --git a/Dockerfile.goreleaser b/Dockerfile.goreleaser new file mode 100644 index 000000000..0cdc8c6bd --- /dev/null +++ b/Dockerfile.goreleaser @@ -0,0 +1,10 @@ +FROM alpine:3.21 + +ARG TARGETPLATFORM + +RUN apk add --no-cache ca-certificates tzdata + +COPY $TARGETPLATFORM/picoclaw /usr/local/bin/picoclaw + +ENTRYPOINT ["picoclaw"] +CMD ["gateway"] diff --git a/Makefile b/Makefile index c9af7d5dd..a5ad4a02d 100644 --- a/Makefile +++ b/Makefile @@ -8,13 +8,17 @@ MAIN_GO=$(CMD_DIR)/main.go # Version VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") +GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev") BUILD_TIME=$(shell date +%FT%T%z) GO_VERSION=$(shell $(GO) version | awk '{print $$3}') -LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)" +LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION) -s -w" # Go variables GO?=go -GOFLAGS?=-v +GOFLAGS?=-v -tags stdjson + +# Golangci-lint +GOLANGCI_LINT?=golangci-lint # Installation INSTALL_PREFIX?=$(HOME)/.local @@ -38,6 +42,8 @@ ifeq ($(UNAME_S),Linux) ARCH=amd64 else ifeq ($(UNAME_M),aarch64) ARCH=arm64 + else ifeq ($(UNAME_M),loongarch64) + ARCH=loong64 else ifeq ($(UNAME_M),riscv64) ARCH=riscv64 else @@ -62,20 +68,28 @@ BINARY_PATH=$(BUILD_DIR)/$(BINARY_NAME)-$(PLATFORM)-$(ARCH) # Default target all: build +## generate: Run generate +generate: + @echo "Run generate..." + @rm -r ./$(CMD_DIR)/workspace 2>/dev/null || true + @$(GO) generate ./... + @echo "Run generate complete" + ## build: Build the picoclaw binary for current platform -build: +build: generate @echo "Building $(BINARY_NAME) for $(PLATFORM)/$(ARCH)..." @mkdir -p $(BUILD_DIR) - $(GO) build $(GOFLAGS) $(LDFLAGS) -o $(BINARY_PATH) ./$(CMD_DIR) + @$(GO) build $(GOFLAGS) $(LDFLAGS) -o $(BINARY_PATH) ./$(CMD_DIR) @echo "Build complete: $(BINARY_PATH)" @ln -sf $(BINARY_NAME)-$(PLATFORM)-$(ARCH) $(BUILD_DIR)/$(BINARY_NAME) ## build-all: Build picoclaw for all platforms -build-all: +build-all: generate @echo "Building for multiple platforms..." @mkdir -p $(BUILD_DIR) GOOS=linux GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR) GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) + GOOS=linux GOARCH=loong64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR) GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR) GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR) @@ -88,35 +102,8 @@ install: build @cp $(BUILD_DIR)/$(BINARY_NAME) $(INSTALL_BIN_DIR)/$(BINARY_NAME) @chmod +x $(INSTALL_BIN_DIR)/$(BINARY_NAME) @echo "Installed binary to $(INSTALL_BIN_DIR)/$(BINARY_NAME)" - @echo "Installing builtin skills to $(WORKSPACE_SKILLS_DIR)..." - @mkdir -p $(WORKSPACE_SKILLS_DIR) - @for skill in $(BUILTIN_SKILLS_DIR)/*/; do \ - if [ -d "$$skill" ]; then \ - skill_name=$$(basename "$$skill"); \ - if [ -f "$$skill/SKILL.md" ]; then \ - cp -r "$$skill" $(WORKSPACE_SKILLS_DIR); \ - echo " ✓ Installed skill: $$skill_name"; \ - fi; \ - fi; \ - done @echo "Installation complete!" -## install-skills: Install builtin skills to workspace -install-skills: - @echo "Installing builtin skills to $(WORKSPACE_SKILLS_DIR)..." - @mkdir -p $(WORKSPACE_SKILLS_DIR) - @for skill in $(BUILTIN_SKILLS_DIR)/*/; do \ - if [ -d "$$skill" ]; then \ - skill_name=$$(basename "$$skill"); \ - if [ -f "$$skill/SKILL.md" ]; then \ - mkdir -p $(WORKSPACE_SKILLS_DIR)/$$skill_name; \ - cp -r "$$skill" $(WORKSPACE_SKILLS_DIR); \ - echo " ✓ Installed skill: $$skill_name"; \ - fi; \ - fi; \ - done - @echo "Skills installation complete!" - ## uninstall: Remove picoclaw from system uninstall: @echo "Uninstalling $(BINARY_NAME)..." @@ -138,15 +125,35 @@ clean: @rm -rf $(BUILD_DIR) @echo "Clean complete" +## vet: Run go vet for static analysis +vet: + @$(GO) vet ./... + +## test: Test Go code +test: + @$(GO) test ./... + ## fmt: Format Go code fmt: - @$(GO) fmt ./... + @$(GOLANGCI_LINT) fmt + +## lint: Run linters +lint: + @$(GOLANGCI_LINT) run -## deps: Update dependencies +## deps: Download dependencies deps: + @$(GO) mod download + @$(GO) mod verify + +## update-deps: Update dependencies +update-deps: @$(GO) get -u ./... @$(GO) mod tidy +## check: Run vet, fmt, and verify dependencies +check: deps fmt vet test + ## run: Build and run picoclaw run: build @$(BUILD_DIR)/$(BINARY_NAME) $(ARGS) diff --git a/README.fr.md b/README.fr.md new file mode 100644 index 000000000..7199f7098 --- /dev/null +++ b/README.fr.md @@ -0,0 +1,1127 @@ +
+ PicoClaw + +

PicoClaw : Assistant IA Ultra-Efficace en Go

+ +

Matériel à 10$ · 10 Mo de RAM · Démarrage en 1s · 皮皮虾,我们走!

+ +

+ Go + Hardware + License +
+ Website + Twitter +

+ + [中文](README.zh.md) | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [English](README.md) | **Français** +
+ +--- + +🦐 **PicoClaw** est un assistant personnel IA ultra-léger inspiré de [nanobot](https://github.com/HKUDS/nanobot), entièrement réécrit en **Go** via un processus d'auto-amorçage (self-bootstrapping) — où l'agent IA lui-même a piloté l'intégralité de la migration architecturale et de l'optimisation du code. + +⚡️ **Extrêmement léger :** Fonctionne sur du matériel à seulement **10$** avec **<10 Mo** de RAM. C'est 99% de mémoire en moins qu'OpenClaw et 98% moins cher qu'un Mac mini ! + + + + + + +
+

+ +

+
+

+ +

+
+ +> [!CAUTION] +> **🚨 SÉCURITÉ & CANAUX OFFICIELS** +> +> * **PAS DE CRYPTO :** PicoClaw n'a **AUCUN** token/jeton officiel. Toute annonce sur `pump.fun` ou d'autres plateformes de trading est une **ARNAQUE**. +> * **DOMAINE OFFICIEL :** Le **SEUL** site officiel est **[picoclaw.io](https://picoclaw.io)**, et le site de l'entreprise est **[sipeed.com](https://sipeed.com)**. +> * **Attention :** De nombreux domaines `.ai/.org/.com/.net/...` sont enregistrés par des tiers et ne nous appartiennent pas. +> * **Attention :** PicoClaw est en phase de développement précoce et peut présenter des problèmes de sécurité réseau non résolus. Ne déployez pas en environnement de production avant la version v1.0. +> * **Note :** PicoClaw a récemment fusionné de nombreuses PR, ce qui peut entraîner une empreinte mémoire plus importante (10–20 Mo) dans les dernières versions. Nous prévoyons de prioriser l'optimisation des ressources dès que l'ensemble des fonctionnalités sera stabilisé. + + +## 📢 Actualités + +2026-02-16 🎉 PicoClaw a atteint 12K étoiles en une semaine ! Merci à tous pour votre soutien ! PicoClaw grandit plus vite que nous ne l'avions jamais imaginé. Vu le volume élevé de PR, nous avons un besoin urgent de mainteneurs communautaires. Nos rôles de bénévoles et notre feuille de route sont officiellement publiés [ici](docs/picoclaw_community_roadmap_260216.md) — nous avons hâte de vous accueillir ! + +2026-02-13 🎉 PicoClaw a atteint 5000 étoiles en 4 jours ! Merci à la communauté ! Nous finalisons la **Feuille de Route du Projet** et mettons en place le **Groupe de Développeurs** pour accélérer le développement de PicoClaw. +🚀 **Appel à l'action :** Soumettez vos demandes de fonctionnalités dans les GitHub Discussions. Nous les examinerons et les prioriserons lors de notre prochaine réunion hebdomadaire. + +2026-02-09 🎉 PicoClaw est lancé ! Construit en 1 jour pour apporter les Agents IA au matériel à 10$ avec <10 Mo de RAM. 🦐 PicoClaw, c'est parti ! + +## ✨ Fonctionnalités + +🪶 **Ultra-Léger** : Empreinte mémoire <10 Mo — 99% plus petit que Clawdbot pour les fonctionnalités essentielles. + +💰 **Coût Minimal** : Suffisamment efficace pour fonctionner sur du matériel à 10$ — 98% moins cher qu'un Mac mini. + +⚡️ **Démarrage Éclair** : Temps de démarrage 400X plus rapide, boot en 1 seconde même sur un cœur unique à 0,6 GHz. + +🌍 **Véritable Portabilité** : Un seul binaire autonome pour RISC-V, ARM et x86. Un clic et c'est parti ! + +🤖 **Auto-Construit par l'IA** : Implémentation native en Go de manière autonome — 95% du cœur généré par l'Agent avec affinement humain dans la boucle. + +| | OpenClaw | NanoBot | **PicoClaw** | +| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- | +| **Langage** | TypeScript | Python | **Go** | +| **RAM** | >1 Go | >100 Mo | **< 10 Mo** | +| **Démarrage**
(cœur 0,8 GHz) | >500s | >30s | **<1s** | +| **Coût** | Mac Mini 599$ | La plupart des SBC Linux
~50$ | **N'importe quelle carte Linux**
**À partir de 10$** | + +PicoClaw + +## 🦾 Démonstration + +### 🛠️ Flux de Travail Standard de l'Assistant + + + + + + + + + + + + + + + + + +

🧩 Ingénieur Full-Stack

🗂️ Gestion des Logs & Planification

🔎 Recherche Web & Apprentissage

Développer • Déployer • Mettre à l'échellePlanifier • Automatiser • MémoriserDécouvrir • Analyser • Tendances
+ +### 📱 Utiliser sur d'anciens téléphones Android + +Donnez une seconde vie à votre téléphone d'il y a dix ans ! Transformez-le en assistant IA intelligent avec PicoClaw. Démarrage rapide : + +1. **Installez Termux** (disponible sur F-Droid ou Google Play). +2. **Exécutez les commandes** + +```bash +# Note : Remplacez v0.1.1 par la dernière version depuis la page des Releases +wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64 +chmod +x picoclaw-linux-arm64 +pkg install proot +termux-chroot ./picoclaw-linux-arm64 onboard +``` + +Puis suivez les instructions de la section « Démarrage Rapide » pour terminer la configuration ! + +PicoClaw + +### 🐜 Déploiement Innovant à Faible Empreinte + +PicoClaw peut être déployé sur pratiquement n'importe quel appareil Linux ! + +- 9,9$ [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) version E (Ethernet) ou W (WiFi6), pour un Assistant Domotique Minimaliste +- 30~50$ [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), ou 100$ [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) pour la Maintenance Automatisée de Serveurs +- 50$ [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) ou 100$ [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) pour la Surveillance Intelligente + + + +🌟 Encore plus de scénarios de déploiement vous attendent ! + +## 📦 Installation + +### Installer avec un binaire précompilé + +Téléchargez le binaire pour votre plateforme depuis la page des [releases](https://github.com/sipeed/picoclaw/releases). + +### Installer depuis les sources (dernières fonctionnalités, recommandé pour le développement) + +```bash +git clone https://github.com/sipeed/picoclaw.git + +cd picoclaw +make deps + +# Compiler, pas besoin d'installer +make build + +# Compiler pour plusieurs plateformes +make build-all + +# Compiler et Installer +make install +``` + +## 🐳 Docker Compose + +Vous pouvez également exécuter PicoClaw avec Docker Compose sans rien installer localement. + +```bash +# 1. Clonez ce dépôt +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. Configurez vos clés API +cp config/config.example.json config/config.json +vim config/config.json # Configurez DISCORD_BOT_TOKEN, clés API, etc. + +# 3. Compiler & Démarrer +docker compose --profile gateway up -d + +# 4. Voir les logs +docker compose logs -f picoclaw-gateway + +# 5. Arrêter +docker compose --profile gateway down +``` + +### Mode Agent (exécution unique) + +```bash +# Poser une question +docker compose run --rm picoclaw-agent -m "Combien font 2+2 ?" + +# Mode interactif +docker compose run --rm picoclaw-agent +``` + +### Recompiler + +```bash +docker compose --profile gateway build --no-cache +docker compose --profile gateway up -d +``` + +### 🚀 Démarrage Rapide + +> [!TIP] +> Configurez votre clé API dans `~/.picoclaw/config.json`. +> Obtenir des clés API : [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) +> La recherche web est **optionnelle** — obtenez gratuitement l'[API Brave Search](https://brave.com/search/api) (2000 requêtes gratuites/mois) ou utilisez le repli automatique intégré. + +**1. Initialiser** + +```bash +picoclaw onboard +``` + +**2. Configurer** (`~/.picoclaw/config.json`) + +```json +{ + "model_list": [ + { + "model_name": "gpt4", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key", + "api_base": "https://api.openai.com/v1" + } + ], + "agents": { + "defaults": { + "model": "gpt4" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "VOTRE_TOKEN_BOT", + "allow_from": ["VOTRE_USER_ID"] + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "VOTRE_CLE_API_BRAVE", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +**3. Obtenir des Clés API** + +* **Fournisseur LLM** : [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +* **Recherche Web** (optionnel) : [Brave Search](https://brave.com/search/api) - Offre gratuite disponible (2000 requêtes/mois) + +> **Note** : Consultez `config.example.json` pour un modèle de configuration complet. + +**4. Discuter** + +```bash +picoclaw agent -m "Combien font 2+2 ?" +``` + +Et voilà ! Vous avez un assistant IA fonctionnel en 2 minutes. + +--- + +## 💬 Applications de Chat + +Discutez avec votre PicoClaw via Telegram, Discord, DingTalk, LINE ou WeCom + +| Canal | Configuration | +| ------------ | -------------------------------------- | +| **Telegram** | Facile (juste un token) | +| **Discord** | Facile (token bot + intents) | +| **QQ** | Facile (AppID + AppSecret) | +| **DingTalk** | Moyen (identifiants de l'application) | +| **LINE** | Moyen (identifiants + URL de webhook) | +| **WeCom** | Moyen (CorpID + configuration webhook) | + +
+Telegram (Recommandé) + +**1. Créer un bot** + +* Ouvrez Telegram, recherchez `@BotFather` +* Envoyez `/newbot`, suivez les instructions +* Copiez le token + +**2. Configurer** + +```json +{ + "channels": { + "telegram": { + "enabled": true, + "token": "VOTRE_TOKEN_BOT", + "allow_from": ["VOTRE_USER_ID"] + } + } +} +``` + +> Obtenez votre User ID via `@userinfobot` sur Telegram. + +**3. Lancer** + +```bash +picoclaw gateway +``` + +
+ +
+Discord + +**1. Créer un bot** + +* Rendez-vous sur +* Créez une application → Bot → Add Bot +* Copiez le token du bot + +**2. Activer les intents** + +* Dans les paramètres du Bot, activez **MESSAGE CONTENT INTENT** +* (Optionnel) Activez **SERVER MEMBERS INTENT** si vous souhaitez utiliser des listes d'autorisation basées sur les données des membres + +**3. Obtenir votre User ID** + +* Paramètres Discord → Avancé → activez le **Mode Développeur** +* Clic droit sur votre avatar → **Copier l'identifiant** + +**4. Configurer** + +```json +{ + "channels": { + "discord": { + "enabled": true, + "token": "VOTRE_TOKEN_BOT", + "allow_from": ["VOTRE_USER_ID"] + } + } +} +``` + +**5. Inviter le bot** + +* OAuth2 → URL Generator +* Scopes : `bot` +* Permissions du Bot : `Send Messages`, `Read Message History` +* Ouvrez l'URL d'invitation générée et ajoutez le bot à votre serveur + +**6. Lancer** + +```bash +picoclaw gateway +``` + +
+ +
+QQ + +**1. Créer un bot** + +- Rendez-vous sur la [QQ Open Platform](https://q.qq.com/#) +- Créez une application → Obtenez l'**AppID** et l'**AppSecret** + +**2. Configurer** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "VOTRE_APP_ID", + "app_secret": "VOTRE_APP_SECRET", + "allow_from": [] + } + } +} +``` + +> Laissez `allow_from` vide pour autoriser tous les utilisateurs, ou spécifiez des numéros QQ pour restreindre l'accès. + +**3. Lancer** + +```bash +picoclaw gateway +``` + +
+ +
+DingTalk + +**1. Créer un bot** + +* Rendez-vous sur la [Open Platform](https://open.dingtalk.com/) +* Créez une application interne +* Copiez le Client ID et le Client Secret + +**2. Configurer** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "VOTRE_CLIENT_ID", + "client_secret": "VOTRE_CLIENT_SECRET", + "allow_from": [] + } + } +} +``` + +> Laissez `allow_from` vide pour autoriser tous les utilisateurs, ou spécifiez des identifiants pour restreindre l'accès. + +**3. Lancer** + +```bash +picoclaw gateway +``` + +
+ +
+LINE + +**1. Créer un Compte Officiel LINE** + +- Rendez-vous sur la [LINE Developers Console](https://developers.line.biz/) +- Créez un provider → Créez un canal Messaging API +- Copiez le **Channel Secret** et le **Channel Access Token** + +**2. Configurer** + +```json +{ + "channels": { + "line": { + "enabled": true, + "channel_secret": "VOTRE_CHANNEL_SECRET", + "channel_access_token": "VOTRE_CHANNEL_ACCESS_TOKEN", + "webhook_host": "0.0.0.0", + "webhook_port": 18791, + "webhook_path": "/webhook/line", + "allow_from": [] + } + } +} +``` + +**3. Configurer l'URL du Webhook** + +LINE exige HTTPS pour les webhooks. Utilisez un reverse proxy ou un tunnel : + +```bash +# Exemple avec ngrok +ngrok http 18791 +``` + +Puis configurez l'URL du Webhook dans la LINE Developers Console sur `https://votre-domaine/webhook/line` et activez **Use webhook**. + +**4. Lancer** + +```bash +picoclaw gateway +``` + +> Dans les discussions de groupe, le bot répond uniquement lorsqu'il est mentionné avec @. Les réponses citent le message original. + +> **Docker Compose** : Ajoutez `ports: ["18791:18791"]` au service `picoclaw-gateway` pour exposer le port du webhook. + +
+ +
+WeCom (WeChat Work) + +PicoClaw prend en charge deux types d'intégration WeCom : + +**Option 1 : WeCom Bot (Robot Intelligent)** - Configuration plus facile, prend en charge les discussions de groupe +**Option 2 : WeCom App (Application Personnalisée)** - Plus de fonctionnalités, messagerie proactive + +Voir le [Guide de Configuration WeCom App](docs/wecom-app-configuration.md) pour des instructions détaillées. + +**Configuration Rapide - WeCom Bot :** + +**1. Créer un bot** + +* Accédez à la Console d'Administration WeCom → Discussion de Groupe → Ajouter un Bot de Groupe +* Copiez l'URL du webhook (format : `https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`) + +**2. Configurer** + +```json +{ + "channels": { + "wecom": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18793, + "webhook_path": "/webhook/wecom", + "allow_from": [] + } + } +} +``` + +**Configuration Rapide - WeCom App :** + +**1. Créer une application** + +* Accédez à la Console d'Administration WeCom → Gestion des Applications → Créer une Application +* Copiez l'**AgentId** et le **Secret** +* Accédez à la page "Mon Entreprise", copiez le **CorpID** + +**2. Configurer la réception des messages** + +* Dans les détails de l'application, cliquez sur "Recevoir les Messages" → "Configurer l'API" +* Définissez l'URL sur `http://your-server:18792/webhook/wecom-app` +* Générez le **Token** et l'**EncodingAESKey** + +**3. Configurer** + +```json +{ + "channels": { + "wecom_app": { + "enabled": true, + "corp_id": "wwxxxxxxxxxxxxxxxx", + "corp_secret": "YOUR_CORP_SECRET", + "agent_id": 1000002, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18792, + "webhook_path": "/webhook/wecom-app", + "allow_from": [] + } + } +} +``` + +**4. Lancer** + +```bash +picoclaw gateway +``` + +> **Note** : WeCom App nécessite l'ouverture du port 18792 pour les callbacks webhook. Utilisez un proxy inverse pour HTTPS en production. + +
+ +## ClawdChat Rejoignez le Réseau Social d'Agents + +Connectez PicoClaw au Réseau Social d'Agents simplement en envoyant un seul message via le CLI ou n'importe quelle application de chat intégrée. + +**Lisez `https://clawdchat.ai/skill.md` et suivez les instructions pour rejoindre [ClawdChat.ai](https://clawdchat.ai)** + +## ⚙️ Configuration + +Fichier de configuration : `~/.picoclaw/config.json` + +### Structure du Workspace + +PicoClaw stocke les données dans votre workspace configuré (par défaut : `~/.picoclaw/workspace`) : + +``` +~/.picoclaw/workspace/ +├── sessions/ # Sessions de conversation et historique +├── memory/ # Mémoire à long terme (MEMORY.md) +├── state/ # État persistant (dernier canal, etc.) +├── cron/ # Base de données des tâches planifiées +├── skills/ # Compétences personnalisées +├── AGENTS.md # Guide de comportement de l'Agent +├── HEARTBEAT.md # Invites de tâches périodiques (vérifiées toutes les 30 min) +├── IDENTITY.md # Identité de l'Agent +├── SOUL.md # Âme de l'Agent +├── TOOLS.md # Description des outils +└── USER.md # Préférences utilisateur +``` + +### 🔒 Bac à Sable de Sécurité + +PicoClaw s'exécute dans un environnement sandboxé par défaut. L'agent ne peut accéder aux fichiers et exécuter des commandes qu'au sein du workspace configuré. + +#### Configuration par Défaut + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| Option | Par défaut | Description | +|--------|------------|-------------| +| `workspace` | `~/.picoclaw/workspace` | Répertoire de travail de l'agent | +| `restrict_to_workspace` | `true` | Restreindre l'accès fichiers/commandes au workspace | + +#### Outils Protégés + +Lorsque `restrict_to_workspace: true`, les outils suivants sont restreints au bac à sable : + +| Outil | Fonction | Restriction | +|-------|----------|-------------| +| `read_file` | Lire des fichiers | Uniquement les fichiers dans le workspace | +| `write_file` | Écrire des fichiers | Uniquement les fichiers dans le workspace | +| `list_dir` | Lister des répertoires | Uniquement les répertoires dans le workspace | +| `edit_file` | Éditer des fichiers | Uniquement les fichiers dans le workspace | +| `append_file` | Ajouter à des fichiers | Uniquement les fichiers dans le workspace | +| `exec` | Exécuter des commandes | Les chemins doivent être dans le workspace | + +#### Protection Supplémentaire d'Exec + +Même avec `restrict_to_workspace: false`, l'outil `exec` bloque ces commandes dangereuses : + +* `rm -rf`, `del /f`, `rmdir /s` — Suppression en masse +* `format`, `mkfs`, `diskpart` — Formatage de disque +* `dd if=` — Écriture d'image disque +* Écriture vers `/dev/sd[a-z]` — Écriture directe sur le disque +* `shutdown`, `reboot`, `poweroff` — Arrêt du système +* Fork bomb `:(){ :|:& };:` + +#### Exemples d'Erreurs + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### Désactiver les Restrictions (Risque de Sécurité) + +Si vous avez besoin que l'agent accède à des chemins en dehors du workspace : + +**Méthode 1 : Fichier de configuration** + +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**Méthode 2 : Variable d'environnement** + +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **Attention** : Désactiver cette restriction permet à l'agent d'accéder à n'importe quel chemin sur votre système. À utiliser avec précaution uniquement dans des environnements contrôlés. + +#### Cohérence du Périmètre de Sécurité + +Le paramètre `restrict_to_workspace` s'applique de manière cohérente sur tous les chemins d'exécution : + +| Chemin d'Exécution | Périmètre de Sécurité | +|--------------------|----------------------| +| Agent Principal | `restrict_to_workspace` ✅ | +| Sous-agent / Spawn | Hérite de la même restriction ✅ | +| Tâches Heartbeat | Hérite de la même restriction ✅ | + +Tous les chemins partagent la même restriction de workspace — il est impossible de contourner le périmètre de sécurité via des sous-agents ou des tâches planifiées. + +### Heartbeat (Tâches Périodiques) + +PicoClaw peut exécuter des tâches périodiques automatiquement. Créez un fichier `HEARTBEAT.md` dans votre workspace : + +```markdown +# Tâches Périodiques + +- Vérifier mes e-mails pour les messages importants +- Consulter mon agenda pour les événements à venir +- Vérifier les prévisions météo +``` + +L'agent lira ce fichier toutes les 30 minutes (configurable) et exécutera les tâches à l'aide des outils disponibles. + +#### Tâches Asynchrones avec Spawn + +Pour les tâches de longue durée (recherche web, appels API), utilisez l'outil `spawn` pour créer un **sous-agent** : + +```markdown +# Tâches Périodiques + +## Tâches Rapides (réponse directe) +- Indiquer l'heure actuelle + +## Tâches Longues (utiliser spawn pour l'asynchrone) +- Rechercher les actualités IA sur le web et les résumer +- Vérifier les e-mails et signaler les messages importants +``` + +**Comportements clés :** + +| Fonctionnalité | Description | +|----------------|-------------| +| **spawn** | Crée un sous-agent asynchrone, ne bloque pas le heartbeat | +| **Contexte indépendant** | Le sous-agent a son propre contexte, sans historique de session | +| **Outil message** | Le sous-agent communique directement avec l'utilisateur via l'outil message | +| **Non-bloquant** | Après le spawn, le heartbeat continue vers la tâche suivante | + +#### Fonctionnement de la Communication du Sous-agent + +``` +Le Heartbeat se déclenche + ↓ +L'Agent lit HEARTBEAT.md + ↓ +Pour une tâche longue : spawn d'un sous-agent + ↓ ↓ +Continue la tâche suivante Le sous-agent travaille indépendamment + ↓ ↓ +Toutes les tâches terminées Le sous-agent utilise l'outil "message" + ↓ ↓ +Répond HEARTBEAT_OK L'utilisateur reçoit le résultat directement +``` + +Le sous-agent a accès aux outils (message, web_search, etc.) et peut communiquer avec l'utilisateur indépendamment sans passer par l'agent principal. + +**Configuration :** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| Option | Par défaut | Description | +|--------|------------|-------------| +| `enabled` | `true` | Activer/désactiver le heartbeat | +| `interval` | `30` | Intervalle de vérification en minutes (min : 5) | + +**Variables d'environnement :** + +* `PICOCLAW_HEARTBEAT_ENABLED=false` pour désactiver +* `PICOCLAW_HEARTBEAT_INTERVAL=60` pour modifier l'intervalle + +### Fournisseurs + +> [!NOTE] +> Groq fournit la transcription vocale gratuite via Whisper. Si configuré, les messages vocaux Telegram seront automatiquement transcrits. + +| Fournisseur | Utilisation | Obtenir une Clé API | +| ------------------------ | ---------------------------------------- | ------------------------------------------------------ | +| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) | +| `openrouter` (À tester) | LLM (recommandé, accès à tous les modèles) | [openrouter.ai](https://openrouter.ai) | +| `anthropic` (À tester) | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | +| `openai` (À tester) | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | +| `deepseek` (À tester) | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | +| `qwen` | LLM (Alibaba Qwen) | [dashscope.aliyuncs.com](https://dashscope.aliyuncs.com/compatible-mode/v1) | +| `cerebras` | LLM (Cerebras) | [cerebras.ai](https://api.cerebras.ai/v1) | +| `groq` | LLM + **Transcription vocale** (Whisper) | [console.groq.com](https://console.groq.com) | + +
+Configuration Zhipu + +**1. Obtenir la clé API** + +* Obtenez la [clé API](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) + +**2. Configurer** + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "zhipu": { + "api_key": "Votre Clé API", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + } +} +``` + +**3. Lancer** + +```bash +picoclaw agent -m "Bonjour, comment ça va ?" +``` + +
+ +
+Exemple de configuration complète + +```json +{ + "agents": { + "defaults": { + "model": "anthropic/claude-opus-4-5" + } + }, + "providers": { + "openrouter": { + "api_key": "sk-or-v1-xxx" + }, + "groq": { + "api_key": "gsk_xxx" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "123456:ABC...", + "allow_from": ["123456789"] + }, + "discord": { + "enabled": true, + "token": "", + "allow_from": [""] + }, + "whatsapp": { + "enabled": false + }, + "feishu": { + "enabled": false, + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] + }, + "qq": { + "enabled": false, + "app_id": "", + "app_secret": "", + "allow_from": [] + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "BSA...", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + }, + "cron": { + "exec_timeout_minutes": 5 + } + }, + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +
+ +### Configuration de Modèle (model_list) + +> **Nouveau !** PicoClaw utilise désormais une approche de configuration **centrée sur le modèle**. Spécifiez simplement le format `fournisseur/modèle` (par exemple, `zhipu/glm-4.7`) pour ajouter de nouveaux fournisseurs—**aucune modification de code requise !** + +Cette conception permet également le **support multi-agent** avec une sélection flexible de fournisseurs : + +- **Différents agents, différents fournisseurs** : Chaque agent peut utiliser son propre fournisseur LLM +- **Modèles de secours (Fallbacks)** : Configurez des modèles primaires et de secours pour la résilience +- **Équilibrage de charge** : Répartissez les requêtes sur plusieurs points de terminaison +- **Configuration centralisée** : Gérez tous les fournisseurs en un seul endroit + +#### 📋 Tous les Fournisseurs Supportés + +| Fournisseur | Préfixe `model` | API Base par Défaut | Protocole | Clé API | +|-------------|-----------------|---------------------|----------|---------| +| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Obtenir Clé](https://platform.openai.com) | +| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Obtenir Clé](https://console.anthropic.com) | +| **Zhipu AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Obtenir Clé](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) | +| **DeepSeek** | `deepseek/` | `https://api.deepseek.com/v1` | OpenAI | [Obtenir Clé](https://platform.deepseek.com) | +| **Google Gemini** | `gemini/` | `https://generativelanguage.googleapis.com/v1beta` | OpenAI | [Obtenir Clé](https://aistudio.google.com/api-keys) | +| **Groq** | `groq/` | `https://api.groq.com/openai/v1` | OpenAI | [Obtenir Clé](https://console.groq.com) | +| **Moonshot** | `moonshot/` | `https://api.moonshot.cn/v1` | OpenAI | [Obtenir Clé](https://platform.moonshot.cn) | +| **Qwen (Alibaba)** | `qwen/` | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI | [Obtenir Clé](https://dashscope.console.aliyun.com) | +| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Obtenir Clé](https://build.nvidia.com) | +| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (pas de clé nécessaire) | +| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Obtenir Clé](https://openrouter.ai/keys) | +| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local | +| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Obtenir Clé](https://cerebras.ai) | +| **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Obtenir Clé](https://console.volcengine.com) | +| **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth uniquement | +| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | + +#### Configuration de Base + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key" + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "sk-ant-your-key" + }, + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-zhipu-key" + } + ], + "agents": { + "defaults": { + "model": "gpt-5.2" + } + } +} +``` + +#### Exemples par Fournisseur + +**OpenAI** +```json +{ + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-..." +} +``` + +**Zhipu AI (GLM)** +```json +{ + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" +} +``` + +**Anthropic (avec OAuth)** +```json +{ + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "auth_method": "oauth" +} +``` +> Exécutez `picoclaw auth login --provider anthropic` pour configurer les identifiants OAuth. + +#### Équilibrage de Charge + +Configurez plusieurs points de terminaison pour le même nom de modèle—PicoClaw utilisera automatiquement le round-robin entre eux : + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api1.example.com/v1", + "api_key": "sk-key1" + }, + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api2.example.com/v1", + "api_key": "sk-key2" + } + ] +} +``` + +#### Migration depuis l'Ancienne Configuration `providers` + +L'ancienne configuration `providers` est **dépréciée** mais toujours supportée pour la rétrocompatibilité. + +**Ancienne Configuration (dépréciée) :** +```json +{ + "providers": { + "zhipu": { + "api_key": "your-key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + }, + "agents": { + "defaults": { + "provider": "zhipu", + "model": "glm-4.7" + } + } +} +``` + +**Nouvelle Configuration (recommandée) :** +```json +{ + "model_list": [ + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" + } + ], + "agents": { + "defaults": { + "model": "glm-4.7" + } + } +} +``` + +Pour le guide de migration détaillé, voir [docs/migration/model-list-migration.md](docs/migration/model-list-migration.md). + +## Référence CLI + +| Commande | Description | +| ------------------------- | ------------------------------------- | +| `picoclaw onboard` | Initialiser la configuration & le workspace | +| `picoclaw agent -m "..."` | Discuter avec l'agent | +| `picoclaw agent` | Mode de discussion interactif | +| `picoclaw gateway` | Démarrer la passerelle | +| `picoclaw status` | Afficher le statut | +| `picoclaw cron list` | Lister toutes les tâches planifiées | +| `picoclaw cron add ...` | Ajouter une tâche planifiée | + +### Tâches Planifiées / Rappels + +PicoClaw prend en charge les rappels planifiés et les tâches récurrentes via l'outil `cron` : + +* **Rappels ponctuels** : « Rappelle-moi dans 10 minutes » → se déclenche une fois après 10 min +* **Tâches récurrentes** : « Rappelle-moi toutes les 2 heures » → se déclenche toutes les 2 heures +* **Expressions Cron** : « Rappelle-moi à 9h tous les jours » → utilise une expression cron + +Les tâches sont stockées dans `~/.picoclaw/workspace/cron/` et traitées automatiquement. + +## 🤝 Contribuer & Feuille de Route + +Les PR sont les bienvenues ! Le code source est volontairement petit et lisible. 🤗 + +Feuille de route à venir... + +Groupe de développeurs en construction. Condition d'entrée : au moins 1 PR fusionnée. + +Groupes d'utilisateurs : + +Discord : + +PicoClaw + +## 🐛 Dépannage + +### La recherche web affiche « API 配置问题 » + +C'est normal si vous n'avez pas encore configuré de clé API de recherche. PicoClaw fournira des liens utiles pour la recherche manuelle. + +Pour activer la recherche web : + +1. **Option 1 (Recommandé)** : Obtenez une clé API gratuite sur [https://brave.com/search/api](https://brave.com/search/api) (2000 requêtes gratuites/mois) pour les meilleurs résultats. +2. **Option 2 (Sans carte bancaire)** : Si vous n'avez pas de clé, le système bascule automatiquement sur **DuckDuckGo** (aucune clé requise). + +Ajoutez la clé dans `~/.picoclaw/config.json` si vous utilisez Brave : + +```json +{ + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "VOTRE_CLE_API_BRAVE", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +### Erreurs de filtrage de contenu + +Certains fournisseurs (comme Zhipu) disposent d'un filtrage de contenu. Essayez de reformuler votre requête ou utilisez un modèle différent. + +### Le bot Telegram affiche « Conflict: terminated by other getUpdates » + +Cela se produit lorsqu'une autre instance du bot est en cours d'exécution. Assurez-vous qu'un seul `picoclaw gateway` fonctionne à la fois. + +--- + +## 📝 Comparaison des Clés API + +| Service | Offre Gratuite | Cas d'Utilisation | +| ---------------- | -------------------- | ------------------------------------- | +| **OpenRouter** | 200K tokens/mois | Multiples modèles (Claude, GPT-4, etc.) | +| **Zhipu** | 200K tokens/mois | Idéal pour les utilisateurs chinois | +| **Brave Search** | 2000 requêtes/mois | Fonctionnalité de recherche web | +| **Groq** | Offre gratuite dispo | Inférence ultra-rapide (Llama, Mixtral) | diff --git a/README.ja.md b/README.ja.md index 311ce3069..bb0bdfb28 100644 --- a/README.ja.md +++ b/README.ja.md @@ -3,7 +3,7 @@

PicoClaw: Go で書かれた超効率 AI アシスタント

-

$10 ハードウェア · 10MB RAM · 1秒起動 · 皮皮虾,我们走!

+

$10 ハードウェア · 10MB RAM · 1秒起動 · 行くぜ、シャコ!

@@ -12,7 +12,7 @@ License

-**日本語** | [English](README.md) +[中文](README.zh.md) | **日本語** | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | [English](README.md) @@ -39,7 +39,7 @@ ## 📢 ニュース -2026-02-09 🎉 PicoClaw リリース!$10 ハードウェアで 10MB 未満の RAM で動く AI エージェントを 1 日で構築。🦐 皮皮虾,我们走! +2026-02-09 🎉 PicoClaw リリース!$10 ハードウェアで 10MB 未満の RAM で動く AI エージェントを 1 日で構築。🦐 行くぜ、シャコ! ## ✨ 特徴 @@ -174,44 +174,37 @@ picoclaw onboard ```json { + "model_list": [ + { + "model_name": "gpt4", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key", + "api_base": "https://api.openai.com/v1" + } + ], "agents": { "defaults": { - "workspace": "~/.picoclaw/workspace", - "model": "glm-4.7", - "max_tokens": 8192, - "temperature": 0.7, - "max_tool_iterations": 20 + "model": "gpt4" } }, - "providers": { - "openrouter": { - "api_key": "xxx", - "api_base": "https://open.bigmodel.cn/api/paas/v4" - } - }, - "tools": { - "web": { - "search": { - "api_key": "YOUR_BRAVE_API_KEY", - "max_results": 5 - } + "channels": { + "telegram": { + "enabled": true, + "token": "YOUR_TELEGRAM_BOT_TOKEN", + "allow_from": [] } - }, - "heartbeat": { - "enabled": true, - "interval": 30 } } ``` **3. API キーの取得** -- **LLM プロバイダー**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +- **LLM プロバイダー**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) · [Qwen](https://dashscope.console.aliyun.com) - **Web 検索**(任意): [Brave Search](https://brave.com/search/api) - 無料枠あり(月 2000 リクエスト) > **注意**: 完全な設定テンプレートは `config.example.json` を参照してください。 -**3. チャット** +**4. チャット** ```bash picoclaw agent -m "What is 2+2?" @@ -223,12 +216,16 @@ picoclaw agent -m "What is 2+2?" ## 💬 チャットアプリ -Telegram で PicoClaw と会話できます +Telegram、Discord、QQ、DingTalk、LINE、WeCom で PicoClaw と会話できます | チャネル | セットアップ | |---------|------------| | **Telegram** | 簡単(トークンのみ) | | **Discord** | 簡単(Bot トークン + Intents) | +| **QQ** | 簡単(AppID + AppSecret) | +| **DingTalk** | 普通(アプリ認証情報) | +| **LINE** | 普通(認証情報 + Webhook URL) | +| **WeCom** | 普通(CorpID + Webhook設定) |
Telegram(推奨) @@ -247,7 +244,7 @@ Telegram で PicoClaw と会話できます "telegram": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allow_from": ["YOUR_USER_ID"] } } } @@ -287,7 +284,7 @@ picoclaw gateway "discord": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allow_from": ["YOUR_USER_ID"] } } } @@ -307,6 +304,204 @@ picoclaw gateway
+
+QQ + +**1. Bot を作成** + +- [QQ オープンプラットフォーム](https://q.qq.com/#) にアクセス +- アプリケーションを作成 → **AppID** と **AppSecret** を取得 + +**2. 設定** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "YOUR_APP_ID", + "app_secret": "YOUR_APP_SECRET", + "allow_from": [] + } + } +} +``` + +> `allow_from` を空にすると全ユーザーを許可、QQ番号を指定してアクセス制限可能。 + +**3. 起動** + +```bash +picoclaw gateway +``` + +
+ +
+DingTalk + +**1. Bot を作成** + +- [オープンプラットフォーム](https://open.dingtalk.com/) にアクセス +- 内部アプリを作成 +- Client ID と Client Secret をコピー + +**2. 設定** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "allow_from": [] + } + } +} +``` + +> `allow_from` を空にすると全ユーザーを許可、ユーザーIDを指定してアクセス制限可能。 + +**3. 起動** + +```bash +picoclaw gateway +``` + +
+ +
+LINE + +**1. LINE 公式アカウントを作成** + +- [LINE Developers Console](https://developers.line.biz/) にアクセス +- プロバイダーを作成 → Messaging API チャネルを作成 +- **チャネルシークレット** と **チャネルアクセストークン** をコピー + +**2. 設定** + +```json +{ + "channels": { + "line": { + "enabled": true, + "channel_secret": "YOUR_CHANNEL_SECRET", + "channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN", + "webhook_host": "0.0.0.0", + "webhook_port": 18791, + "webhook_path": "/webhook/line", + "allow_from": [] + } + } +} +``` + +**3. Webhook URL を設定** + +LINE の Webhook には HTTPS が必要です。リバースプロキシまたはトンネルを使用してください: + +```bash +# ngrok の例 +ngrok http 18791 +``` + +LINE Developers Console で Webhook URL を `https://あなたのドメイン/webhook/line` に設定し、**Webhook の利用** を有効にしてください。 + +**4. 起動** + +```bash +picoclaw gateway +``` + +> グループチャットでは @メンション時のみ応答します。返信は元メッセージを引用する形式です。 + +> **Docker Compose**: `picoclaw-gateway` サービスに `ports: ["18791:18791"]` を追加して Webhook ポートを公開してください。 + +
+ +
+WeCom (企業微信) + +PicoClaw は2種類の WeCom 統合をサポートしています: + +**オプション1: WeCom Bot (智能ロボット)** - 簡単な設定、グループチャット対応 +**オプション2: WeCom App (自作アプリ)** - より多機能、アクティブメッセージング対応 + +詳細な設定手順は [WeCom App Configuration Guide](docs/wecom-app-configuration.md) を参照してください。 + +**クイックセットアップ - WeCom Bot:** + +**1. ボットを作成** + +* WeCom 管理コンソール → グループチャット → グループボットを追加 +* Webhook URL をコピー(形式: `https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`) + +**2. 設定** + +```json +{ + "channels": { + "wecom": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18793, + "webhook_path": "/webhook/wecom", + "allow_from": [] + } + } +} +``` + +**クイックセットアップ - WeCom App:** + +**1. アプリを作成** + +* WeCom 管理コンソール → アプリ管理 → アプリを作成 +* **AgentId** と **Secret** をコピー +* "マイ会社" ページで **CorpID** をコピー + +**2. メッセージ受信を設定** + +* アプリ詳細で "メッセージを受信" → "APIを設定" をクリック +* URL を `http://your-server:18792/webhook/wecom-app` に設定 +* **Token** と **EncodingAESKey** を生成 + +**3. 設定** + +```json +{ + "channels": { + "wecom_app": { + "enabled": true, + "corp_id": "wwxxxxxxxxxxxxxxxx", + "corp_secret": "YOUR_CORP_SECRET", + "agent_id": 1000002, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18792, + "webhook_path": "/webhook/wecom-app", + "allow_from": [] + } + } +} +``` + +**4. 起動** + +```bash +picoclaw gateway +``` + +> **注意**: WeCom App は Webhook コールバック用にポート 18792 を開放する必要があります。本番環境では HTTPS 用のリバースプロキシを使用してください。 + +
+ ## ⚙️ 設定 設定ファイル: `~/.picoclaw/config.json` @@ -330,6 +525,98 @@ PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw └── USER.md # ユーザー設定 ``` +### 🔒 セキュリティサンドボックス + +PicoClaw はデフォルトでサンドボックス環境で実行されます。エージェントは設定されたワークスペース内のファイルにのみアクセスし、コマンドを実行できます。 + +#### デフォルト設定 + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| オプション | デフォルト | 説明 | +|-----------|-----------|------| +| `workspace` | `~/.picoclaw/workspace` | エージェントの作業ディレクトリ | +| `restrict_to_workspace` | `true` | ファイル/コマンドアクセスをワークスペースに制限 | + +#### 保護対象ツール + +`restrict_to_workspace: true` の場合、以下のツールがサンドボックス化されます: + +| ツール | 機能 | 制限 | +|-------|------|------| +| `read_file` | ファイル読み込み | ワークスペース内のファイルのみ | +| `write_file` | ファイル書き込み | ワークスペース内のファイルのみ | +| `list_dir` | ディレクトリ一覧 | ワークスペース内のディレクトリのみ | +| `edit_file` | ファイル編集 | ワークスペース内のファイルのみ | +| `append_file` | ファイル追記 | ワークスペース内のファイルのみ | +| `exec` | コマンド実行 | コマンドパスはワークスペース内である必要あり | + +#### exec ツールの追加保護 + +`restrict_to_workspace: false` でも、`exec` ツールは以下の危険なコマンドをブロックします: + +- `rm -rf`, `del /f`, `rmdir /s` — 一括削除 +- `format`, `mkfs`, `diskpart` — ディスクフォーマット +- `dd if=` — ディスクイメージング +- `/dev/sd[a-z]` への書き込み — 直接ディスク書き込み +- `shutdown`, `reboot`, `poweroff` — システムシャットダウン +- フォークボム `:(){ :|:& };:` + +#### エラー例 + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### 制限の無効化(セキュリティリスク) + +エージェントにワークスペース外のパスへのアクセスが必要な場合: + +**方法1: 設定ファイル** +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**方法2: 環境変数** +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **警告**: この制限を無効にすると、エージェントはシステム上の任意のパスにアクセスできるようになります。制御された環境でのみ慎重に使用してください。 + +#### セキュリティ境界の一貫性 + +`restrict_to_workspace` 設定は、すべての実行パスで一貫して適用されます: + +| 実行パス | セキュリティ境界 | +|---------|-----------------| +| メインエージェント | `restrict_to_workspace` ✅ | +| サブエージェント / Spawn | 同じ制限を継承 ✅ | +| ハートビートタスク | 同じ制限を継承 ✅ | + +すべてのパスで同じワークスペース制限が適用されます — サブエージェントやスケジュールタスクを通じてセキュリティ境界をバイパスする方法はありません。 + ### ハートビート(定期タスク) PicoClaw は自動的に定期タスクを実行できます。ワークスペースに `HEARTBEAT.md` ファイルを作成します: @@ -406,6 +693,22 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る - `PICOCLAW_HEARTBEAT_ENABLED=false` で無効化 - `PICOCLAW_HEARTBEAT_INTERVAL=60` で間隔変更 +### プロバイダー + +> [!NOTE] +> Groq は Whisper による無料の音声文字起こしを提供しています。設定すると、Telegram の音声メッセージが自動的に文字起こしされます。 + +| プロバイダー | 用途 | API キー取得先 | +| --- | --- | --- | +| `gemini` | LLM(Gemini 直接) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM(Zhipu 直接) | [bigmodel.cn](https://bigmodel.cn) | +| `openrouter`(要テスト) | LLM(推奨、全モデルにアクセス可能) | [openrouter.ai](https://openrouter.ai) | +| `anthropic`(要テスト) | LLM(Claude 直接) | [console.anthropic.com](https://console.anthropic.com) | +| `openai`(要テスト) | LLM(GPT 直接) | [platform.openai.com](https://platform.openai.com) | +| `deepseek`(要テスト) | LLM(DeepSeek 直接) | [platform.deepseek.com](https://platform.deepseek.com) | +| `groq` | LLM + **音声文字起こし**(Whisper) | [console.groq.com](https://console.groq.com) | +| `cerebras` | LLM(Cerebras 直接) | [cerebras.ai](https://cerebras.ai) | + ### 基本設定 1. **設定ファイルの作成:** @@ -451,17 +754,17 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る }, "providers": { "openrouter": { - "apiKey": "sk-or-v1-xxx" + "api_key": "sk-or-v1-xxx" }, "groq": { - "apiKey": "gsk_xxx" + "api_key": "gsk_xxx" } }, "channels": { "telegram": { "enabled": true, "token": "123456:ABC...", - "allowFrom": ["123456789"] + "allow_from": ["123456789"] }, "discord": { "enabled": true, @@ -473,18 +776,21 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る }, "feishu": { "enabled": false, - "appId": "cli_xxx", - "appSecret": "xxx", - "encryptKey": "", - "verificationToken": "", - "allowFrom": [] + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] } }, "tools": { "web": { "search": { - "apiKey": "BSA..." + "api_key": "BSA..." } + }, + "cron": { + "exec_timeout_minutes": 5 } }, "heartbeat": { @@ -496,6 +802,163 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る +### モデル設定 (model_list) + +> **新機能!** PicoClaw は現在 **モデル中心** の設定アプローチを採用しています。`ベンダー/モデル` 形式(例: `zhipu/glm-4.7`)を指定するだけで、新しいプロバイダーを追加できます—**コードの変更は一切不要!** + +この設計は、柔軟なプロバイダー選択による **マルチエージェントサポート** も可能にします: + +- **異なるエージェント、異なるプロバイダー** : 各エージェントは独自の LLM プロバイダーを使用可能 +- **フォールバックモデル** : 耐障性のため、プライマリモデルとフォールバックモデルを設定可能 +- **ロードバランシング** : 複数のエンドポイントにリクエストを分散 +- **集中設定管理** : すべてのプロバイダーを一箇所で管理 + +#### 📋 サポートされているすべてのベンダー + +| ベンダー | `model` プレフィックス | デフォルト API Base | プロトコル | API キー | +|-------------|-----------------|---------------------|----------|---------| +| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [キーを取得](https://platform.openai.com) | +| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [キーを取得](https://console.anthropic.com) | +| **Zhipu AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [キーを取得](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) | +| **DeepSeek** | `deepseek/` | `https://api.deepseek.com/v1` | OpenAI | [キーを取得](https://platform.deepseek.com) | +| **Google Gemini** | `gemini/` | `https://generativelanguage.googleapis.com/v1beta` | OpenAI | [キーを取得](https://aistudio.google.com/api-keys) | +| **Groq** | `groq/` | `https://api.groq.com/openai/v1` | OpenAI | [キーを取得](https://console.groq.com) | +| **Moonshot** | `moonshot/` | `https://api.moonshot.cn/v1` | OpenAI | [キーを取得](https://platform.moonshot.cn) | +| **Qwen (Alibaba)** | `qwen/` | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI | [キーを取得](https://dashscope.console.aliyun.com) | +| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [キーを取得](https://build.nvidia.com) | +| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | ローカル(キー不要) | +| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [キーを取得](https://openrouter.ai/keys) | +| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | ローカル | +| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [キーを取得](https://cerebras.ai) | +| **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [キーを取得](https://console.volcengine.com) | +| **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **Antigravity** | `antigravity/` | Google Cloud | カスタム | OAuthのみ | +| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | + +#### 基本設定 + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key" + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "sk-ant-your-key" + }, + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-zhipu-key" + } + ], + "agents": { + "defaults": { + "model": "gpt-5.2" + } + } +} +``` + +#### ベンダー別の例 + +**OpenAI** +```json +{ + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-..." +} +``` + +**Zhipu AI (GLM)** +```json +{ + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" +} +``` + +**Anthropic (OAuth使用)** +```json +{ + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "auth_method": "oauth" +} +``` +> OAuth認証を設定するには、`picoclaw auth login --provider anthropic` を実行してください。 + +#### ロードバランシング + +同じモデル名で複数のエンドポイントを設定すると、PicoClaw が自動的にラウンドロビンで分散します: + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api1.example.com/v1", + "api_key": "sk-key1" + }, + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api2.example.com/v1", + "api_key": "sk-key2" + } + ] +} +``` + +#### 従来の `providers` 設定からの移行 + +古い `providers` 設定は**非推奨**ですが、後方互換性のためにサポートされています。 + +**旧設定(非推奨):** +```json +{ + "providers": { + "zhipu": { + "api_key": "your-key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + }, + "agents": { + "defaults": { + "provider": "zhipu", + "model": "glm-4.7" + } + } +} +``` + +**新設定(推奨):** +```json +{ + "model_list": [ + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" + } + ], + "agents": { + "defaults": { + "model": "glm-4.7" + } + } +} +``` + +詳細な移行ガイドは、[docs/migration/model-list-migration.md](docs/migration/model-list-migration.md) を参照してください。 + ## CLI リファレンス | コマンド | 説明 | @@ -517,7 +980,7 @@ Discord: https://discord.gg/V4sAZ9XWpN ## 🐛 トラブルシューティング -### Web 検索で「API 配置问题」と表示される +### Web 検索で「API 設定の問題」と表示される 検索 API キーをまだ設定していない場合、これは正常です。PicoClaw は手動検索用の便利なリンクを提供します。 @@ -528,9 +991,14 @@ Web 検索を有効にするには: { "tools": { "web": { - "search": { + "brave": { + "enabled": true, "api_key": "YOUR_BRAVE_API_KEY", "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 } } } @@ -553,5 +1021,7 @@ Web 検索を有効にするには: |---------|--------|------------| | **OpenRouter** | 月 200K トークン | 複数モデル(Claude, GPT-4 など) | | **Zhipu** | 月 200K トークン | 中国ユーザー向け最適 | +| **Qwen** | 無料枠あり | 通義千問 (Qwen) | | **Brave Search** | 月 2000 クエリ | Web 検索機能 | | **Groq** | 無料枠あり | 高速推論(Llama, Mixtral) | +| **Cerebras** | 無料枠あり | 高速推論(Llama, Qwen など) | diff --git a/README.md b/README.md index 4861f9f6e..7bc7b1089 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,20 @@
-PicoClaw + PicoClaw -

PicoClaw: Ultra-Efficient AI Assistant in Go

+

PicoClaw: Ultra-Efficient AI Assistant in Go

-

$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!

-

+

$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!

-

-Go -Hardware -License -

- -[日本語](README.ja.md) | **English** +

+ Go + Hardware + License +
+ Website + Twitter +

+ [中文](README.zh.md) | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | **English**
--- @@ -37,7 +38,21 @@ +> [!CAUTION] +> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明** +> +> * **NO CRYPTO:** PicoClaw has **NO** official token/coin. All claims on `pump.fun` or other trading platforms are **SCAMS**. +> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)** +> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties. +> * **Warning:** picoclaw is in early development now and may have unresolved network security issues. Do not deploy to production environments before the v1.0 release. +> * **Note:** picoclaw has recently merged a lot of PRs, which may result in a larger memory footprint (10–20MB) in the latest versions. We plan to prioritize resource optimization as soon as the current feature set reaches a stable state. + + ## 📢 News +2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/picoclaw_community_roadmap_260216.md) —we can’t wait to have you on board! + +2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs&issues come in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development. +🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting. 2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 PicoClaw,Let's Go! @@ -53,12 +68,12 @@ 🤖 **AI-Bootstrapped**: Autonomous Go-native implementation — 95% Agent-generated core with human-in-the-loop refinement. -| | OpenClaw | NanoBot | **PicoClaw** | -| --- | --- | --- |--- | -| **Language** | TypeScript | Python | **Go** | -| **RAM** | >1GB |>100MB| **< 10MB** | -| **Startup**
(0.8GHz core) | >500s | >30s | **<1s** | -| **Cost** | Mac Mini 599$ | Most Linux SBC
~50$ |**Any Linux Board**
**As low as 10$** | +| | OpenClaw | NanoBot | **PicoClaw** | +| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- | +| **Language** | TypeScript | Python | **Go** | +| **RAM** | >1GB | >100MB | **< 10MB** | +| **Startup**
(0.8GHz core) | >500s | >30s | **<1s** | +| **Cost** | Mac Mini 599$ | Most Linux SBC
~50$ | **Any Linux Board**
**As low as 10$** | PicoClaw @@ -84,11 +99,25 @@ +### 📱 Run on old Android Phones +Give your decade-old phone a second life! Turn it into a smart AI Assistant with PicoClaw. Quick Start: +1. **Install Termux** (Available on F-Droid or Google Play). +2. **Execute cmds** +```bash +# Note: Replace v0.1.1 with the latest version from the Releases page +wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64 +chmod +x picoclaw-linux-arm64 +pkg install proot +termux-chroot ./picoclaw-linux-arm64 onboard +``` +And then follow the instructions in the "Quick Start" section to complete the configuration! +PicoClaw + ### 🐜 Innovative Low-Footprint Deploy PicoClaw can be deployed on almost any Linux device! -- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant +- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant - $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance - $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring @@ -165,7 +194,7 @@ docker compose --profile gateway up -d > [!TIP] > Set your API key in `~/.picoclaw/config.json`. > Get API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) -> Web search is **optional** - get free [Brave Search API](https://brave.com/search/api) (2000 free queries/month) +> Web search is **optional** - get free [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback. **1. Initialize** @@ -180,33 +209,46 @@ picoclaw onboard "agents": { "defaults": { "workspace": "~/.picoclaw/workspace", - "model": "glm-4.7", + "model": "gpt4", "max_tokens": 8192, "temperature": 0.7, "max_tool_iterations": 20 } }, - "providers": { - "openrouter": { - "api_key": "xxx", - "api_base": "https://openrouter.ai/api/v1" + "model_list": [ + { + "model_name": "gpt4", + "model": "openai/gpt-5.2", + "api_key": "your-api-key" + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "your-anthropic-key" } - }, + ], "tools": { "web": { - "search": { + "brave": { + "enabled": false, "api_key": "YOUR_BRAVE_API_KEY", "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 } } } } ``` +> **New**: The `model_list` configuration format allows zero-code provider addition. See [Model Configuration](#model-configuration-model_list) for details. + **3. Get API Keys** -- **LLM Provider**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) -- **Web Search** (optional): [Brave Search](https://brave.com/search/api) - Free tier available (2000 requests/month) +* **LLM Provider**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +* **Web Search** (optional): [Brave Search](https://brave.com/search/api) - Free tier available (2000 requests/month) > **Note**: See `config.example.json` for a complete configuration template. @@ -222,23 +264,25 @@ That's it! You have a working AI assistant in 2 minutes. ## 💬 Chat Apps -Talk to your picoclaw through Telegram, Discord, or DingTalk +Talk to your picoclaw through Telegram, Discord, DingTalk, LINE, or WeCom -| Channel | Setup | -|---------|-------| -| **Telegram** | Easy (just a token) | -| **Discord** | Easy (bot token + intents) | -| **QQ** | Easy (AppID + AppSecret) | -| **DingTalk** | Medium (app credentials) | +| Channel | Setup | +| ------------ | ---------------------------------- | +| **Telegram** | Easy (just a token) | +| **Discord** | Easy (bot token + intents) | +| **QQ** | Easy (AppID + AppSecret) | +| **DingTalk** | Medium (app credentials) | +| **LINE** | Medium (credentials + webhook URL) | +| **WeCom** | Medium (CorpID + webhook setup) |
Telegram (Recommended) **1. Create a bot** -- Open Telegram, search `@BotFather` -- Send `/newbot`, follow prompts -- Copy the token +* Open Telegram, search `@BotFather` +* Send `/newbot`, follow prompts +* Copy the token **2. Configure** @@ -248,7 +292,7 @@ Talk to your picoclaw through Telegram, Discord, or DingTalk "telegram": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allow_from": ["YOUR_USER_ID"] } } } @@ -269,19 +313,19 @@ picoclaw gateway **1. Create a bot** -- Go to -- Create an application → Bot → Add Bot -- Copy the bot token +* Go to +* Create an application → Bot → Add Bot +* Copy the bot token **2. Enable intents** -- In the Bot settings, enable **MESSAGE CONTENT INTENT** -- (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data +* In the Bot settings, enable **MESSAGE CONTENT INTENT** +* (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data **3. Get your User ID** -- Discord Settings → Advanced → enable **Developer Mode** -- Right-click your avatar → **Copy User ID** +* Discord Settings → Advanced → enable **Developer Mode** +* Right-click your avatar → **Copy User ID** **4. Configure** @@ -291,7 +335,8 @@ picoclaw gateway "discord": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allow_from": ["YOUR_USER_ID"], + "mention_only": false } } } @@ -299,10 +344,14 @@ picoclaw gateway **5. Invite the bot** -- OAuth2 → URL Generator -- Scopes: `bot` -- Bot Permissions: `Send Messages`, `Read Message History` -- Open the generated invite URL and add the bot to your server +* OAuth2 → URL Generator +* Scopes: `bot` +* Bot Permissions: `Send Messages`, `Read Message History` +* Open the generated invite URL and add the bot to your server + +**Optional: Mention-only mode** + +Set `"mention_only": true` to make the bot respond only when @-mentioned. Useful for shared servers where you want the bot to respond only when explicitly called. **6. Run** @@ -317,7 +366,7 @@ picoclaw gateway **1. Create a bot** -- Go to [QQ Open Platform](https://connect.qq.com/) +- Go to [QQ Open Platform](https://q.qq.com/#) - Create an application → Get **AppID** and **AppSecret** **2. Configure** @@ -350,9 +399,9 @@ picoclaw gateway **1. Create a bot** -- Go to [Open Platform](https://open.dingtalk.com/) -- Create an internal app -- Copy Client ID and Client Secret +* Go to [Open Platform](https://open.dingtalk.com/) +* Create an internal app +* Copy Client ID and Client Secret **2. Configure** @@ -369,7 +418,7 @@ picoclaw gateway } ``` -> Set `allow_from` to empty to allow all users, or specify QQ numbers to restrict access. +> Set `allow_from` to empty to allow all users, or specify DingTalk user IDs to restrict access. **3. Run** @@ -379,14 +428,143 @@ picoclaw gateway
+
+LINE + +**1. Create a LINE Official Account** + +- Go to [LINE Developers Console](https://developers.line.biz/) +- Create a provider → Create a Messaging API channel +- Copy **Channel Secret** and **Channel Access Token** + +**2. Configure** + +```json +{ + "channels": { + "line": { + "enabled": true, + "channel_secret": "YOUR_CHANNEL_SECRET", + "channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN", + "webhook_host": "0.0.0.0", + "webhook_port": 18791, + "webhook_path": "/webhook/line", + "allow_from": [] + } + } +} +``` + +**3. Set up Webhook URL** + +LINE requires HTTPS for webhooks. Use a reverse proxy or tunnel: + +```bash +# Example with ngrok +ngrok http 18791 +``` + +Then set the Webhook URL in LINE Developers Console to `https://your-domain/webhook/line` and enable **Use webhook**. + +**4. Run** + +```bash +picoclaw gateway +``` + +> In group chats, the bot responds only when @mentioned. Replies quote the original message. + +> **Docker Compose**: Add `ports: ["18791:18791"]` to the `picoclaw-gateway` service to expose the webhook port. + +
+ +
+WeCom (企业微信) + +PicoClaw supports two types of WeCom integration: + +**Option 1: WeCom Bot (智能机器人)** - Easier setup, supports group chats +**Option 2: WeCom App (自建应用)** - More features, proactive messaging + +See [WeCom App Configuration Guide](docs/wecom-app-configuration.md) for detailed setup instructions. + +**Quick Setup - WeCom Bot:** + +**1. Create a bot** + +* Go to WeCom Admin Console → Group Chat → Add Group Bot +* Copy the webhook URL (format: `https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`) + +**2. Configure** + +```json +{ + "channels": { + "wecom": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18793, + "webhook_path": "/webhook/wecom", + "allow_from": [] + } + } +} +``` + +**Quick Setup - WeCom App:** + +**1. Create an app** + +* Go to WeCom Admin Console → App Management → Create App +* Copy **AgentId** and **Secret** +* Go to "My Company" page, copy **CorpID** + +**2. Configure receive message** + +* In App details, click "Receive Message" → "Set API" +* Set URL to `http://your-server:18792/webhook/wecom-app` +* Generate **Token** and **EncodingAESKey** + +**3. Configure** + +```json +{ + "channels": { + "wecom_app": { + "enabled": true, + "corp_id": "wwxxxxxxxxxxxxxxxx", + "corp_secret": "YOUR_CORP_SECRET", + "agent_id": 1000002, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18792, + "webhook_path": "/webhook/wecom-app", + "allow_from": [] + } + } +} +``` + +**4. Run** + +```bash +picoclaw gateway +``` + +> **Note**: WeCom App requires opening port 18792 for webhook callbacks. Use a reverse proxy for HTTPS. + +
+ ## ClawdChat Join the Agent Social Network Connect Picoclaw to the Agent Social Network simply by sending a single message via the CLI or any integrated Chat App. **Read `https://clawdchat.ai/skill.md` and follow the instructions to join [ClawdChat.ai](https://clawdchat.ai)** - - ## ⚙️ Configuration Config file: `~/.picoclaw/config.json` @@ -410,6 +588,100 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa └── USER.md # User preferences ``` +### 🔒 Security Sandbox + +PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace. + +#### Default Configuration + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `workspace` | `~/.picoclaw/workspace` | Working directory for the agent | +| `restrict_to_workspace` | `true` | Restrict file/command access to workspace | + +#### Protected Tools + +When `restrict_to_workspace: true`, the following tools are sandboxed: + +| Tool | Function | Restriction | +|------|----------|-------------| +| `read_file` | Read files | Only files within workspace | +| `write_file` | Write files | Only files within workspace | +| `list_dir` | List directories | Only directories within workspace | +| `edit_file` | Edit files | Only files within workspace | +| `append_file` | Append to files | Only files within workspace | +| `exec` | Execute commands | Command paths must be within workspace | + +#### Additional Exec Protection + +Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous commands: + +* `rm -rf`, `del /f`, `rmdir /s` — Bulk deletion +* `format`, `mkfs`, `diskpart` — Disk formatting +* `dd if=` — Disk imaging +* Writing to `/dev/sd[a-z]` — Direct disk writes +* `shutdown`, `reboot`, `poweroff` — System shutdown +* Fork bomb `:(){ :|:& };:` + +#### Error Examples + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### Disabling Restrictions (Security Risk) + +If you need the agent to access paths outside the workspace: + +**Method 1: Config file** + +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**Method 2: Environment variable** + +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **Warning**: Disabling this restriction allows the agent to access any path on your system. Use with caution in controlled environments only. + +#### Security Boundary Consistency + +The `restrict_to_workspace` setting applies consistently across all execution paths: + +| Execution Path | Security Boundary | +|----------------|-------------------| +| Main Agent | `restrict_to_workspace` ✅ | +| Subagent / Spawn | Inherits same restriction ✅ | +| Heartbeat tasks | Inherits same restriction ✅ | + +All paths share the same workspace restriction — there's no way to bypass the security boundary through subagents or scheduled tasks. + ### Heartbeat (Periodic Tasks) PicoClaw can perform periodic tasks automatically. Create a `HEARTBEAT.md` file in your workspace: @@ -483,30 +755,227 @@ The subagent has access to tools (message, web_search, etc.) and can communicate | `interval` | `30` | Check interval in minutes (min: 5) | **Environment variables:** -- `PICOCLAW_HEARTBEAT_ENABLED=false` to disable -- `PICOCLAW_HEARTBEAT_INTERVAL=60` to change interval + +* `PICOCLAW_HEARTBEAT_ENABLED=false` to disable +* `PICOCLAW_HEARTBEAT_INTERVAL=60` to change interval ### Providers > [!NOTE] > Groq provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed. -| Provider | Purpose | Get API Key | -|----------|---------|-------------| -| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | -| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) | -| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | -| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | -| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | -| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | -| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | +| Provider | Purpose | Get API Key | +| -------------------------- | --------------------------------------- | ------------------------------------------------------ | +| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) | +| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | +| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | +| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | +| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | +| `qwen` | LLM (Qwen direct) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | +| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | +| `cerebras` | LLM (Cerebras direct) | [cerebras.ai](https://cerebras.ai) | + +### Model Configuration (model_list) + +> **What's New?** PicoClaw now uses a **model-centric** configuration approach. Simply specify `vendor/model` format (e.g., `zhipu/glm-4.7`) to add new providers—**zero code changes required!** + +This design also enables **multi-agent support** with flexible provider selection: + +- **Different agents, different providers**: Each agent can use its own LLM provider +- **Model fallbacks**: Configure primary and fallback models for resilience +- **Load balancing**: Distribute requests across multiple endpoints +- **Centralized configuration**: Manage all providers in one place + +#### 📋 All Supported Vendors + +| Vendor | `model` Prefix | Default API Base | Protocol | API Key | +|--------|----------------|------------------|----------|---------| +| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Get Key](https://platform.openai.com) | +| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Get Key](https://console.anthropic.com) | +| **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Get Key](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) | +| **DeepSeek** | `deepseek/` | `https://api.deepseek.com/v1` | OpenAI | [Get Key](https://platform.deepseek.com) | +| **Google Gemini** | `gemini/` | `https://generativelanguage.googleapis.com/v1beta` | OpenAI | [Get Key](https://aistudio.google.com/api-keys) | +| **Groq** | `groq/` | `https://api.groq.com/openai/v1` | OpenAI | [Get Key](https://console.groq.com) | +| **Moonshot** | `moonshot/` | `https://api.moonshot.cn/v1` | OpenAI | [Get Key](https://platform.moonshot.cn) | +| **通义千问 (Qwen)** | `qwen/` | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI | [Get Key](https://dashscope.console.aliyun.com) | +| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Get Key](https://build.nvidia.com) | +| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (no key needed) | +| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Get Key](https://openrouter.ai/keys) | +| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local | +| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Get Key](https://cerebras.ai) | +| **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Get Key](https://console.volcengine.com) | +| **神算云** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only | +| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | + +#### Basic Configuration + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key" + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "sk-ant-your-key" + }, + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-zhipu-key" + } + ], + "agents": { + "defaults": { + "model": "gpt-5.2" + } + } +} +``` + +#### Vendor-Specific Examples + +**OpenAI** +```json +{ + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-..." +} +``` + +**智谱 AI (GLM)** +```json +{ + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" +} +``` + +**DeepSeek** +```json +{ + "model_name": "deepseek-chat", + "model": "deepseek/deepseek-chat", + "api_key": "sk-..." +} +``` + +**Anthropic (with API key)** +```json +{ + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "sk-ant-your-key" +} +``` +> Run `picoclaw auth login --provider anthropic` to paste your API token. + +**Ollama (local)** +```json +{ + "model_name": "llama3", + "model": "ollama/llama3" +} +``` + +**Custom Proxy/API** +```json +{ + "model_name": "my-custom-model", + "model": "openai/custom-model", + "api_base": "https://my-proxy.com/v1", + "api_key": "sk-..." +} +``` + +#### Load Balancing + +Configure multiple endpoints for the same model name—PicoClaw will automatically round-robin between them: + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api1.example.com/v1", + "api_key": "sk-key1" + }, + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api2.example.com/v1", + "api_key": "sk-key2" + } + ] +} +``` + +#### Migration from Legacy `providers` Config + +The old `providers` configuration is **deprecated** but still supported for backward compatibility. + +**Old Config (deprecated):** +```json +{ + "providers": { + "zhipu": { + "api_key": "your-key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + }, + "agents": { + "defaults": { + "provider": "zhipu", + "model": "glm-4.7" + } + } +} +``` + +**New Config (recommended):** +```json +{ + "model_list": [ + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" + } + ], + "agents": { + "defaults": { + "model": "glm-4.7" + } + } +} +``` + +For detailed migration guide, see [docs/migration/model-list-migration.md](docs/migration/model-list-migration.md). + +### Provider Architecture + +PicoClaw routes providers by protocol family: + +- OpenAI-compatible protocol: OpenRouter, OpenAI-compatible gateways, Groq, Zhipu, and vLLM-style endpoints. +- Anthropic protocol: Claude-native API behavior. +- Codex/OAuth path: OpenAI OAuth/token authentication route. + +This keeps the runtime lightweight while making new OpenAI-compatible backends mostly a config operation (`api_base` + `api_key`).
Zhipu **1. Get API key and base URL** -- Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) +* Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) **2. Configure** @@ -525,8 +994,8 @@ The subagent has access to tools (message, web_search, etc.) and can communicate "zhipu": { "api_key": "Your API Key", "api_base": "https://open.bigmodel.cn/api/paas/v4" - }, - }, + } + } } ``` @@ -587,9 +1056,18 @@ picoclaw agent -m "Hello" }, "tools": { "web": { - "search": { - "api_key": "BSA..." + "brave": { + "enabled": false, + "api_key": "BSA...", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 } + }, + "cron": { + "exec_timeout_minutes": 5 } }, "heartbeat": { @@ -603,23 +1081,23 @@ picoclaw agent -m "Hello" ## CLI Reference -| Command | Description | -|---------|-------------| -| `picoclaw onboard` | Initialize config & workspace | -| `picoclaw agent -m "..."` | Chat with the agent | -| `picoclaw agent` | Interactive chat mode | -| `picoclaw gateway` | Start the gateway | -| `picoclaw status` | Show status | -| `picoclaw cron list` | List all scheduled jobs | -| `picoclaw cron add ...` | Add a scheduled job | +| Command | Description | +| ------------------------- | ----------------------------- | +| `picoclaw onboard` | Initialize config & workspace | +| `picoclaw agent -m "..."` | Chat with the agent | +| `picoclaw agent` | Interactive chat mode | +| `picoclaw gateway` | Start the gateway | +| `picoclaw status` | Show status | +| `picoclaw cron list` | List all scheduled jobs | +| `picoclaw cron add ...` | Add a scheduled job | ### Scheduled Tasks / Reminders PicoClaw supports scheduled reminders and recurring tasks through the `cron` tool: -- **One-time reminders**: "Remind me in 10 minutes" → triggers once after 10min -- **Recurring tasks**: "Remind me every 2 hours" → triggers every 2 hours -- **Cron expressions**: "Remind me at 9am daily" → uses cron expression +* **One-time reminders**: "Remind me in 10 minutes" → triggers once after 10min +* **Recurring tasks**: "Remind me every 2 hours" → triggers every 2 hours +* **Cron expressions**: "Remind me at 9am daily" → uses cron expression Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically. @@ -627,6 +1105,12 @@ Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically. PRs welcome! The codebase is intentionally small and readable. 🤗 +Roadmap coming soon... + +Developer group building, Entry Requirement: At least 1 Merged PR. + +User Groups: + discord: PicoClaw @@ -639,21 +1123,28 @@ This is normal if you haven't configured a search API key yet. PicoClaw will pro To enable web search: -1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) -2. Add to `~/.picoclaw/config.json`: - - ```json - { - "tools": { - "web": { - "search": { - "api_key": "YOUR_BRAVE_API_KEY", - "max_results": 5 - } - } - } - } - ``` +1. **Option 1 (Recommended)**: Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) for the best results. +2. **Option 2 (No Credit Card)**: If you don't have a key, we automatically fall back to **DuckDuckGo** (no key required). + +Add the key to `~/.picoclaw/config.json` if using Brave: + +```json +{ + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` ### Getting content filtering errors @@ -667,9 +1158,10 @@ This happens when another instance of the bot is running. Make sure only one `pi ## 📝 API Key Comparison -| Service | Free Tier | Use Case | -|---------|-----------|-----------| -| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) | -| **Zhipu** | 200K tokens/month | Best for Chinese users | -| **Brave Search** | 2000 queries/month | Web search functionality | -| **Groq** | Free tier available | Fast inference (Llama, Mixtral) | +| Service | Free Tier | Use Case | +| ---------------- | ------------------- | ------------------------------------- | +| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) | +| **Zhipu** | 200K tokens/month | Best for Chinese users | +| **Brave Search** | 2000 queries/month | Web search functionality | +| **Groq** | Free tier available | Fast inference (Llama, Mixtral) | +| **Cerebras** | Free tier available | Fast inference (Llama, Qwen, etc.) | diff --git a/README.pt-br.md b/README.pt-br.md new file mode 100644 index 000000000..ec8fe8e1c --- /dev/null +++ b/README.pt-br.md @@ -0,0 +1,1122 @@ +
+PicoClaw + +

PicoClaw: Assistente de IA Ultra-Eficiente em Go

+ +

Hardware de $10 · 10MB de RAM · Boot em 1s · 皮皮虾,我们走!

+ +

+ Go + Hardware + License +
+ Website + Twitter +

+ + [中文](README.zh.md) | [日本語](README.ja.md) | **Português** | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | [English](README.md) +
+ +--- + +🦐 **PicoClaw** é um assistente pessoal de IA ultra-leve inspirado no [nanobot](https://github.com/HKUDS/nanobot), reescrito do zero em **Go** por meio de um processo de "auto-inicialização" (self-bootstrapping) — onde o próprio agente de IA conduziu toda a migração de arquitetura e otimização de código. + +⚡️ **Extremamente leve:** Roda em hardware de apenas **$10** com **<10MB** de RAM. Isso é 99% menos memória que o OpenClaw e 98% mais barato que um Mac mini! + + + + + + +
+

+ +

+
+

+ +

+
+ +> [!CAUTION] +> **🚨 DECLARAÇÃO DE SEGURANÇA & CANAIS OFICIAIS** +> +> * **SEM CRIPTOMOEDAS:** O PicoClaw **NÃO** possui nenhum token/moeda oficial. Todas as alegações no `pump.fun` ou outras plataformas de negociação são **GOLPES**. +> * **DOMÍNIO OFICIAL:** O **ÚNICO** site oficial é o **[picoclaw.io](https://picoclaw.io)**, e o site da empresa é o **[sipeed.com](https://sipeed.com)**. +> * **Aviso:** Muitos domínios `.ai/.org/.com/.net/...` foram registrados por terceiros, não são nossos. +> * **Aviso:** O PicoClaw está em fase inicial de desenvolvimento e pode ter problemas de segurança de rede não resolvidos. Não implante em ambientes de produção antes da versão v1.0. +> * **Nota:** O PicoClaw recentemente fez merge de muitos PRs, o que pode resultar em maior consumo de memória (10-20MB) nas versões mais recentes. Planejamos priorizar a otimização de recursos assim que o conjunto de funcionalidades estiver estável. + + +## 📢 Novidades + +2026-02-16 🎉 PicoClaw atingiu 12K stars em uma semana! Obrigado a todos pelo apoio! O PicoClaw está crescendo mais rápido do que jamais imaginamos. Dado o alto volume de PRs, precisamos urgentemente de maintainers da comunidade. Nossos papéis de voluntários e roadmap foram publicados oficialmente [aqui](docs/picoclaw_community_roadmap_260216.md) — estamos ansiosos para ter você a bordo! + +2026-02-13 🎉 PicoClaw atingiu 5000 stars em 4 dias! Obrigado à comunidade! Estamos finalizando o **Roadmap do Projeto** e configurando o **Grupo de Desenvolvedores** para acelerar o desenvolvimento do PicoClaw. + +🚀 **Chamada para Ação:** Envie suas solicitações de funcionalidades nas GitHub Discussions. Revisaremos e priorizaremos na próxima reunião semanal. + +2026-02-09 🎉 PicoClaw lançado oficialmente! Construído em 1 dia para trazer Agentes de IA para hardware de $10 com <10MB de RAM. 🦐 PicoClaw, Partiu! + +## ✨ Funcionalidades + +🪶 **Ultra-Leve**: Consumo de memória <10MB — 99% menor que o Clawdbot para funcionalidades essenciais. + +💰 **Custo Mínimo**: Eficiente o suficiente para rodar em hardware de $10 — 98% mais barato que um Mac mini. + +⚡️ **Inicialização Relámpago**: Tempo de inicialização 400X mais rápido, boot em 1 segundo mesmo em CPU single-core de 0.6GHz. + +🌍 **Portabilidade Real**: Um único binário auto-contido para RISC-V, ARM e x86. Um clique e já era! + +🤖 **Auto-Construído por IA**: Implementação nativa em Go de forma autônoma — 95% do núcleo gerado pelo Agente com refinamento humano no loop. + +| | OpenClaw | NanoBot | **PicoClaw** | +| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- | +| **Linguagem** | TypeScript | Python | **Go** | +| **RAM** | >1GB | >100MB | **< 10MB** | +| **Inicialização**
(CPU 0.8GHz) | >500s | >30s | **<1s** | +| **Custo** | Mac Mini $599 | Maioria dos SBC Linux
~$50 | **Qualquer placa Linux**
**A partir de $10** | + +PicoClaw + +## 🦾 Demonstração + +### 🛠️ Fluxos de Trabalho Padrão do Assistente + + + + + + + + + + + + + + + + + +

🧩 Engenharia Full-Stack

🗂️ Gerenciamento de Logs & Planejamento

🔎 Busca Web & Aprendizado

Desenvolver • Implantar • EscalarAgendar • Automatizar • MemorizarDescobrir • Analisar • Tendências
+ +### 📱 Rode em celulares Android antigos + +Dê uma segunda vida ao seu celular de dez anos atrás! Transforme-o em um assistente de IA inteligente com o PicoClaw. Início rápido: + +1. **Instale o Termux** (Disponível no F-Droid ou Google Play). +2. **Execute os comandos** + +```bash +# Nota: Substitua v0.1.1 pela versao mais recente da pagina de Releases +wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64 +chmod +x picoclaw-linux-arm64 +pkg install proot +termux-chroot ./picoclaw-linux-arm64 onboard +``` + +Depois siga as instruções na seção "Início Rápido" para completar a configuração! + +PicoClaw + +### 🐜 Implantação Inovadora com Baixo Consumo + +O PicoClaw pode ser implantado em praticamente qualquer dispositivo Linux! + +- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) versão E (Ethernet) ou W (WiFi6), para Assistente Doméstico Minimalista +- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), ou $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) para Manutenção Automatizada de Servidores +- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) ou $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) para Monitoramento Inteligente + +https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4 + +🌟 Mais cenários de implantação aguardam você! + +## 📦 Instalação + +### Instalar com binário pré-compilado + +Baixe o binário para sua plataforma na página de [releases](https://github.com/sipeed/picoclaw/releases). + +### Instalar a partir do código-fonte (funcionalidades mais recentes, recomendado para desenvolvimento) + +```bash +git clone https://github.com/sipeed/picoclaw.git + +cd picoclaw +make deps + +# Build, sem necessidade de instalar +make build + +# Build para multiplas plataformas +make build-all + +# Build e Instalar +make install +``` + +## 🐳 Docker Compose + +Você tambêm pode rodar o PicoClaw usando Docker Compose sem instalar nada localmente. + +```bash +# 1. Clone este repositorio +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. Configure suas API keys +cp config/config.example.json config/config.json +vim config/config.json # Configure DISCORD_BOT_TOKEN, API keys, etc. + +# 3. Build & Iniciar +docker compose --profile gateway up -d + +# 4. Ver logs +docker compose logs -f picoclaw-gateway + +# 5. Parar +docker compose --profile gateway down +``` + +### Modo Agente (Execução única) + +```bash +# Fazer uma pergunta +docker compose run --rm picoclaw-agent -m "Quanto e 2+2?" + +# Modo interativo +docker compose run --rm picoclaw-agent +``` + +### Rebuild + +```bash +docker compose --profile gateway build --no-cache +docker compose --profile gateway up -d +``` + +### 🚀 Início Rápido + +> [!TIP] +> Configure sua API key em `~/.picoclaw/config.json`. +> Obtenha API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) +> Busca web e **opcional** — obtenha a [Brave Search API](https://brave.com/search/api) gratuita (2000 consultas grátis/mês) ou use o fallback automático integrado. + +**1. Inicializar** + +```bash +picoclaw onboard +``` + +**2. Configurar** (`~/.picoclaw/config.json`) + +```json +{ + "model_list": [ + { + "model_name": "gpt4", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key", + "api_base": "https://api.openai.com/v1" + } + ], + "agents": { + "defaults": { + "model": "gpt4" + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +**3. Obter API Keys** + +* **Provedor de LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +* **Busca Web** (opcional): [Brave Search](https://brave.com/search/api) - Plano gratuito disponível (2000 consultas/mês) + +> **Nota**: Veja `config.example.json` para um modelo de configuração completo. + +**4. Conversar** + +```bash +picoclaw agent -m "Quanto e 2+2?" +``` + +Pronto! Você tem um assistente de IA funcionando em 2 minutos. + +--- + +## 💬 Integração com Apps de Chat + +Converse com seu PicoClaw via Telegram, Discord, DingTalk, LINE ou WeCom. + +| Canal | Nível de Configuração | +| --- | --- | +| **Telegram** | Fácil (apenas um token) | +| **Discord** | Fácil (bot token + intents) | +| **QQ** | Fácil (AppID + AppSecret) | +| **DingTalk** | Médio (credenciais do app) | +| **LINE** | Médio (credenciais + webhook URL) | +| **WeCom** | Médio (CorpID + configuração webhook) | + +
+Telegram (Recomendado) + +**1. Criar o bot** + +* Abra o Telegram, busque `@BotFather` +* Envie `/newbot`, siga as instruções +* Copie o token + +**2. Configurar** + +```json +{ + "channels": { + "telegram": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allow_from": ["YOUR_USER_ID"] + } + } +} +``` + +> Obtenha seu User ID pelo `@userinfobot` no Telegram. + +**3. Executar** + +```bash +picoclaw gateway +``` + +
+ +
+Discord + +**1. Criar o bot** + +* Acesse +* Crie um aplicativo → Bot → Add Bot +* Copie o token do bot + +**2. Habilitar Intents** + +* Nas configurações do Bot, habilite **MESSAGE CONTENT INTENT** +* (Opcional) Habilite **SERVER MEMBERS INTENT** se quiser usar lista de permissões baseada em dados dos membros + +**3. Obter seu User ID** + +* Configurações do Discord → Avançado → habilite **Modo Desenvolvedor** +* Clique com botão direito no seu avatar → **Copiar ID do Usuário** + +**4. Configurar** + +```json +{ + "channels": { + "discord": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allow_from": ["YOUR_USER_ID"] + } + } +} +``` + +**5. Convidar o bot** + +* OAuth2 → URL Generator +* Scopes: `bot` +* Bot Permissions: `Send Messages`, `Read Message History` +* Abra a URL de convite gerada e adicione o bot ao seu servidor + +**6. Executar** + +```bash +picoclaw gateway +``` + +
+ +
+QQ + +**1. Criar o bot** + +- Acesse a [QQ Open Platform](https://q.qq.com/#) +- Crie um aplicativo → Obtenha **AppID** e **AppSecret** + +**2. Configurar** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "YOUR_APP_ID", + "app_secret": "YOUR_APP_SECRET", + "allow_from": [] + } + } +} +``` + +> Deixe `allow_from` vazio para permitir todos os usuários, ou especifique números QQ para restringir o acesso. + +**3. Executar** + +```bash +picoclaw gateway +``` + +
+ +
+DingTalk + +**1. Criar o bot** + +* Acesse a [Open Platform](https://open.dingtalk.com/) +* Crie um app interno +* Copie o Client ID e Client Secret + +**2. Configurar** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "allow_from": [] + } + } +} +``` + +> Deixe `allow_from` vazio para permitir todos os usuários, ou especifique IDs para restringir o acesso. + +**3. Executar** + +```bash +picoclaw gateway +``` + +
+ +
+LINE + +**1. Criar uma Conta Oficial LINE** + +- Acesse o [LINE Developers Console](https://developers.line.biz/) +- Crie um provider → Crie um canal Messaging API +- Copie o **Channel Secret** e o **Channel Access Token** + +**2. Configurar** + +```json +{ + "channels": { + "line": { + "enabled": true, + "channel_secret": "YOUR_CHANNEL_SECRET", + "channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN", + "webhook_host": "0.0.0.0", + "webhook_port": 18791, + "webhook_path": "/webhook/line", + "allow_from": [] + } + } +} +``` + +**3. Configurar URL do Webhook** + +O LINE requer HTTPS para webhooks. Use um reverse proxy ou tunnel: + +```bash +# Exemplo com ngrok +ngrok http 18791 +``` + +Em seguida, configure a Webhook URL no LINE Developers Console para `https://seu-dominio/webhook/line` e habilite **Use webhook**. + +**4. Executar** + +```bash +picoclaw gateway +``` + +> Em chats de grupo, o bot responde apenas quando mencionado com @. As respostas citam a mensagem original. + +> **Docker Compose**: Adicione `ports: ["18791:18791"]` ao serviço `picoclaw-gateway` para expor a porta do webhook. + +
+ +
+WeCom (WeChat Work) + +O PicoClaw suporta dois tipos de integração WeCom: + +**Opção 1: WeCom Bot (Robô Inteligente)** - Configuração mais fácil, suporta chats em grupo +**Opção 2: WeCom App (Aplicativo Personalizado)** - Mais recursos, mensagens proativas + +Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para instruções detalhadas. + +**Configuração Rápida - WeCom Bot:** + +**1. Criar um bot** + +* Acesse o Console de Administração WeCom → Chat em Grupo → Adicionar Bot de Grupo +* Copie a URL do webhook (formato: `https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`) + +**2. Configurar** + +```json +{ + "channels": { + "wecom": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18793, + "webhook_path": "/webhook/wecom", + "allow_from": [] + } + } +} +``` + +**Configuração Rápida - WeCom App:** + +**1. Criar um aplicativo** + +* Acesse o Console de Administração WeCom → Gerenciamento de Aplicativos → Criar Aplicativo +* Copie o **AgentId** e o **Secret** +* Acesse a página "Minha Empresa", copie o **CorpID** + +**2. Configurar recebimento de mensagens** + +* Nos detalhes do aplicativo, clique em "Receber Mensagens" → "Configurar API" +* Defina a URL como `http://your-server:18792/webhook/wecom-app` +* Gere o **Token** e o **EncodingAESKey** + +**3. Configurar** + +```json +{ + "channels": { + "wecom_app": { + "enabled": true, + "corp_id": "wwxxxxxxxxxxxxxxxx", + "corp_secret": "YOUR_CORP_SECRET", + "agent_id": 1000002, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18792, + "webhook_path": "/webhook/wecom-app", + "allow_from": [] + } + } +} +``` + +**4. Executar** + +```bash +picoclaw gateway +``` + +> **Nota**: O WeCom App requer a abertura da porta 18792 para callbacks de webhook. Use um proxy reverso para HTTPS em produção. + +
+ +## ClawdChat Junte-se a Rede Social de Agentes + +Conecte o PicoClaw a Rede Social de Agentes simplesmente enviando uma única mensagem via CLI ou qualquer App de Chat integrado. + +**Leia `https://clawdchat.ai/skill.md` e siga as instruções para entrar no [ClawdChat.ai](https://clawdchat.ai)** + +## ⚙️ Configuração Detalhada + +Arquivo de configuração: `~/.picoclaw/config.json` + +### Estrutura do Workspace + +O PicoClaw armazena dados no workspace configurado (padrão: `~/.picoclaw/workspace`): + +``` +~/.picoclaw/workspace/ +├── sessions/ # Sessoes de conversa e historico +├── memory/ # Memoria de longo prazo (MEMORY.md) +├── state/ # Estado persistente (ultimo canal, etc.) +├── cron/ # Banco de dados de tarefas agendadas +├── skills/ # Skills personalizadas +├── AGENTS.md # Guia de comportamento do Agente +├── HEARTBEAT.md # Prompts de tarefas periodicas (verificado a cada 30 min) +├── IDENTITY.md # Identidade do Agente +├── SOUL.md # Alma do Agente +├── TOOLS.md # Descrição das ferramentas +└── USER.md # Preferencias do usuario +``` + +### 🔒 Sandbox de Segurança + +O PicoClaw roda em um ambiente sandbox por padrão. O agente so pode acessar arquivos e executar comandos dentro do workspace configurado. + +#### Configuração Padrão + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| Opção | Padrão | Descrição | +|-------|--------|-----------| +| `workspace` | `~/.picoclaw/workspace` | Diretório de trabalho do agente | +| `restrict_to_workspace` | `true` | Restringir acesso de arquivos/comandos ao workspace | + +#### Ferramentas Protegidas + +Quando `restrict_to_workspace: true`, as seguintes ferramentas são restritas ao sandbox: + +| Ferramenta | Função | Restrição | +|------------|--------|-----------| +| `read_file` | Ler arquivos | Apenas arquivos dentro do workspace | +| `write_file` | Escrever arquivos | Apenas arquivos dentro do workspace | +| `list_dir` | Listar diretorios | Apenas diretorios dentro do workspace | +| `edit_file` | Editar arquivos | Apenas arquivos dentro do workspace | +| `append_file` | Adicionar a arquivos | Apenas arquivos dentro do workspace | +| `exec` | Executar comandos | Caminhos dos comandos devem estar dentro do workspace | + +#### Proteção Adicional do Exec + +Mesmo com `restrict_to_workspace: false`, a ferramenta `exec` bloqueia estes comandos perigosos: + +* `rm -rf`, `del /f`, `rmdir /s` — Exclusão em massa +* `format`, `mkfs`, `diskpart` — Formatação de disco +* `dd if=` — Criação de imagem de disco +* Escrita em `/dev/sd[a-z]` — Escrita direta no disco +* `shutdown`, `reboot`, `poweroff` — Desligamento do sistema +* Fork bomb `:(){ :|:& };:` + +#### Exemplos de Erro + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### Desabilitar Restrições (Risco de Segurança) + +Se você precisa que o agente acesse caminhos fora do workspace: + +**Método 1: Arquivo de configuração** + +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**Método 2: Variável de ambiente** + +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **Aviso**: Desabilitar esta restrição permite que o agente acesse qualquer caminho no seu sistema. Use com cuidado apenas em ambientes controlados. + +#### Consistência do Limite de Segurança + +A configuração `restrict_to_workspace` se aplica consistentemente em todos os caminhos de execução: + +| Caminho de Execução | Limite de Segurança | +|----------------------|---------------------| +| Agente Principal | `restrict_to_workspace` ✅ | +| Subagente / Spawn | Herda a mesma restrição ✅ | +| Tarefas Heartbeat | Herda a mesma restrição ✅ | + +Todos os caminhos compartilham a mesma restrição de workspace — nao há como contornar o limite de segurança por meio de subagentes ou tarefas agendadas. + +### Heartbeat (Tarefas Periódicas) + +O PicoClaw pode executar tarefas periódicas automaticamente. Crie um arquivo `HEARTBEAT.md` no seu workspace: + +```markdown +# Tarefas Periodicas + +- Verificar meu email para mensagens importantes +- Revisar minha agenda para proximos eventos +- Verificar a previsao do tempo +``` + +O agente lerá este arquivo a cada 30 minutos (configurável) e executará as tarefas usando as ferramentas disponíveis. + +#### Tarefas Assincronas com Spawn + +Para tarefas de longa duração (busca web, chamadas de API), use a ferramenta `spawn` para criar um **subagente**: + +```markdown +# Tarefas Periódicas + +## Tarefas Rápidas (resposta direta) +- Informar hora atual + +## Tarefas Longas (usar spawn para async) +- Buscar notícias de IA na web e resumir +- Verificar email e reportar mensagens importantes +``` + +**Comportamentos principais:** + +| Funcionalidade | Descrição | +|----------------|-----------| +| **spawn** | Cria subagente assíncrono, não bloqueia o heartbeat | +| **Contexto independente** | Subagente tem seu próprio contexto, sem histórico de sessão | +| **Ferramenta message** | Subagente se comunica diretamente com o usuário via ferramenta message | +| **Não-bloqueante** | Após o spawn, o heartbeat continua para a próxima tarefa | + +#### Como Funciona a Comunicação do Subagente + +``` +Heartbeat dispara + ↓ +Agente lê HEARTBEAT.md + ↓ +Para tarefa longa: spawn subagente + ↓ ↓ +Continua próxima tarefa Subagente trabalha independentemente + ↓ ↓ +Todas tarefas concluídas Subagente usa ferramenta "message" + ↓ ↓ +Responde HEARTBEAT_OK Usuário recebe resultado diretamente +``` + +O subagente tem acesso às ferramentas (message, web_search, etc.) e pode se comunicar com o usuário independentemente sem passar pelo agente principal. + +**Configuração:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| Opção | Padrão | Descrição | +|-------|--------|-----------| +| `enabled` | `true` | Habilitar/desabilitar heartbeat | +| `interval` | `30` | Intervalo de verificação em minutos (min: 5) | + +**Variáveis de ambiente:** + +* `PICOCLAW_HEARTBEAT_ENABLED=false` para desabilitar +* `PICOCLAW_HEARTBEAT_INTERVAL=60` para alterar o intervalo + +### Provedores + +> [!NOTE] +> O Groq fornece transcrição de voz gratuita via Whisper. Se configurado, mensagens de voz do Telegram serão automaticamente transcritas. + +| Provedor | Finalidade | Obter API Key | +| --- | --- | --- | +| `gemini` | LLM (Gemini direto) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (Zhipu direto) | [bigmodel.cn](bigmodel.cn) | +| `openrouter` (Em teste) | LLM (recomendado, acesso a todos os modelos) | [openrouter.ai](https://openrouter.ai) | +| `anthropic` (Em teste) | LLM (Claude direto) | [console.anthropic.com](https://console.anthropic.com) | +| `openai` (Em teste) | LLM (GPT direto) | [platform.openai.com](https://platform.openai.com) | +| `deepseek` (Em teste) | LLM (DeepSeek direto) | [platform.deepseek.com](https://platform.deepseek.com) | +| `qwen` | Alibaba Qwen | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | +| `cerebras` | Cerebras | [cerebras.ai](https://cerebras.ai) | +| `groq` | LLM + **Transcrição de voz** (Whisper) | [console.groq.com](https://console.groq.com) | + +
+Configuração Zhipu + +**1. Obter API key** + +* Obtenha a [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) + +**2. Configurar** + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "zhipu": { + "api_key": "Sua API Key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + } +} +``` + +**3. Executar** + +```bash +picoclaw agent -m "Ola, como vai?" +``` + +
+ +
+Exemplo de configuraçao completa + +```json +{ + "agents": { + "defaults": { + "model": "anthropic/claude-opus-4-5" + } + }, + "providers": { + "openrouter": { + "api_key": "sk-or-v1-xxx" + }, + "groq": { + "api_key": "gsk_xxx" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "123456:ABC...", + "allow_from": ["123456789"] + }, + "discord": { + "enabled": true, + "token": "", + "allow_from": [""] + }, + "whatsapp": { + "enabled": false + }, + "feishu": { + "enabled": false, + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] + }, + "qq": { + "enabled": false, + "app_id": "", + "app_secret": "", + "allow_from": [] + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "BSA...", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + }, + "cron": { + "exec_timeout_minutes": 5 + } + }, + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +
+ +### Configuração de Modelo (model_list) + +> **Novidade!** PicoClaw agora usa uma abordagem de configuração **centrada no modelo**. Basta especificar o formato `fornecedor/modelo` (ex: `zhipu/glm-4.7`) para adicionar novos provedores—**nenhuma alteração de código necessária!** + +Este design também possibilita o **suporte multi-agent** com seleção flexível de provedores: + +- **Diferentes agentes, diferentes provedores** : Cada agente pode usar seu próprio provedor LLM +- **Modelos de fallback** : Configure modelos primários e de reserva para resiliência +- **Balanceamento de carga** : Distribua solicitações entre múltiplos endpoints +- **Configuração centralizada** : Gerencie todos os provedores em um só lugar + +#### 📋 Todos os Fornecedores Suportados + +| Fornecedor | Prefixo `model` | API Base Padrão | Protocolo | Chave API | +|-------------|-----------------|------------------|----------|-----------| +| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Obter Chave](https://platform.openai.com) | +| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Obter Chave](https://console.anthropic.com) | +| **Zhipu AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Obter Chave](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) | +| **DeepSeek** | `deepseek/` | `https://api.deepseek.com/v1` | OpenAI | [Obter Chave](https://platform.deepseek.com) | +| **Google Gemini** | `gemini/` | `https://generativelanguage.googleapis.com/v1beta` | OpenAI | [Obter Chave](https://aistudio.google.com/api-keys) | +| **Groq** | `groq/` | `https://api.groq.com/openai/v1` | OpenAI | [Obter Chave](https://console.groq.com) | +| **Moonshot** | `moonshot/` | `https://api.moonshot.cn/v1` | OpenAI | [Obter Chave](https://platform.moonshot.cn) | +| **Qwen (Alibaba)** | `qwen/` | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI | [Obter Chave](https://dashscope.console.aliyun.com) | +| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Obter Chave](https://build.nvidia.com) | +| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (sem chave necessária) | +| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Obter Chave](https://openrouter.ai/keys) | +| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local | +| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Obter Chave](https://cerebras.ai) | +| **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Obter Chave](https://console.volcengine.com) | +| **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **Antigravity** | `antigravity/` | Google Cloud | Custom | Apenas OAuth | +| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | + +#### Configuração Básica + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key" + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "sk-ant-your-key" + }, + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-zhipu-key" + } + ], + "agents": { + "defaults": { + "model": "gpt-5.2" + } + } +} +``` + +#### Exemplos por Fornecedor + +**OpenAI** +```json +{ + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-..." +} +``` + +**Zhipu AI (GLM)** +```json +{ + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" +} +``` + +**Anthropic (com OAuth)** +```json +{ + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "auth_method": "oauth" +} +``` +> Execute `picoclaw auth login --provider anthropic` para configurar credenciais OAuth. + +#### Balanceamento de Carga + +Configure vários endpoints para o mesmo nome de modelo—PicoClaw fará round-robin automaticamente entre eles: + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api1.example.com/v1", + "api_key": "sk-key1" + }, + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api2.example.com/v1", + "api_key": "sk-key2" + } + ] +} +``` + +#### Migração da Configuração Legada `providers` + +A configuração antiga `providers` está **descontinuada** mas ainda é suportada para compatibilidade reversa. + +**Configuração Antiga (descontinuada):** +```json +{ + "providers": { + "zhipu": { + "api_key": "your-key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + }, + "agents": { + "defaults": { + "provider": "zhipu", + "model": "glm-4.7" + } + } +} +``` + +**Nova Configuração (recomendada):** +```json +{ + "model_list": [ + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" + } + ], + "agents": { + "defaults": { + "model": "glm-4.7" + } + } +} +``` + +Para o guia de migração detalhado, consulte [docs/migration/model-list-migration.md](docs/migration/model-list-migration.md). + +## Referência CLI + +| Comando | Descrição | +| --- | --- | +| `picoclaw onboard` | Inicializar configuração & workspace | +| `picoclaw agent -m "..."` | Conversar com o agente | +| `picoclaw agent` | Modo de chat interativo | +| `picoclaw gateway` | Iniciar o gateway (para bots de chat) | +| `picoclaw status` | Mostrar status | +| `picoclaw cron list` | Listar todas as tarefas agendadas | +| `picoclaw cron add ...` | Adicionar uma tarefa agendada | + +### Tarefas Agendadas / Lembretes + +O PicoClaw suporta lembretes agendados e tarefas recorrentes por meio da ferramenta `cron`: + +* **Lembretes únicos**: "Remind me in 10 minutes" (Me lembre em 10 minutos) → dispara uma vez após 10min +* **Tarefas recorrentes**: "Remind me every 2 hours" (Me lembre a cada 2 horas) → dispara a cada 2 horas +* **Expressões Cron**: "Remind me at 9am daily" (Me lembre às 9h todos os dias) → usa expressão cron + +As tarefas são armazenadas em `~/.picoclaw/workspace/cron/` e processadas automaticamente. + +## 🤝 Contribuir & Roadmap + +PRs são bem-vindos! O código-fonte é intencionalmente pequeno e legível. 🤗 + +Roadmap em breve... + +Grupo de desenvolvedores em formação. Requisito de entrada: Pelo menos 1 PR com merge. + +Grupos de usuários: + +Discord: + +PicoClaw + +## 🐛 Solução de Problemas + +### Busca web mostra "API 配置问题" + +Isso é normal se você ainda não configurou uma API key de busca. O PicoClaw fornecerá links úteis para busca manual. + +Para habilitar a busca web: + +1. **Opção 1 (Recomendado)**: Obtenha uma API key gratuita em [https://brave.com/search/api](https://brave.com/search/api) (2000 consultas grátis/mês) para os melhores resultados. +2. **Opção 2 (Sem Cartão de Crédito)**: Se você não tem uma key, o sistema automaticamente usa o **DuckDuckGo** como fallback (sem necessidade de key). + +Adicione a key em `~/.picoclaw/config.json` se usar o Brave: + +```json +{ + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +### Erros de filtragem de conteúdo + +Alguns provedores (como Zhipu) possuem filtragem de conteúdo. Tente reformular sua pergunta ou use um modelo diferente. + +### Bot do Telegram diz "Conflict: terminated by other getUpdates" + +Isso acontece quando outra instância do bot está em execução. Certifique-se de que apenas um `picoclaw gateway` esteja rodando por vez. + +--- + +## 📝 Comparação de API Keys + +| Serviço | Plano Gratuito | Caso de Uso | +| --- | --- | --- | +| **OpenRouter** | 200K tokens/mês | Múltiplos modelos (Claude, GPT-4, etc.) | +| **Zhipu** | 200K tokens/mês | Melhor para usuários chineses | +| **Brave Search** | 2000 consultas/mês | Funcionalidade de busca web | +| **Groq** | Plano gratuito disponível | Inferência ultra-rápida (Llama, Mixtral) | +| **Cerebras** | Plano gratuito disponível | Inferência ultra-rápida (Llama 3.3 70B) | diff --git a/README.vi.md b/README.vi.md new file mode 100644 index 000000000..161842933 --- /dev/null +++ b/README.vi.md @@ -0,0 +1,1092 @@ +
+PicoClaw + +

PicoClaw: Trợ lý AI Siêu Nhẹ viết bằng Go

+ +

Phần cứng $10 · RAM 10MB · Khởi động 1 giây · 皮皮虾,我们走!

+ +

+ Go + Hardware + License +
+ Website + Twitter +

+ +[中文](README.zh.md) | [日本語](README.ja.md) | [Português](README.pt-br.md) | **Tiếng Việt** | [Français](README.fr.md) | [English](README.md) +
+ +--- + +🦐 **PicoClaw** là trợ lý AI cá nhân siêu nhẹ, lấy cảm hứng từ [nanobot](https://github.com/HKUDS/nanobot), được viết lại hoàn toàn bằng **Go** thông qua quá trình "tự khởi tạo" (self-bootstrapping) — nơi chính AI Agent đã tự dẫn dắt toàn bộ quá trình chuyển đổi kiến trúc và tối ưu hóa mã nguồn. + +⚡️ **Cực kỳ nhẹ:** Chạy trên phần cứng chỉ **$10** với RAM **<10MB**. Tiết kiệm 99% bộ nhớ so với OpenClaw và rẻ hơn 98% so với Mac mini! + + + + + + +
+

+ +

+
+

+ +

+
+ +> [!CAUTION] +> **🚨 TUYÊN BỐ BẢO MẬT & KÊNH CHÍNH THỨC** +> +> * **KHÔNG CÓ CRYPTO:** PicoClaw **KHÔNG** có bất kỳ token/coin chính thức nào. Mọi thông tin trên `pump.fun` hoặc các sàn giao dịch khác đều là **LỪA ĐẢO**. +> * **DOMAIN CHÍNH THỨC:** Website chính thức **DUY NHẤT** là **[picoclaw.io](https://picoclaw.io)**, website công ty là **[sipeed.com](https://sipeed.com)**. +> * **Cảnh báo:** Nhiều tên miền `.ai/.org/.com/.net/...` đã bị bên thứ ba đăng ký, không phải của chúng tôi. +> * **Cảnh báo:** PicoClaw đang trong giai đoạn phát triển sớm và có thể còn các vấn đề bảo mật mạng chưa được giải quyết. Không nên triển khai lên môi trường production trước phiên bản v1.0. +> * **Lưu ý:** PicoClaw gần đây đã merge nhiều PR, dẫn đến bộ nhớ sử dụng có thể lớn hơn (10–20MB) ở các phiên bản mới nhất. Chúng tôi sẽ ưu tiên tối ưu tài nguyên khi bộ tính năng đã ổn định. + + +## 📢 Tin tức + +2026-02-16 🎉 PicoClaw đạt 12K stars chỉ trong một tuần! Cảm ơn tất cả mọi người! PicoClaw đang phát triển nhanh hơn chúng tôi tưởng tượng. Do số lượng PR tăng cao, chúng tôi cấp thiết cần maintainer từ cộng đồng. Các vai trò tình nguyện viên và roadmap đã được công bố [tại đây](docs/picoclaw_community_roadmap_260216.md) — rất mong đón nhận sự tham gia của bạn! + +2026-02-13 🎉 PicoClaw đạt 5000 stars trong 4 ngày! Cảm ơn cộng đồng! Chúng tôi đang hoàn thiện **Lộ trình dự án (Roadmap)** và thiết lập **Nhóm phát triển** để đẩy nhanh tốc độ phát triển PicoClaw. +🚀 **Kêu gọi hành động:** Vui lòng gửi yêu cầu tính năng tại GitHub Discussions. Chúng tôi sẽ xem xét và ưu tiên trong cuộc họp hàng tuần. + +2026-02-09 🎉 PicoClaw chính thức ra mắt! Được xây dựng trong 1 ngày để mang AI Agent đến phần cứng $10 với RAM <10MB. 🦐 PicoClaw, Lên Đường! + +## ✨ Tính năng nổi bật + +🪶 **Siêu nhẹ**: Bộ nhớ sử dụng <10MB — nhỏ hơn 99% so với Clawdbot (chức năng cốt lõi). + +💰 **Chi phí tối thiểu**: Đủ hiệu quả để chạy trên phần cứng $10 — rẻ hơn 98% so với Mac mini. + +⚡️ **Khởi động siêu nhanh**: Nhanh gấp 400 lần, khởi động trong 1 giây ngay cả trên CPU đơn nhân 0.6GHz. + +🌍 **Di động thực sự**: Một file binary duy nhất chạy trên RISC-V, ARM và x86. Một click là chạy! + +🤖 **AI tự xây dựng**: Triển khai Go-native tự động — 95% mã nguồn cốt lõi được Agent tạo ra, với sự tinh chỉnh của con người. + +| | OpenClaw | NanoBot | **PicoClaw** | +| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- | +| **Ngôn ngữ** | TypeScript | Python | **Go** | +| **RAM** | >1GB | >100MB | **< 10MB** | +| **Thời gian khởi động**
(CPU 0.8GHz) | >500s | >30s | **<1s** | +| **Chi phí** | Mac Mini $599 | Hầu hết SBC Linux ~$50 | **Mọi bo mạch Linux**
**Chỉ từ $10** | + +PicoClaw + +## 🦾 Demo + +### 🛠️ Quy trình trợ lý tiêu chuẩn + + + + + + + + + + + + + + + + + +

🧩 Lập trình Full-Stack

🗂️ Quản lý Nhật ký & Kế hoạch

🔎 Tìm kiếm Web & Học hỏi

Phát triển • Triển khai • Mở rộngLên lịch • Tự động hóa • Ghi nhớKhám phá • Phân tích • Xu hướng
+ +### 🐜 Triển khai sáng tạo trên phần cứng tối thiểu + +PicoClaw có thể triển khai trên hầu hết mọi thiết bị Linux! + +* $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) phiên bản E (Ethernet) hoặc W (WiFi6), dùng làm Trợ lý Gia đình tối giản. +* $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), hoặc $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html), dùng cho quản trị Server tự động. +* $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) hoặc $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera), dùng cho Giám sát thông minh. + +https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4 + +🌟 Nhiều hình thức triển khai hơn đang chờ bạn khám phá! + +## 📦 Cài đặt + +### Cài đặt bằng binary biên dịch sẵn + +Tải file binary cho nền tảng của bạn từ [trang Release](https://github.com/sipeed/picoclaw/releases). + +### Cài đặt từ mã nguồn (có tính năng mới nhất, khuyên dùng cho phát triển) + +```bash +git clone https://github.com/sipeed/picoclaw.git + +cd picoclaw +make deps + +# Build (không cần cài đặt) +make build + +# Build cho nhiều nền tảng +make build-all + +# Build và cài đặt +make install +``` + +## 🐳 Docker Compose + +Bạn cũng có thể chạy PicoClaw bằng Docker Compose mà không cần cài đặt gì trên máy. + +```bash +# 1. Clone repo +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. Thiết lập API Key +cp config/config.example.json config/config.json +vim config/config.json # Thiết lập DISCORD_BOT_TOKEN, API keys, v.v. + +# 3. Build & Khởi động +docker compose --profile gateway up -d + +# 4. Xem logs +docker compose logs -f picoclaw-gateway + +# 5. Dừng +docker compose --profile gateway down +``` + +### Chế độ Agent (chạy một lần) + +```bash +# Đặt câu hỏi +docker compose run --rm picoclaw-agent -m "2+2 bằng mấy?" + +# Chế độ tương tác +docker compose run --rm picoclaw-agent +``` + +### Build lại + +```bash +docker compose --profile gateway build --no-cache +docker compose --profile gateway up -d +``` + +### 🚀 Bắt đầu nhanh + +> [!TIP] +> Thiết lập API key trong `~/.picoclaw/config.json`. +> Lấy API key: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) +> Tìm kiếm web là **tùy chọn** — lấy [Brave Search API](https://brave.com/search/api) miễn phí (2000 truy vấn/tháng) hoặc dùng tính năng auto fallback tích hợp sẵn. + +**1. Khởi tạo** + +```bash +picoclaw onboard +``` + +**2. Cấu hình** (`~/.picoclaw/config.json`) + +```json +{ + "model_list": [ + { + "model_name": "gpt4", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key", + "api_base": "https://api.openai.com/v1" + } + ], + "agents": { + "defaults": { + "model": "gpt4" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "YOUR_TELEGRAM_BOT_TOKEN", + "allow_from": [] + } + } +} +``` + +**3. Lấy API Key** + +* **Nhà cung cấp LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +* **Tìm kiếm Web** (tùy chọn): [Brave Search](https://brave.com/search/api) — Có gói miễn phí (2000 truy vấn/tháng) + +> **Lưu ý**: Xem `config.example.json` để có mẫu cấu hình đầy đủ. + +**4. Trò chuyện** + +```bash +picoclaw agent -m "Xin chào, bạn là ai?" +``` + +Vậy là xong! Bạn đã có một trợ lý AI hoạt động chỉ trong 2 phút. + +--- + +## 💬 Tích hợp ứng dụng Chat + +Trò chuyện với PicoClaw qua Telegram, Discord, DingTalk, LINE hoặc WeCom. + +| Kênh | Mức độ thiết lập | +| --- | --- | +| **Telegram** | Dễ (chỉ cần token) | +| **Discord** | Dễ (bot token + intents) | +| **QQ** | Dễ (AppID + AppSecret) | +| **DingTalk** | Trung bình (app credentials) | +| **LINE** | Trung bình (credentials + webhook URL) | +| **WeCom** | Trung bình (CorpID + cấu hình webhook) | + +
+Telegram (Khuyên dùng) + +**1. Tạo bot** + +* Mở Telegram, tìm `@BotFather` +* Gửi `/newbot`, làm theo hướng dẫn +* Sao chép token + +**2. Cấu hình** + +```json +{ + "channels": { + "telegram": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allow_from": ["YOUR_USER_ID"] + } + } +} +``` + +> Lấy User ID từ `@userinfobot` trên Telegram. + +**3. Chạy** + +```bash +picoclaw gateway +``` + +
+ +
+Discord + +**1. Tạo bot** + +* Truy cập +* Create an application → Bot → Add Bot +* Sao chép bot token + +**2. Bật Intents** + +* Trong phần Bot settings, bật **MESSAGE CONTENT INTENT** +* (Tùy chọn) Bật **SERVER MEMBERS INTENT** nếu muốn dùng danh sách cho phép theo thông tin thành viên + +**3. Lấy User ID** + +* Discord Settings → Advanced → bật **Developer Mode** +* Click chuột phải vào avatar → **Copy User ID** + +**4. Cấu hình** + +```json +{ + "channels": { + "discord": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allow_from": ["YOUR_USER_ID"] + } + } +} +``` + +**5. Mời bot vào server** + +* OAuth2 → URL Generator +* Scopes: `bot` +* Bot Permissions: `Send Messages`, `Read Message History` +* Mở URL mời được tạo và thêm bot vào server của bạn + +**6. Chạy** + +```bash +picoclaw gateway +``` + +
+ +
+QQ + +**1. Tạo bot** + +* Truy cập [QQ Open Platform](https://q.qq.com/#) +* Tạo ứng dụng → Lấy **AppID** và **AppSecret** + +**2. Cấu hình** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "YOUR_APP_ID", + "app_secret": "YOUR_APP_SECRET", + "allow_from": [] + } + } +} +``` + +> Để `allow_from` trống để cho phép tất cả người dùng, hoặc chỉ định số QQ để giới hạn quyền truy cập. + +**3. Chạy** + +```bash +picoclaw gateway +``` + +
+ +
+DingTalk + +**1. Tạo bot** + +* Truy cập [Open Platform](https://open.dingtalk.com/) +* Tạo ứng dụng nội bộ +* Sao chép Client ID và Client Secret + +**2. Cấu hình** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "allow_from": [] + } + } +} +``` + +> Để `allow_from` trống để cho phép tất cả người dùng, hoặc chỉ định ID để giới hạn quyền truy cập. + +**3. Chạy** + +```bash +picoclaw gateway +``` + +
+ +
+LINE + +**1. Tạo tài khoản LINE Official** + +- Truy cập [LINE Developers Console](https://developers.line.biz/) +- Tạo provider → Tạo Messaging API channel +- Sao chép **Channel Secret** và **Channel Access Token** + +**2. Cấu hình** + +```json +{ + "channels": { + "line": { + "enabled": true, + "channel_secret": "YOUR_CHANNEL_SECRET", + "channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN", + "webhook_host": "0.0.0.0", + "webhook_port": 18791, + "webhook_path": "/webhook/line", + "allow_from": [] + } + } +} +``` + +**3. Thiết lập Webhook URL** + +LINE yêu cầu HTTPS cho webhook. Sử dụng reverse proxy hoặc tunnel: + +```bash +# Ví dụ với ngrok +ngrok http 18791 +``` + +Sau đó cài đặt Webhook URL trong LINE Developers Console thành `https://your-domain/webhook/line` và bật **Use webhook**. + +**4. Chạy** + +```bash +picoclaw gateway +``` + +> Trong nhóm chat, bot chỉ phản hồi khi được @mention. Các câu trả lời sẽ trích dẫn tin nhắn gốc. + +> **Docker Compose**: Thêm `ports: ["18791:18791"]` vào service `picoclaw-gateway` để mở port webhook. + +
+ +
+WeCom (WeChat Work) + +PicoClaw hỗ trợ hai loại tích hợp WeCom: + +**Tùy chọn 1: WeCom Bot (Robot Thông minh)** - Thiết lập dễ dàng hơn, hỗ trợ chat nhóm +**Tùy chọn 2: WeCom App (Ứng dụng Tự xây dựng)** - Nhiều tính năng hơn, nhắn tin chủ động + +Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) để biết hướng dẫn chi tiết. + +**Thiết lập Nhanh - WeCom Bot:** + +**1. Tạo bot** + +* Truy cập Bảng điều khiển Quản trị WeCom → Chat Nhóm → Thêm Bot Nhóm +* Sao chép URL webhook (định dạng: `https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`) + +**2. Cấu hình** + +```json +{ + "channels": { + "wecom": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18793, + "webhook_path": "/webhook/wecom", + "allow_from": [] + } + } +} +``` + +**Thiết lập Nhanh - WeCom App:** + +**1. Tạo ứng dụng** + +* Truy cập Bảng điều khiển Quản trị WeCom → Quản lý Ứng dụng → Tạo Ứng dụng +* Sao chép **AgentId** và **Secret** +* Truy cập trang "Công ty của tôi", sao chép **CorpID** + +**2. Cấu hình nhận tin nhắn** + +* Trong chi tiết ứng dụng, nhấp vào "Nhận Tin nhắn" → "Thiết lập API" +* Đặt URL thành `http://your-server:18792/webhook/wecom-app` +* Tạo **Token** và **EncodingAESKey** + +**3. Cấu hình** + +```json +{ + "channels": { + "wecom_app": { + "enabled": true, + "corp_id": "wwxxxxxxxxxxxxxxxx", + "corp_secret": "YOUR_CORP_SECRET", + "agent_id": 1000002, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18792, + "webhook_path": "/webhook/wecom-app", + "allow_from": [] + } + } +} +``` + +**4. Chạy** + +```bash +picoclaw gateway +``` + +> **Lưu ý**: WeCom App yêu cầu mở cổng 18792 cho callback webhook. Sử dụng proxy ngược cho HTTPS trong môi trường sản xuất. + +
+ +## ClawdChat Tham gia Mạng xã hội Agent + +Kết nối PicoClaw với Mạng xã hội Agent chỉ bằng cách gửi một tin nhắn qua CLI hoặc bất kỳ ứng dụng Chat nào đã tích hợp. + +**Đọc `https://clawdchat.ai/skill.md` và làm theo hướng dẫn để tham gia [ClawdChat.ai](https://clawdchat.ai)** + +## ⚙️ Cấu hình chi tiết + +File cấu hình: `~/.picoclaw/config.json` + +### Cấu trúc Workspace + +PicoClaw lưu trữ dữ liệu trong workspace đã cấu hình (mặc định: `~/.picoclaw/workspace`): + +``` +~/.picoclaw/workspace/ +├── sessions/ # Phiên hội thoại và lịch sử +├── memory/ # Bộ nhớ dài hạn (MEMORY.md) +├── state/ # Trạng thái lưu trữ (kênh cuối cùng, v.v.) +├── cron/ # Cơ sở dữ liệu tác vụ định kỳ +├── skills/ # Kỹ năng tùy chỉnh +├── AGENTS.md # Hướng dẫn hành vi Agent +├── HEARTBEAT.md # Prompt tác vụ định kỳ (kiểm tra mỗi 30 phút) +├── IDENTITY.md # Danh tính Agent +├── SOUL.md # Tâm hồn/Tính cách Agent +├── TOOLS.md # Mô tả công cụ +└── USER.md # Tùy chọn người dùng +``` + +### 🔒 Hộp cát bảo mật (Security Sandbox) + +PicoClaw chạy trong môi trường sandbox theo mặc định. Agent chỉ có thể truy cập file và thực thi lệnh trong phạm vi workspace. + +#### Cấu hình mặc định + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| Tùy chọn | Mặc định | Mô tả | +|----------|---------|-------| +| `workspace` | `~/.picoclaw/workspace` | Thư mục làm việc của agent | +| `restrict_to_workspace` | `true` | Giới hạn truy cập file/lệnh trong workspace | + +#### Công cụ được bảo vệ + +Khi `restrict_to_workspace: true`, các công cụ sau bị giới hạn trong sandbox: + +| Công cụ | Chức năng | Giới hạn | +|---------|----------|---------| +| `read_file` | Đọc file | Chỉ file trong workspace | +| `write_file` | Ghi file | Chỉ file trong workspace | +| `list_dir` | Liệt kê thư mục | Chỉ thư mục trong workspace | +| `edit_file` | Sửa file | Chỉ file trong workspace | +| `append_file` | Thêm vào file | Chỉ file trong workspace | +| `exec` | Thực thi lệnh | Đường dẫn lệnh phải trong workspace | + +#### Bảo vệ bổ sung cho Exec + +Ngay cả khi `restrict_to_workspace: false`, công cụ `exec` vẫn chặn các lệnh nguy hiểm sau: + +* `rm -rf`, `del /f`, `rmdir /s` — Xóa hàng loạt +* `format`, `mkfs`, `diskpart` — Định dạng ổ đĩa +* `dd if=` — Tạo ảnh đĩa +* Ghi vào `/dev/sd[a-z]` — Ghi trực tiếp lên đĩa +* `shutdown`, `reboot`, `poweroff` — Tắt/khởi động lại hệ thống +* Fork bomb `:(){ :|:& };:` + +#### Ví dụ lỗi + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### Tắt giới hạn (Rủi ro bảo mật) + +Nếu bạn cần agent truy cập đường dẫn ngoài workspace: + +**Cách 1: File cấu hình** + +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**Cách 2: Biến môi trường** + +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **Cảnh báo**: Tắt giới hạn này cho phép agent truy cập mọi đường dẫn trên hệ thống. Chỉ sử dụng cẩn thận trong môi trường được kiểm soát. + +#### Tính nhất quán của ranh giới bảo mật + +Cài đặt `restrict_to_workspace` áp dụng nhất quán trên mọi đường thực thi: + +| Đường thực thi | Ranh giới bảo mật | +|----------------|-------------------| +| Agent chính | `restrict_to_workspace` ✅ | +| Subagent / Spawn | Kế thừa cùng giới hạn ✅ | +| Tác vụ Heartbeat | Kế thừa cùng giới hạn ✅ | + +Tất cả đường thực thi chia sẻ cùng giới hạn workspace — không có cách nào vượt qua ranh giới bảo mật thông qua subagent hoặc tác vụ định kỳ. + +### Heartbeat (Tác vụ định kỳ) + +PicoClaw có thể tự động thực hiện các tác vụ định kỳ. Tạo file `HEARTBEAT.md` trong workspace: + +```markdown +# Tác vụ định kỳ + +- Kiểm tra email xem có tin nhắn quan trọng không +- Xem lại lịch cho các sự kiện sắp tới +- Kiểm tra dự báo thời tiết +``` + +Agent sẽ đọc file này mỗi 30 phút (có thể cấu hình) và thực hiện các tác vụ bằng công cụ có sẵn. + +#### Tác vụ bất đồng bộ với Spawn + +Đối với các tác vụ chạy lâu (tìm kiếm web, gọi API), sử dụng công cụ `spawn` để tạo **subagent**: + +```markdown +# Tác vụ định kỳ + +## Tác vụ nhanh (trả lời trực tiếp) +- Báo cáo thời gian hiện tại + +## Tác vụ lâu (dùng spawn cho async) +- Tìm kiếm tin tức AI trên web và tóm tắt +- Kiểm tra email và báo cáo tin nhắn quan trọng +``` + +**Hành vi chính:** + +| Tính năng | Mô tả | +|-----------|-------| +| **spawn** | Tạo subagent bất đồng bộ, không chặn heartbeat | +| **Context độc lập** | Subagent có context riêng, không có lịch sử phiên | +| **message tool** | Subagent giao tiếp trực tiếp với người dùng qua công cụ message | +| **Không chặn** | Sau khi spawn, heartbeat tiếp tục tác vụ tiếp theo | + +#### Cách Subagent giao tiếp + +``` +Heartbeat kích hoạt + ↓ +Agent đọc HEARTBEAT.md + ↓ +Tác vụ lâu: spawn subagent + ↓ ↓ +Tiếp tục tác vụ tiếp theo Subagent làm việc độc lập + ↓ ↓ +Tất cả tác vụ hoàn thành Subagent dùng công cụ "message" + ↓ ↓ +Phản hồi HEARTBEAT_OK Người dùng nhận kết quả trực tiếp +``` + +Subagent có quyền truy cập các công cụ (message, web_search, v.v.) và có thể giao tiếp với người dùng một cách độc lập mà không cần thông qua agent chính. + +**Cấu hình:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| Tùy chọn | Mặc định | Mô tả | +|----------|---------|-------| +| `enabled` | `true` | Bật/tắt heartbeat | +| `interval` | `30` | Khoảng thời gian kiểm tra (phút, tối thiểu: 5) | + +**Biến môi trường:** + +* `PICOCLAW_HEARTBEAT_ENABLED=false` để tắt +* `PICOCLAW_HEARTBEAT_INTERVAL=60` để thay đổi khoảng thời gian + +### Nhà cung cấp (Providers) + +> [!NOTE] +> Groq cung cấp dịch vụ chuyển giọng nói thành văn bản miễn phí qua Whisper. Nếu đã cấu hình Groq, tin nhắn thoại trên Telegram sẽ được tự động chuyển thành văn bản. + +| Nhà cung cấp | Mục đích | Lấy API Key | +| --- | --- | --- | +| `gemini` | LLM (Gemini trực tiếp) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (Zhipu trực tiếp) | [bigmodel.cn](bigmodel.cn) | +| `openrouter` (Đang thử nghiệm) | LLM (khuyên dùng, truy cập mọi model) | [openrouter.ai](https://openrouter.ai) | +| `anthropic` (Đang thử nghiệm) | LLM (Claude trực tiếp) | [console.anthropic.com](https://console.anthropic.com) | +| `openai` (Đang thử nghiệm) | LLM (GPT trực tiếp) | [platform.openai.com](https://platform.openai.com) | +| `deepseek` (Đang thử nghiệm) | LLM (DeepSeek trực tiếp) | [platform.deepseek.com](https://platform.deepseek.com) | +| `groq` | LLM + **Chuyển giọng nói** (Whisper) | [console.groq.com](https://console.groq.com) | +| `qwen` | LLM (Qwen trực tiếp) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | +| `cerebras` | LLM (Cerebras trực tiếp) | [cerebras.ai](https://cerebras.ai) | + +
+Cấu hình Zhipu + +**1. Lấy API key** + +* Lấy [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) + +**2. Cấu hình** + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "zhipu": { + "api_key": "Your API Key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + } +} +``` + +**3. Chạy** + +```bash +picoclaw agent -m "Xin chào" +``` + +
+ +
+Ví dụ cấu hình đầy đủ + +```json +{ + "agents": { + "defaults": { + "model": "anthropic/claude-opus-4-5" + } + }, + "providers": { + "openrouter": { + "api_key": "sk-or-v1-xxx" + }, + "groq": { + "api_key": "gsk_xxx" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "123456:ABC...", + "allow_from": ["123456789"] + }, + "discord": { + "enabled": true, + "token": "", + "allow_from": [""] + }, + "whatsapp": { + "enabled": false + }, + "feishu": { + "enabled": false, + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] + }, + "qq": { + "enabled": false, + "app_id": "", + "app_secret": "", + "allow_from": [] + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "BSA...", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + }, + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +
+ +### Cấu hình Mô hình (model_list) + +> **Tính năng mới!** PicoClaw hiện sử dụng phương pháp cấu hình **đặt mô hình vào trung tâm**. Chỉ cần chỉ định dạng `nhà cung cấp/mô hình` (ví dụ: `zhipu/glm-4.7`) để thêm nhà cung cấp mới—**không cần thay đổi mã!** + +Thiết kế này cũng cho phép **hỗ trợ đa tác nhân** với lựa chọn nhà cung cấp linh hoạt: + +- **Tác nhân khác nhau, nhà cung cấp khác nhau** : Mỗi tác nhân có thể sử dụng nhà cung cấp LLM riêng +- **Mô hình dự phòng** : Cấu hình mô hình chính và dự phòng để tăng độ tin cậy +- **Cân bằng tải** : Phân phối yêu cầu trên nhiều endpoint khác nhau +- **Cấu hình tập trung** : Quản lý tất cả nhà cung cấp ở một nơi + +#### 📋 Tất cả Nhà cung cấp được Hỗ trợ + +| Nhà cung cấp | Prefix `model` | API Base Mặc định | Giao thức | Khóa API | +|-------------|----------------|-------------------|-----------|----------| +| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Lấy Khóa](https://platform.openai.com) | +| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Lấy Khóa](https://console.anthropic.com) | +| **Zhipu AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Lấy Khóa](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) | +| **DeepSeek** | `deepseek/` | `https://api.deepseek.com/v1` | OpenAI | [Lấy Khóa](https://platform.deepseek.com) | +| **Google Gemini** | `gemini/` | `https://generativelanguage.googleapis.com/v1beta` | OpenAI | [Lấy Khóa](https://aistudio.google.com/api-keys) | +| **Groq** | `groq/` | `https://api.groq.com/openai/v1` | OpenAI | [Lấy Khóa](https://console.groq.com) | +| **Moonshot** | `moonshot/` | `https://api.moonshot.cn/v1` | OpenAI | [Lấy Khóa](https://platform.moonshot.cn) | +| **Qwen (Alibaba)** | `qwen/` | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI | [Lấy Khóa](https://dashscope.console.aliyun.com) | +| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Lấy Khóa](https://build.nvidia.com) | +| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (không cần khóa) | +| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Lấy Khóa](https://openrouter.ai/keys) | +| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local | +| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Lấy Khóa](https://cerebras.ai) | +| **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Lấy Khóa](https://console.volcengine.com) | +| **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **Antigravity** | `antigravity/` | Google Cloud | Tùy chỉnh | Chỉ OAuth | +| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | + +#### Cấu hình Cơ bản + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key" + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "sk-ant-your-key" + }, + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-zhipu-key" + } + ], + "agents": { + "defaults": { + "model": "gpt-5.2" + } + } +} +``` + +#### Ví dụ theo Nhà cung cấp + +**OpenAI** +```json +{ + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-..." +} +``` + +**Zhipu AI (GLM)** +```json +{ + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" +} +``` + +**Anthropic (với OAuth)** +```json +{ + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "auth_method": "oauth" +} +``` +> Chạy `picoclaw auth login --provider anthropic` để thiết lập thông tin xác thực OAuth. + +#### Cân bằng Tải tải + +Định cấu hình nhiều endpoint cho cùng một tên mô hình—PicoClaw sẽ tự động phân phối round-robin giữa chúng: + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api1.example.com/v1", + "api_key": "sk-key1" + }, + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api2.example.com/v1", + "api_key": "sk-key2" + } + ] +} +``` + +#### Chuyển đổi từ Cấu hình `providers` Cũ + +Cấu hình `providers` cũ đã **ngừng sử dụng** nhưng vẫn được hỗ trợ để tương thích ngược. + +**Cấu hình Cũ (đã ngừng sử dụng):** +```json +{ + "providers": { + "zhipu": { + "api_key": "your-key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + }, + "agents": { + "defaults": { + "provider": "zhipu", + "model": "glm-4.7" + } + } +} +``` + +**Cấu hình Mới (khuyến nghị):** +```json +{ + "model_list": [ + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" + } + ], + "agents": { + "defaults": { + "model": "glm-4.7" + } + } +} +``` + +Xem hướng dẫn chuyển đổi chi tiết tại [docs/migration/model-list-migration.md](docs/migration/model-list-migration.md). + +## Tham chiếu CLI + +| Lệnh | Mô tả | +| --- | --- | +| `picoclaw onboard` | Khởi tạo cấu hình & workspace | +| `picoclaw agent -m "..."` | Trò chuyện với agent | +| `picoclaw agent` | Chế độ chat tương tác | +| `picoclaw gateway` | Khởi động gateway (cho bot chat) | +| `picoclaw status` | Hiển thị trạng thái | +| `picoclaw cron list` | Liệt kê tất cả tác vụ định kỳ | +| `picoclaw cron add ...` | Thêm tác vụ định kỳ | + +### Tác vụ định kỳ / Nhắc nhở + +PicoClaw hỗ trợ nhắc nhở theo lịch và tác vụ lặp lại thông qua công cụ `cron`: + +* **Nhắc nhở một lần**: "Remind me in 10 minutes" (Nhắc tôi sau 10 phút) → kích hoạt một lần sau 10 phút +* **Tác vụ lặp lại**: "Remind me every 2 hours" (Nhắc tôi mỗi 2 giờ) → kích hoạt mỗi 2 giờ +* **Biểu thức Cron**: "Remind me at 9am daily" (Nhắc tôi lúc 9 giờ sáng mỗi ngày) → sử dụng biểu thức cron + +Các tác vụ được lưu trong `~/.picoclaw/workspace/cron/` và được xử lý tự động. + +## 🤝 Đóng góp & Lộ trình + +Chào đón mọi PR! Mã nguồn được thiết kế nhỏ gọn và dễ đọc. 🤗 + +Lộ trình sắp được công bố... + +Nhóm phát triển đang được xây dựng. Điều kiện tham gia: Ít nhất 1 PR đã được merge. + +Nhóm người dùng: + +Discord: + +PicoClaw + +## 🐛 Xử lý sự cố + +### Tìm kiếm web hiện "API 配置问题" + +Điều này là bình thường nếu bạn chưa cấu hình API key cho tìm kiếm. PicoClaw sẽ cung cấp các liên kết hữu ích để tìm kiếm thủ công. + +Để bật tìm kiếm web: + +1. **Tùy chọn 1 (Khuyên dùng)**: Lấy API key miễn phí tại [https://brave.com/search/api](https://brave.com/search/api) (2000 truy vấn miễn phí/tháng) để có kết quả tốt nhất. +2. **Tùy chọn 2 (Không cần thẻ tín dụng)**: Nếu không có key, hệ thống tự động chuyển sang dùng **DuckDuckGo** (không cần key). + +Thêm key vào `~/.picoclaw/config.json` nếu dùng Brave: + +```json +{ + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +### Gặp lỗi lọc nội dung (Content Filtering) + +Một số nhà cung cấp (như Zhipu) có bộ lọc nội dung nghiêm ngặt. Thử diễn đạt lại câu hỏi hoặc sử dụng model khác. + +### Telegram bot báo "Conflict: terminated by other getUpdates" + +Điều này xảy ra khi có một instance bot khác đang chạy. Đảm bảo chỉ có một tiến trình `picoclaw gateway` chạy tại một thời điểm. + +--- + +## 📝 So sánh API Key + +| Dịch vụ | Gói miễn phí | Trường hợp sử dụng | +| --- | --- | --- | +| **OpenRouter** | 200K tokens/tháng | Đa model (Claude, GPT-4, v.v.) | +| **Zhipu** | 200K tokens/tháng | Tốt nhất cho người dùng Trung Quốc | +| **Brave Search** | 2000 truy vấn/tháng | Chức năng tìm kiếm web | +| **Groq** | Có gói miễn phí | Suy luận siêu nhanh (Llama, Mixtral) | diff --git a/README.zh.md b/README.zh.md new file mode 100644 index 000000000..4d739c5eb --- /dev/null +++ b/README.zh.md @@ -0,0 +1,813 @@ +
+PicoClaw + +

PicoClaw: 基于Go语言的超高效 AI 助手

+ +

10$硬件 · 10MB内存 · 1秒启动 · 皮皮虾,我们走!

+ +

+ Go + Hardware + License +
+ Website + Twitter +

+ +**中文** | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | [English](README.md) + +
+ +--- + +🦐 **PicoClaw** 是一个受 [nanobot](https://github.com/HKUDS/nanobot) 启发的超轻量级个人 AI 助手。它采用 **Go 语言** 从零重构,经历了一个“自举”过程——即由 AI Agent 自身驱动了整个架构迁移和代码优化。 + +⚡️ **极致轻量**:可在 **10 美元** 的硬件上运行,内存占用 **<10MB**。这意味着比 OpenClaw 节省 99% 的内存,比 Mac mini 便宜 98%! + + + + + + +
+

+ +

+
+

+ +

+
+ +注意:人手有限,中文文档可能略有滞后,请优先查看英文文档。 + +> [!CAUTION] +> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明** +> +> - **无加密货币 (NO CRYPTO):** PicoClaw **没有** 发行任何官方代币、Token 或虚拟货币。所有在 `pump.fun` 或其他交易平台上的相关声称均为 **诈骗**。 +> - **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。 +> - **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。 +> - **注意:** picoclaw正在初期的快速功能开发阶段,可能有尚未修复的网络安全问题,在1.0正式版发布前,请不要将其部署到生产环境中 +> - **注意:** picoclaw最近合并了大量PRs,近期版本可能内存占用较大(10~20MB),我们将在功能较为收敛后进行资源占用优化. + +## 📢 新闻 (News) + +2026-02-16 🎉 PicoClaw 在一周内突破了12K star! 感谢大家的关注!PicoClaw 的成长速度超乎我们预期. 由于PR数量的快速膨胀,我们亟需社区开发者参与维护. 我们需要的志愿者角色和roadmap已经发布到了[这里](docs/picoclaw_community_roadmap_260216.md), 期待你的参与! + +2026-02-13 🎉 **PicoClaw 在 4 天内突破 5000 Stars!** 感谢社区的支持!由于正值中国春节假期,PR 和 Issue 涌入较多,我们正在利用这段时间敲定 **项目路线图 (Roadmap)** 并组建 **开发者群组**,以便加速 PicoClaw 的开发。 +🚀 **行动号召:** 请在 GitHub Discussions 中提交您的功能请求 (Feature Requests)。我们将在接下来的周会上进行审查和优先级排序。 + +2026-02-09 🎉 **PicoClaw 正式发布!** 仅用 1 天构建,旨在将 AI Agent 带入 10 美元硬件与 <10MB 内存的世界。🦐 PicoClaw(皮皮虾),我们走! + +## ✨ 特性 + +🪶 **超轻量级**: 核心功能内存占用 <10MB — 比 Clawdbot 小 99%。 + +💰 **极低成本**: 高效到足以在 10 美元的硬件上运行 — 比 Mac mini 便宜 98%。 + +⚡️ **闪电启动**: 启动速度快 400 倍,即使在 0.6GHz 单核处理器上也能在 1 秒内启动。 + +🌍 **真正可移植**: 跨 RISC-V、ARM 和 x86 架构的单二进制文件,一键运行! + +🤖 **AI 自举**: 纯 Go 语言原生实现 — 95% 的核心代码由 Agent 生成,并经由“人机回环 (Human-in-the-loop)”微调。 + +| | OpenClaw | NanoBot | **PicoClaw** | +| ------------------------------ | ------------- | ------------------------ | -------------------------------------- | +| **语言** | TypeScript | Python | **Go** | +| **RAM** | >1GB | >100MB | **< 10MB** | +| **启动时间**
(0.8GHz core) | >500s | >30s | **<1s** | +| **成本** | Mac Mini $599 | 大多数 Linux 开发板 ~$50 | **任意 Linux 开发板**
**低至 $10** | + +PicoClaw + +## 🦾 演示 + +### 🛠️ 标准助手工作流 + + + + + + + + + + + + + + + + + +

🧩 全栈工程师模式

🗂️ 日志与规划管理

🔎 网络搜索与学习

开发 • 部署 • 扩展日程 • 自动化 • 记忆发现 • 洞察 • 趋势
+ +### 📱 在手机上轻松运行 + +picoclaw 可以将你10年前的老旧手机废物利用,变身成为你的AI助理!快速指南: + +1. 先去应用商店下载安装Termux +2. 打开后执行指令 + +```bash +# 注意: 下面的v0.1.1 可以换为你实际看到的最新版本 +wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64 +chmod +x picoclaw-linux-arm64 +pkg install proot +termux-chroot ./picoclaw-linux-arm64 onboard +``` + +然后跟随下面的“快速开始”章节继续配置picoclaw即可使用! +PicoClaw + +### 🐜 创新的低占用部署 + +PicoClaw 几乎可以部署在任何 Linux 设备上! + +- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(网口) 或 W(WiFi6) 版本,用于极简家庭助手。 +- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html),或 $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html),用于自动化服务器运维。 +- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) 或 $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera),用于智能监控。 + +[https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4](https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4) + +🌟 更多部署案例敬请期待! + +## 📦 安装 + +### 使用预编译二进制文件安装 + +从 [Release 页面](https://github.com/sipeed/picoclaw/releases) 下载适用于您平台的固件。 + +### 从源码安装(获取最新特性,开发推荐) + +```bash +git clone https://github.com/sipeed/picoclaw.git + +cd picoclaw +make deps + +# 构建(无需安装) +make build + +# 为多平台构建 +make build-all + +# 构建并安装 +make install + +``` + +## 🐳 Docker Compose + +您也可以使用 Docker Compose 运行 PicoClaw,无需在本地安装任何环境。 + +```bash +# 1. 克隆仓库 +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. 设置 API Key +cp config/config.example.json config/config.json +vim config/config.json # 设置 DISCORD_BOT_TOKEN, API keys 等 + +# 3. 构建并启动 +docker compose --profile gateway up -d + +# 4. 查看日志 +docker compose logs -f picoclaw-gateway + +# 5. 停止 +docker compose --profile gateway down + +``` + +### Agent 模式 (一次性运行) + +```bash +# 提问 +docker compose run --rm picoclaw-agent -m "2+2 等于几?" + +# 交互模式 +docker compose run --rm picoclaw-agent + +``` + +### 重新构建 + +```bash +docker compose --profile gateway build --no-cache +docker compose --profile gateway up -d + +``` + +### 🚀 快速开始 + +> [!TIP] +> 在 `~/.picoclaw/config.json` 中设置您的 API Key。 +> 获取 API Key: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu (智谱)](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) +> 网络搜索是 **可选的** - 获取免费的 [Brave Search API](https://brave.com/search/api) (每月 2000 次免费查询) + +**1. 初始化 (Initialize)** + +```bash +picoclaw onboard + +``` + +**2. 配置 (Configure)** (`~/.picoclaw/config.json`) + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "gpt4", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "model_list": [ + { + "model_name": "gpt4", + "model": "openai/gpt-5.2", + "api_key": "your-api-key" + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "your-anthropic-key" + } + ], + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + }, + "cron": { + "exec_timeout_minutes": 5 + } + } +} +``` + +> **新功能**: `model_list` 配置格式支持零代码添加 provider。详见[模型配置](#模型配置-model_list)章节。 + +**3. 获取 API Key** + +- **LLM 提供商**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +- **网络搜索** (可选): [Brave Search](https://brave.com/search/api) - 提供免费层级 (2000 请求/月) + +> **注意**: 完整的配置模板请参考 `config.example.json`。 + +**4. 对话 (Chat)** + +```bash +picoclaw agent -m "2+2 等于几?" + +``` + +就是这样!您在 2 分钟内就拥有了一个可工作的 AI 助手。 + +--- + +## 💬 聊天应用集成 (Chat Apps) + +PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方。 + +### 核心渠道 + +| 渠道 | 设置难度 | 特性说明 | 文档链接 | +| -------------------- | ----------- | ----------------------------------------- | --------------------------------------------------------------------------------------------------------------- | +| **Telegram** | ⭐ 简单 | 推荐,支持语音转文字,长轮询无需公网 | [查看文档](docs/channels/telegram/README.zh.md) | +| **Discord** | ⭐ 简单 | Socket Mode,支持群组/私信,Bot 生态成熟 | [查看文档](docs/channels/discord/README.zh.md) | +| **Slack** | ⭐ 简单 | **Socket Mode** (无需公网 IP),企业级支持 | [查看文档](docs/channels/slack/README.zh.md) | +| **QQ** | ⭐⭐ 中等 | 官方机器人 API,适合国内社群 | [查看文档](docs/channels/qq/README.zh.md) | +| **钉钉 (DingTalk)** | ⭐⭐ 中等 | Stream 模式无需公网,企业办公首选 | [查看文档](docs/channels/dingtalk/README.zh.md) | +| **企业微信 (WeCom)** | ⭐⭐⭐ 较难 | 支持群机器人(Webhook)和自建应用(API) | [Bot 文档](docs/channels/wecom/wecom_bot/README.zh.md) / [App 文档](docs/channels/wecom/wecom_app/README.zh.md) | +| **飞书 (Feishu)** | ⭐⭐⭐ 较难 | 企业级协作,功能丰富 | [查看文档](docs/channels/feishu/README.zh.md) | +| **Line** | ⭐⭐⭐ 较难 | 需要 HTTPS Webhook | [查看文档](docs/channels/line/README.zh.md) | +| **OneBot** | ⭐⭐ 中等 | 兼容 NapCat/Go-CQHTTP,社区生态丰富 | [查看文档](docs/channels/onebot/README.zh.md) | +| **MaixCam** | ⭐ 简单 | 专为 AI 摄像头设计的硬件集成通道 | [查看文档](docs/channels/maixcam/README.zh.md) | + +## ClawdChat 加入 Agent 社交网络 + +只需通过 CLI 或任何集成的聊天应用发送一条消息,即可将 PicoClaw 连接到 Agent 社交网络。 + +\*\*阅读 `https://clawdchat.ai/skill.md` 并按照说明加入 [ClawdChat.ai](https://clawdchat.ai) + +## ⚙️ 配置详解 + +配置文件路径: `~/.picoclaw/config.json` + +### 工作区布局 (Workspace Layout) + +PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/workspace`): + +``` +~/.picoclaw/workspace/ +├── sessions/ # 对话会话和历史 +├── memory/ # 长期记忆 (MEMORY.md) +├── state/ # 持久化状态 (最后一次频道等) +├── cron/ # 定时任务数据库 +├── skills/ # 自定义技能 +├── AGENTS.md # Agent 行为指南 +├── HEARTBEAT.md # 周期性任务提示词 (每 30 分钟检查一次) +├── IDENTITY.md # Agent 身份设定 +├── SOUL.md # Agent 灵魂/性格 +├── TOOLS.md # 工具描述 +└── USER.md # 用户偏好 + +``` + +### 心跳 / 周期性任务 (Heartbeat) + +PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件: + +```markdown +# Periodic Tasks + +- Check my email for important messages +- Review my calendar for upcoming events +- Check the weather forecast +``` + +Agent 将每隔 30 分钟(可配置)读取此文件,并使用可用工具执行任务。 + +#### 使用 Spawn 的异步任务 + +对于耗时较长的任务(网络搜索、API 调用),使用 `spawn` 工具创建一个 **子 Agent (subagent)**: + +```markdown +# Periodic Tasks + +## Quick Tasks (respond directly) + +- Report current time + +## Long Tasks (use spawn for async) + +- Search the web for AI news and summarize +- Check email and report important messages +``` + +**关键行为:** + +| 特性 | 描述 | +| ---------------- | ---------------------------------------- | +| **spawn** | 创建异步子 Agent,不阻塞主心跳进程 | +| **独立上下文** | 子 Agent 拥有独立上下文,无会话历史 | +| **message tool** | 子 Agent 通过 message 工具直接与用户通信 | +| **非阻塞** | spawn 后,心跳继续处理下一个任务 | + +#### 子 Agent 通信原理 + +``` +心跳触发 (Heartbeat triggers) + ↓ +Agent 读取 HEARTBEAT.md + ↓ +对于长任务: spawn 子 Agent + ↓ ↓ +继续下一个任务 子 Agent 独立工作 + ↓ ↓ +所有任务完成 子 Agent 使用 "message" 工具 + ↓ ↓ +响应 HEARTBEAT_OK 用户直接收到结果 + +``` + +子 Agent 可以访问工具(message, web_search 等),并且无需通过主 Agent 即可独立与用户通信。 + +**配置:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| 选项 | 默认值 | 描述 | +| ---------- | ------ | ---------------------------- | +| `enabled` | `true` | 启用/禁用心跳 | +| `interval` | `30` | 检查间隔,单位分钟 (最小: 5) | + +**环境变量:** + +- `PICOCLAW_HEARTBEAT_ENABLED=false` 禁用 +- `PICOCLAW_HEARTBEAT_INTERVAL=60` 更改间隔 + +### 提供商 (Providers) + +> [!NOTE] +> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,Telegram 语音消息将被自动转录为文字。 + +| 提供商 | 用途 | 获取 API Key | +| -------------------- | ---------------------------- | -------------------------------------------------------------------- | +| `gemini` | LLM (Gemini 直连) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (智谱直连) | [bigmodel.cn](bigmodel.cn) | +| `openrouter(待测试)` | LLM (推荐,可访问所有模型) | [openrouter.ai](https://openrouter.ai) | +| `anthropic(待测试)` | LLM (Claude 直连) | [console.anthropic.com](https://console.anthropic.com) | +| `openai(待测试)` | LLM (GPT 直连) | [platform.openai.com](https://platform.openai.com) | +| `deepseek(待测试)` | LLM (DeepSeek 直连) | [platform.deepseek.com](https://platform.deepseek.com) | +| `qwen` | LLM (通义千问) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | +| `groq` | LLM + **语音转录** (Whisper) | [console.groq.com](https://console.groq.com) | +| `cerebras` | LLM (Cerebras 直连) | [cerebras.ai](https://cerebras.ai) | + +### 模型配置 (model_list) + +> **新功能!** PicoClaw 现在采用**以模型为中心**的配置方式。只需使用 `厂商/模型` 格式(如 `zhipu/glm-4.7`)即可添加新的 provider——**无需修改任何代码!** + +该设计同时支持**多 Agent 场景**,提供灵活的 Provider 选择: + +- **不同 Agent 使用不同 Provider**:每个 Agent 可以使用自己的 LLM provider +- **模型回退(Fallback)**:配置主模型和备用模型,提高可靠性 +- **负载均衡**:在多个 API 端点之间分配请求 +- **集中化配置**:在一个地方管理所有 provider + +#### 📋 所有支持的厂商 + +| 厂商 | `model` 前缀 | 默认 API Base | 协议 | 获取 API Key | +| ------------------- | ----------------- | --------------------------------------------------- | --------- | ----------------------------------------------------------------- | +| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [获取密钥](https://platform.openai.com) | +| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [获取密钥](https://console.anthropic.com) | +| **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [获取密钥](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) | +| **DeepSeek** | `deepseek/` | `https://api.deepseek.com/v1` | OpenAI | [获取密钥](https://platform.deepseek.com) | +| **Google Gemini** | `gemini/` | `https://generativelanguage.googleapis.com/v1beta` | OpenAI | [获取密钥](https://aistudio.google.com/api-keys) | +| **Groq** | `groq/` | `https://api.groq.com/openai/v1` | OpenAI | [获取密钥](https://console.groq.com) | +| **Moonshot** | `moonshot/` | `https://api.moonshot.cn/v1` | OpenAI | [获取密钥](https://platform.moonshot.cn) | +| **通义千问 (Qwen)** | `qwen/` | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI | [获取密钥](https://dashscope.console.aliyun.com) | +| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [获取密钥](https://build.nvidia.com) | +| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | 本地(无需密钥) | +| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [获取密钥](https://openrouter.ai/keys) | +| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | 本地 | +| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [获取密钥](https://cerebras.ai) | +| **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [获取密钥](https://console.volcengine.com) | +| **神算云** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **Antigravity** | `antigravity/` | Google Cloud | 自定义 | 仅 OAuth | +| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | + +#### 基础配置示例 + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key" + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "sk-ant-your-key" + }, + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-zhipu-key" + } + ], + "agents": { + "defaults": { + "model": "gpt-5.2" + } + } +} +``` + +#### 各厂商配置示例 + +**OpenAI** + +```json +{ + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-..." +} +``` + +**智谱 AI (GLM)** + +```json +{ + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" +} +``` + +**DeepSeek** + +```json +{ + "model_name": "deepseek-chat", + "model": "deepseek/deepseek-chat", + "api_key": "sk-..." +} +``` + +**Anthropic (使用 OAuth)** + +```json +{ + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "auth_method": "oauth" +} +``` + +> 运行 `picoclaw auth login --provider anthropic` 来设置 OAuth 凭证。 + +**Ollama (本地)** + +```json +{ + "model_name": "llama3", + "model": "ollama/llama3" +} +``` + +**自定义代理/API** + +```json +{ + "model_name": "my-custom-model", + "model": "openai/custom-model", + "api_base": "https://my-proxy.com/v1", + "api_key": "sk-..." +} +``` + +#### 负载均衡 + +为同一个模型名称配置多个端点——PicoClaw 会自动在它们之间轮询: + +```json +{ + "model_list": [ + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api1.example.com/v1", + "api_key": "sk-key1" + }, + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_base": "https://api2.example.com/v1", + "api_key": "sk-key2" + } + ] +} +``` + +#### 从旧的 `providers` 配置迁移 + +旧的 `providers` 配置格式**已弃用**,但为向后兼容仍支持。 + +**旧配置(已弃用):** + +```json +{ + "providers": { + "zhipu": { + "api_key": "your-key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + }, + "agents": { + "defaults": { + "provider": "zhipu", + "model": "glm-4.7" + } + } +} +``` + +**新配置(推荐):** + +```json +{ + "model_list": [ + { + "model_name": "glm-4.7", + "model": "zhipu/glm-4.7", + "api_key": "your-key" + } + ], + "agents": { + "defaults": { + "model": "glm-4.7" + } + } +} +``` + +详细的迁移指南请参考 [docs/migration/model-list-migration.md](docs/migration/model-list-migration.md)。 + +
+智谱 (Zhipu) 配置示例 + +**1. 获取 API key 和 base URL** + +- 获取 [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) + +**2. 配置** + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "zhipu": { + "api_key": "Your API Key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + } +} +``` + +**3. 运行** + +```bash +picoclaw agent -m "你好" + +``` + +
+ +
+完整配置示例 + +```json +{ + "agents": { + "defaults": { + "model": "anthropic/claude-opus-4-5" + } + }, + "providers": { + "openrouter": { + "api_key": "sk-or-v1-xxx" + }, + "groq": { + "api_key": "gsk_xxx" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "123456:ABC...", + "allow_from": ["123456789"] + }, + "discord": { + "enabled": true, + "token": "", + "allow_from": [""] + }, + "whatsapp": { + "enabled": false + }, + "feishu": { + "enabled": false, + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] + }, + "qq": { + "enabled": false, + "app_id": "", + "app_secret": "", + "allow_from": [] + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + }, + "cron": { + "exec_timeout_minutes": 5 + } + }, + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +
+ +## CLI 命令行参考 + +| 命令 | 描述 | +| ------------------------- | ------------------ | +| `picoclaw onboard` | 初始化配置和工作区 | +| `picoclaw agent -m "..."` | 与 Agent 对话 | +| `picoclaw agent` | 交互式聊天模式 | +| `picoclaw gateway` | 启动网关 (Gateway) | +| `picoclaw status` | 显示状态 | +| `picoclaw cron list` | 列出所有定时任务 | +| `picoclaw cron add ...` | 添加定时任务 | + +### 定时任务 / 提醒 (Scheduled Tasks) + +PicoClaw 通过 `cron` 工具支持定时提醒和重复任务: + +- **一次性提醒**: "Remind me in 10 minutes" (10分钟后提醒我) → 10分钟后触发一次 +- **重复任务**: "Remind me every 2 hours" (每2小时提醒我) → 每2小时触发 +- **Cron 表达式**: "Remind me at 9am daily" (每天上午9点提醒我) → 使用 cron 表达式 + +任务存储在 `~/.picoclaw/workspace/cron/` 中并自动处理。 + +## 🤝 贡献与路线图 (Roadmap) + +欢迎提交 PR!代码库刻意保持小巧和可读。🤗 + +路线图即将发布... + +开发者群组正在组建中,入群门槛:至少合并过 1 个 PR。 + +用户群组: + +Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN) + +PicoClaw + +## 🐛 疑难解答 (Troubleshooting) + +### 网络搜索提示 "API 配置问题" + +如果您尚未配置搜索 API Key,这是正常的。PicoClaw 会提供手动搜索的帮助链接。 + +启用网络搜索: + +1. 在 [https://brave.com/search/api](https://brave.com/search/api) 获取免费 API Key (每月 2000 次免费查询) +2. 添加到 `~/.picoclaw/config.json`: + +```json +{ + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +### 遇到内容过滤错误 (Content Filtering Errors) + +某些提供商(如智谱)有严格的内容过滤。尝试改写您的问题或使用其他模型。 + +### Telegram bot 提示 "Conflict: terminated by other getUpdates" + +这表示有另一个机器人实例正在运行。请确保同一时间只有一个 `picoclaw gateway` 进程在运行。 + +--- + +## 📝 API Key 对比 + +| 服务 | 免费层级 | 适用场景 | +| ---------------- | -------------- | ----------------------------- | +| **OpenRouter** | 200K tokens/月 | 多模型聚合 (Claude, GPT-4 等) | +| **智谱 (Zhipu)** | 200K tokens/月 | 最适合中国用户 | +| **Brave Search** | 2000 次查询/月 | 网络搜索功能 | +| **Groq** | 提供免费层级 | 极速推理 (Llama, Mixtral) | +| **Cerebras** | 提供免费层级 | 极速推理 (Llama, Qwen 等) | diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 000000000..8c5c0e252 --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,116 @@ + +# 🦐 PicoClaw Roadmap + +> **Vision**: To build the ultimate lightweight, secure, and fully autonomous AI Agent infrastructure.automate the mundane, unleash your creativity + +--- + +## 🚀 1. Core Optimization: Extreme Lightweight + +*Our defining characteristic. We fight software bloat to ensure PicoClaw runs smoothly on the smallest embedded devices.* + +* [**Memory Footprint Reduction**](https://github.com/sipeed/picoclaw/issues/346) + * **Goal**: Run smoothly on 64MB RAM embedded boards (e.g., low-end RISC-V SBCs) with the core process consuming < 20MB. + * **Context**: RAM is expensive and scarce on edge devices. Memory optimization takes precedence over storage size. + * **Action**: Analyze memory growth between releases, remove redundant dependencies, and optimize data structures. + + +## 🛡️ 2. Security Hardening: Defense in Depth + +*Paying off early technical debt. We invite security experts to help build a "Secure-by-Default" agent.* + +* **Input Defense & Permission Control** + * **Prompt Injection Defense**: Harden JSON extraction logic to prevent LLM manipulation. + * **Tool Abuse Prevention**: Strict parameter validation to ensure generated commands stay within safe boundaries. + * **SSRF Protection**: Built-in blocklists for network tools to prevent accessing internal IPs (LAN/Metadata services). + + +* **Sandboxing & Isolation** + * **Filesystem Sandbox**: Restrict file R/W operations to specific directories only. + * **Context Isolation**: Prevent data leakage between different user sessions or channels. + * **Privacy Redaction**: Auto-redact sensitive info (API Keys, PII) from logs and standard outputs. + + +* **Authentication & Secrets** + * **Crypto Upgrade**: Adopt modern algorithms like `ChaCha20-Poly1305` for secret storage. + * **OAuth 2.0 Flow**: Deprecate hardcoded API keys in the CLI; move to secure OAuth flows. + + + +## 🔌 3. Connectivity: Protocol-First Architecture + +*Connect every model, reach every platform.* + +* **Provider** + * [**Architecture Upgrade**](https://github.com/sipeed/picoclaw/issues/283): Refactor from "Vendor-based" to "Protocol-based" classification (e.g., OpenAI-compatible, Ollama-compatible). *(Status: In progress by @Daming, ETA 5 days)* + * **Local Models**: Deep integration with **Ollama**, **vLLM**, **LM Studio**, and **Mistral** (local inference). + * **Online Models**: Continued support for frontier closed-source models. + + +* **Channel** + * **IM Matrix**: QQ, WeChat (Work), DingTalk, Feishu (Lark), Telegram, Discord, WhatsApp, LINE, Slack, Email, KOOK, Signal, ... + * **Standards**: Support for the **OneBot** protocol. + * [**attachment**](https://github.com/sipeed/picoclaw/issues/348): Native handling of images, audio, and video attachments. + + +* **Skill Marketplace** + * [**Discovery skills**](https://github.com/sipeed/picoclaw/issues/287): Implement `find_skill` to automatically discover and install skills from the [GitHub Skills Repo] or other registries. + + + +## 🧠 4. Advanced Capabilities: From Chatbot to Agentic AI + +*Beyond conversation—focusing on action and collaboration.* + +* **Operations** + * [**MCP Support**](https://github.com/sipeed/picoclaw/issues/290): Native support for the **Model Context Protocol (MCP)**. + * [**Browser Automation**](https://github.com/sipeed/picoclaw/issues/293): Headless browser control via CDP (Chrome DevTools Protocol) or ActionBook. + * [**Mobile Operation**](https://github.com/sipeed/picoclaw/issues/292): Android device control (similar to BotDrop). + + +* **Multi-Agent Collaboration** + * [**Basic Multi-Agent**](https://github.com/sipeed/picoclaw/issues/294) implement + * [**Model Routing**](https://github.com/sipeed/picoclaw/issues/295): "Smart Routing" — dispatch simple tasks to small/local models (fast/cheap) and complex tasks to SOTA models (smart). + * [**Swarm Mode**](https://github.com/sipeed/picoclaw/issues/284): Collaboration between multiple PicoClaw instances on the same network. + * [**AIEOS**](https://github.com/sipeed/picoclaw/issues/296): Exploring AI-Native Operating System interaction paradigms. + + + +## 📚 5. Developer Experience (DevEx) & Documentation + +*Lowering the barrier to entry so anyone can deploy in minutes.* + +* [**QuickGuide (Zero-Config Start)**](https://github.com/sipeed/picoclaw/issues/350) + * Interactive CLI Wizard: If launched without config, automatically detect the environment and guide the user through Token/Network setup step-by-step. + + +* **Comprehensive Documentation** + * **Platform Guides**: Dedicated guides for Windows, macOS, Linux, and Android. + * **Step-by-Step Tutorials**: "Babysitter-level" guides for configuring Providers and Channels. + * **AI-Assisted Docs**: Using AI to auto-generate API references and code comments (with human verification to prevent hallucinations). + + + +## 🤖 6. Engineering: AI-Powered Open Source + +*Born from Vibe Coding, we continue to use AI to accelerate development.* + +* **AI-Enhanced CI/CD** + * Integrate AI for automated Code Review, Linting, and PR Labeling. + * **Bot Noise Reduction**: Optimize bot interactions to keep PR timelines clean. + * **Issue Triage**: AI agents to analyze incoming issues and suggest preliminary fixes. + + + +## 🎨 7. Brand & Community + +* [**Logo Design**](https://github.com/sipeed/picoclaw/issues/297): We are looking for a **Mantis Shrimp (Stomatopoda)** logo design! + * *Concept*: Needs to reflect "Small but Mighty" and "Lightning Fast Strikes." + + + +--- + +### 🤝 Call for Contributions + +We welcome community contributions to any item on this roadmap! Please comment on the relevant Issue or submit a PR. Let's build the best Edge AI Agent together! \ No newline at end of file diff --git a/assets/termux.jpg b/assets/termux.jpg new file mode 100644 index 000000000..30c724a20 Binary files /dev/null and b/assets/termux.jpg differ diff --git a/assets/wechat.png b/assets/wechat.png index 73b09da68..a34217c33 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw/cmd_agent.go b/cmd/picoclaw/cmd_agent.go new file mode 100644 index 000000000..6d6ff935f --- /dev/null +++ b/cmd/picoclaw/cmd_agent.go @@ -0,0 +1,181 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/chzyer/readline" + + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +func agentCmd() { + message := "" + sessionKey := "cli:default" + modelOverride := "" + + args := os.Args[2:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--debug", "-d": + logger.SetLevel(logger.DEBUG) + fmt.Println("🔍 Debug mode enabled") + case "-m", "--message": + if i+1 < len(args) { + message = args[i+1] + i++ + } + case "-s", "--session": + if i+1 < len(args) { + sessionKey = args[i+1] + i++ + } + case "--model", "-model": + if i+1 < len(args) { + modelOverride = args[i+1] + i++ + } + } + } + + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + os.Exit(1) + } + + if modelOverride != "" { + cfg.Agents.Defaults.Model = modelOverride + } + + provider, modelID, err := providers.CreateProvider(cfg) + if err != nil { + fmt.Printf("Error creating provider: %v\n", err) + os.Exit(1) + } + // Use the resolved model ID from provider creation + if modelID != "" { + cfg.Agents.Defaults.Model = modelID + } + + msgBus := bus.NewMessageBus() + agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + + // Print agent startup info (only for interactive mode) + startupInfo := agentLoop.GetStartupInfo() + logger.InfoCF("agent", "Agent initialized", + map[string]any{ + "tools_count": startupInfo["tools"].(map[string]any)["count"], + "skills_total": startupInfo["skills"].(map[string]any)["total"], + "skills_available": startupInfo["skills"].(map[string]any)["available"], + }) + + if message != "" { + ctx := context.Background() + response, err := agentLoop.ProcessDirect(ctx, message, sessionKey) + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + fmt.Printf("\n%s %s\n", logo, response) + } else { + fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", logo) + interactiveMode(agentLoop, sessionKey) + } +} + +func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) { + prompt := fmt.Sprintf("%s You: ", logo) + + rl, err := readline.NewEx(&readline.Config{ + Prompt: prompt, + HistoryFile: filepath.Join(os.TempDir(), ".picoclaw_history"), + HistoryLimit: 100, + InterruptPrompt: "^C", + EOFPrompt: "exit", + }) + if err != nil { + fmt.Printf("Error initializing readline: %v\n", err) + fmt.Println("Falling back to simple input mode...") + simpleInteractiveMode(agentLoop, sessionKey) + return + } + defer rl.Close() + + for { + line, err := rl.Readline() + if err != nil { + if err == readline.ErrInterrupt || err == io.EOF { + fmt.Println("\nGoodbye!") + return + } + fmt.Printf("Error reading input: %v\n", err) + continue + } + + input := strings.TrimSpace(line) + if input == "" { + continue + } + + if input == "exit" || input == "quit" { + fmt.Println("Goodbye!") + return + } + + ctx := context.Background() + response, err := agentLoop.ProcessDirect(ctx, input, sessionKey) + if err != nil { + fmt.Printf("Error: %v\n", err) + continue + } + + fmt.Printf("\n%s %s\n\n", logo, response) + } +} + +func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) { + reader := bufio.NewReader(os.Stdin) + for { + fmt.Print(fmt.Sprintf("%s You: ", logo)) + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + fmt.Println("\nGoodbye!") + return + } + fmt.Printf("Error reading input: %v\n", err) + continue + } + + input := strings.TrimSpace(line) + if input == "" { + continue + } + + if input == "exit" || input == "quit" { + fmt.Println("Goodbye!") + return + } + + ctx := context.Background() + response, err := agentLoop.ProcessDirect(ctx, input, sessionKey) + if err != nil { + fmt.Printf("Error: %v\n", err) + continue + } + + fmt.Printf("\n%s %s\n\n", logo, response) + } +} diff --git a/cmd/picoclaw/cmd_auth.go b/cmd/picoclaw/cmd_auth.go new file mode 100644 index 000000000..729c56177 --- /dev/null +++ b/cmd/picoclaw/cmd_auth.go @@ -0,0 +1,512 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +const supportedProvidersMsg = "Supported providers: openai, anthropic, google-antigravity" + +func authCmd() { + if len(os.Args) < 3 { + authHelp() + return + } + + switch os.Args[2] { + case "login": + authLoginCmd() + case "logout": + authLogoutCmd() + case "status": + authStatusCmd() + case "models": + authModelsCmd() + default: + fmt.Printf("Unknown auth command: %s\n", os.Args[2]) + authHelp() + } +} + +func authHelp() { + fmt.Println("\nAuth commands:") + fmt.Println(" login Login via OAuth or paste token") + fmt.Println(" logout Remove stored credentials") + fmt.Println(" status Show current auth status") + fmt.Println(" models List available Antigravity models") + fmt.Println() + fmt.Println("Login options:") + fmt.Println(" --provider Provider to login with (openai, anthropic, google-antigravity)") + fmt.Println(" --device-code Use device code flow (for headless environments)") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw auth login --provider openai") + fmt.Println(" picoclaw auth login --provider openai --device-code") + fmt.Println(" picoclaw auth login --provider anthropic") + fmt.Println(" picoclaw auth login --provider google-antigravity") + fmt.Println(" picoclaw auth models") + fmt.Println(" picoclaw auth logout --provider openai") + fmt.Println(" picoclaw auth status") +} + +func authLoginCmd() { + provider := "" + useDeviceCode := false + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + case "--device-code": + useDeviceCode = true + } + } + + if provider == "" { + fmt.Println("Error: --provider is required") + fmt.Println(supportedProvidersMsg) + return + } + + switch provider { + case "openai": + authLoginOpenAI(useDeviceCode) + case "anthropic": + authLoginPasteToken(provider) + case "google-antigravity", "antigravity": + authLoginGoogleAntigravity() + default: + fmt.Printf("Unsupported provider: %s\n", provider) + fmt.Println(supportedProvidersMsg) + } +} + +func authLoginOpenAI(useDeviceCode bool) { + cfg := auth.OpenAIOAuthConfig() + + var cred *auth.AuthCredential + var err error + + if useDeviceCode { + cred, err = auth.LoginDeviceCode(cfg) + } else { + cred, err = auth.LoginBrowser(cfg) + } + + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err = auth.SetCredential("openai", cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + // Update Providers (legacy format) + appCfg.Providers.OpenAI.AuthMethod = "oauth" + + // Update or add openai in ModelList + foundOpenAI := false + for i := range appCfg.ModelList { + if isOpenAIModel(appCfg.ModelList[i].Model) { + appCfg.ModelList[i].AuthMethod = "oauth" + foundOpenAI = true + break + } + } + + // If no openai in ModelList, add it + if !foundOpenAI { + appCfg.ModelList = append(appCfg.ModelList, config.ModelConfig{ + ModelName: "gpt-5.2", + Model: "openai/gpt-5.2", + AuthMethod: "oauth", + }) + } + + // Update default model to use OpenAI + appCfg.Agents.Defaults.Model = "gpt-5.2" + + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Println("Login successful!") + if cred.AccountID != "" { + fmt.Printf("Account: %s\n", cred.AccountID) + } + fmt.Println("Default model set to: gpt-5.2") +} + +func authLoginGoogleAntigravity() { + cfg := auth.GoogleAntigravityOAuthConfig() + + cred, err := auth.LoginBrowser(cfg) + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + cred.Provider = "google-antigravity" + + // Fetch user email from Google userinfo + email, err := fetchGoogleUserEmail(cred.AccessToken) + if err != nil { + fmt.Printf("Warning: could not fetch email: %v\n", err) + } else { + cred.Email = email + fmt.Printf("Email: %s\n", email) + } + + // Fetch Cloud Code Assist project ID + projectID, err := providers.FetchAntigravityProjectID(cred.AccessToken) + if err != nil { + fmt.Printf("Warning: could not fetch project ID: %v\n", err) + fmt.Println("You may need Google Cloud Code Assist enabled on your account.") + } else { + cred.ProjectID = projectID + fmt.Printf("Project: %s\n", projectID) + } + + if err = auth.SetCredential("google-antigravity", cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + // Update Providers (legacy format, for backward compatibility) + appCfg.Providers.Antigravity.AuthMethod = "oauth" + + // Update or add antigravity in ModelList + foundAntigravity := false + for i := range appCfg.ModelList { + if isAntigravityModel(appCfg.ModelList[i].Model) { + appCfg.ModelList[i].AuthMethod = "oauth" + foundAntigravity = true + break + } + } + + // If no antigravity in ModelList, add it + if !foundAntigravity { + appCfg.ModelList = append(appCfg.ModelList, config.ModelConfig{ + ModelName: "gemini-flash", + Model: "antigravity/gemini-3-flash", + AuthMethod: "oauth", + }) + } + + // Update default model + appCfg.Agents.Defaults.Model = "gemini-flash" + + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Println("\n✓ Google Antigravity login successful!") + fmt.Println("Default model set to: gemini-flash") + fmt.Println("Try it: picoclaw agent -m \"Hello world\"") +} + +func fetchGoogleUserEmail(accessToken string) (string, error) { + req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v2/userinfo", nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("userinfo request failed: %s", string(body)) + } + + var userInfo struct { + Email string `json:"email"` + } + if err := json.Unmarshal(body, &userInfo); err != nil { + return "", err + } + return userInfo.Email, nil +} + +func authLoginPasteToken(provider string) { + cred, err := auth.LoginPasteToken(provider, os.Stdin) + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err = auth.SetCredential(provider, cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + switch provider { + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "token" + // Update ModelList + found := false + for i := range appCfg.ModelList { + if isAnthropicModel(appCfg.ModelList[i].Model) { + appCfg.ModelList[i].AuthMethod = "token" + found = true + break + } + } + if !found { + appCfg.ModelList = append(appCfg.ModelList, config.ModelConfig{ + ModelName: "claude-sonnet-4.6", + Model: "anthropic/claude-sonnet-4.6", + AuthMethod: "token", + }) + } + // Update default model + appCfg.Agents.Defaults.Model = "claude-sonnet-4.6" + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "token" + // Update ModelList + found := false + for i := range appCfg.ModelList { + if isOpenAIModel(appCfg.ModelList[i].Model) { + appCfg.ModelList[i].AuthMethod = "token" + found = true + break + } + } + if !found { + appCfg.ModelList = append(appCfg.ModelList, config.ModelConfig{ + ModelName: "gpt-5.2", + Model: "openai/gpt-5.2", + AuthMethod: "token", + }) + } + // Update default model + appCfg.Agents.Defaults.Model = "gpt-5.2" + } + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Printf("Token saved for %s!\n", provider) + fmt.Printf("Default model set to: %s\n", appCfg.Agents.Defaults.Model) +} + +func authLogoutCmd() { + provider := "" + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + } + } + + if provider != "" { + if err := auth.DeleteCredential(provider); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + // Clear AuthMethod in ModelList + for i := range appCfg.ModelList { + switch provider { + case "openai": + if isOpenAIModel(appCfg.ModelList[i].Model) { + appCfg.ModelList[i].AuthMethod = "" + } + case "anthropic": + if isAnthropicModel(appCfg.ModelList[i].Model) { + appCfg.ModelList[i].AuthMethod = "" + } + case "google-antigravity", "antigravity": + if isAntigravityModel(appCfg.ModelList[i].Model) { + appCfg.ModelList[i].AuthMethod = "" + } + } + } + // Clear AuthMethod in Providers (legacy) + switch provider { + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "" + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "" + case "google-antigravity", "antigravity": + appCfg.Providers.Antigravity.AuthMethod = "" + } + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Printf("Logged out from %s\n", provider) + } else { + if err := auth.DeleteAllCredentials(); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + // Clear all AuthMethods in ModelList + for i := range appCfg.ModelList { + appCfg.ModelList[i].AuthMethod = "" + } + // Clear all AuthMethods in Providers (legacy) + appCfg.Providers.OpenAI.AuthMethod = "" + appCfg.Providers.Anthropic.AuthMethod = "" + appCfg.Providers.Antigravity.AuthMethod = "" + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Println("Logged out from all providers") + } +} + +func authStatusCmd() { + store, err := auth.LoadStore() + if err != nil { + fmt.Printf("Error loading auth store: %v\n", err) + return + } + + if len(store.Credentials) == 0 { + fmt.Println("No authenticated providers.") + fmt.Println("Run: picoclaw auth login --provider ") + return + } + + fmt.Println("\nAuthenticated Providers:") + fmt.Println("------------------------") + for provider, cred := range store.Credentials { + status := "active" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + + fmt.Printf(" %s:\n", provider) + fmt.Printf(" Method: %s\n", cred.AuthMethod) + fmt.Printf(" Status: %s\n", status) + if cred.AccountID != "" { + fmt.Printf(" Account: %s\n", cred.AccountID) + } + if cred.Email != "" { + fmt.Printf(" Email: %s\n", cred.Email) + } + if cred.ProjectID != "" { + fmt.Printf(" Project: %s\n", cred.ProjectID) + } + if !cred.ExpiresAt.IsZero() { + fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04")) + } + } +} + +func authModelsCmd() { + cred, err := auth.GetCredential("google-antigravity") + if err != nil || cred == nil { + fmt.Println("Not logged in to Google Antigravity.") + fmt.Println("Run: picoclaw auth login --provider google-antigravity") + return + } + + // Refresh token if needed + if cred.NeedsRefresh() && cred.RefreshToken != "" { + oauthCfg := auth.GoogleAntigravityOAuthConfig() + refreshed, refreshErr := auth.RefreshAccessToken(cred, oauthCfg) + if refreshErr == nil { + cred = refreshed + _ = auth.SetCredential("google-antigravity", cred) + } + } + + projectID := cred.ProjectID + if projectID == "" { + fmt.Println("No project ID stored. Try logging in again.") + return + } + + fmt.Printf("Fetching models for project: %s\n\n", projectID) + + models, err := providers.FetchAntigravityModels(cred.AccessToken, projectID) + if err != nil { + fmt.Printf("Error fetching models: %v\n", err) + return + } + + if len(models) == 0 { + fmt.Println("No models available.") + return + } + + fmt.Println("Available Antigravity Models:") + fmt.Println("-----------------------------") + for _, m := range models { + status := "✓" + if m.IsExhausted { + status = "✗ (quota exhausted)" + } + name := m.ID + if m.DisplayName != "" { + name = fmt.Sprintf("%s (%s)", m.ID, m.DisplayName) + } + fmt.Printf(" %s %s\n", status, name) + } +} + +// isAntigravityModel checks if a model string belongs to antigravity provider +func isAntigravityModel(model string) bool { + return model == "antigravity" || + model == "google-antigravity" || + strings.HasPrefix(model, "antigravity/") || + strings.HasPrefix(model, "google-antigravity/") +} + +// isOpenAIModel checks if a model string belongs to openai provider +func isOpenAIModel(model string) bool { + return model == "openai" || + strings.HasPrefix(model, "openai/") +} + +// isAnthropicModel checks if a model string belongs to anthropic provider +func isAnthropicModel(model string) bool { + return model == "anthropic" || + strings.HasPrefix(model, "anthropic/") +} diff --git a/cmd/picoclaw/cmd_cron.go b/cmd/picoclaw/cmd_cron.go new file mode 100644 index 000000000..8c42bde06 --- /dev/null +++ b/cmd/picoclaw/cmd_cron.go @@ -0,0 +1,227 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "github.com/sipeed/picoclaw/pkg/cron" +) + +func cronCmd() { + if len(os.Args) < 3 { + cronHelp() + return + } + + subcommand := os.Args[2] + + // Load config to get workspace path + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + return + } + + cronStorePath := filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json") + + switch subcommand { + case "list": + cronListCmd(cronStorePath) + case "add": + cronAddCmd(cronStorePath) + case "remove": + if len(os.Args) < 4 { + fmt.Println("Usage: picoclaw cron remove ") + return + } + cronRemoveCmd(cronStorePath, os.Args[3]) + case "enable": + cronEnableCmd(cronStorePath, false) + case "disable": + cronEnableCmd(cronStorePath, true) + default: + fmt.Printf("Unknown cron command: %s\n", subcommand) + cronHelp() + } +} + +func cronHelp() { + fmt.Println("\nCron commands:") + fmt.Println(" list List all scheduled jobs") + fmt.Println(" add Add a new scheduled job") + fmt.Println(" remove Remove a job by ID") + fmt.Println(" enable Enable a job") + fmt.Println(" disable Disable a job") + fmt.Println() + fmt.Println("Add options:") + fmt.Println(" -n, --name Job name") + fmt.Println(" -m, --message Message for agent") + fmt.Println(" -e, --every Run every N seconds") + fmt.Println(" -c, --cron Cron expression (e.g. '0 9 * * *')") + fmt.Println(" -d, --deliver Deliver response to channel") + fmt.Println(" --to Recipient for delivery") + fmt.Println(" --channel Channel for delivery") +} + +func cronListCmd(storePath string) { + cs := cron.NewCronService(storePath, nil) + jobs := cs.ListJobs(true) // Show all jobs, including disabled + + if len(jobs) == 0 { + fmt.Println("No scheduled jobs.") + return + } + + fmt.Println("\nScheduled Jobs:") + fmt.Println("----------------") + for _, job := range jobs { + var schedule string + if job.Schedule.Kind == "every" && job.Schedule.EveryMS != nil { + schedule = fmt.Sprintf("every %ds", *job.Schedule.EveryMS/1000) + } else if job.Schedule.Kind == "cron" { + schedule = job.Schedule.Expr + } else { + schedule = "one-time" + } + + nextRun := "scheduled" + if job.State.NextRunAtMS != nil { + nextTime := time.UnixMilli(*job.State.NextRunAtMS) + nextRun = nextTime.Format("2006-01-02 15:04") + } + + status := "enabled" + if !job.Enabled { + status = "disabled" + } + + fmt.Printf(" %s (%s)\n", job.Name, job.ID) + fmt.Printf(" Schedule: %s\n", schedule) + fmt.Printf(" Status: %s\n", status) + fmt.Printf(" Next run: %s\n", nextRun) + } +} + +func cronAddCmd(storePath string) { + name := "" + message := "" + var everySec *int64 + cronExpr := "" + deliver := false + channel := "" + to := "" + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "-n", "--name": + if i+1 < len(args) { + name = args[i+1] + i++ + } + case "-m", "--message": + if i+1 < len(args) { + message = args[i+1] + i++ + } + case "-e", "--every": + if i+1 < len(args) { + var sec int64 + fmt.Sscanf(args[i+1], "%d", &sec) + everySec = &sec + i++ + } + case "-c", "--cron": + if i+1 < len(args) { + cronExpr = args[i+1] + i++ + } + case "-d", "--deliver": + deliver = true + case "--to": + if i+1 < len(args) { + to = args[i+1] + i++ + } + case "--channel": + if i+1 < len(args) { + channel = args[i+1] + i++ + } + } + } + + if name == "" { + fmt.Println("Error: --name is required") + return + } + + if message == "" { + fmt.Println("Error: --message is required") + return + } + + if everySec == nil && cronExpr == "" { + fmt.Println("Error: Either --every or --cron must be specified") + return + } + + var schedule cron.CronSchedule + if everySec != nil { + everyMS := *everySec * 1000 + schedule = cron.CronSchedule{ + Kind: "every", + EveryMS: &everyMS, + } + } else { + schedule = cron.CronSchedule{ + Kind: "cron", + Expr: cronExpr, + } + } + + cs := cron.NewCronService(storePath, nil) + job, err := cs.AddJob(name, schedule, message, deliver, channel, to) + if err != nil { + fmt.Printf("Error adding job: %v\n", err) + return + } + + fmt.Printf("✓ Added job '%s' (%s)\n", job.Name, job.ID) +} + +func cronRemoveCmd(storePath, jobID string) { + cs := cron.NewCronService(storePath, nil) + if cs.RemoveJob(jobID) { + fmt.Printf("✓ Removed job %s\n", jobID) + } else { + fmt.Printf("✗ Job %s not found\n", jobID) + } +} + +func cronEnableCmd(storePath string, disable bool) { + if len(os.Args) < 4 { + fmt.Println("Usage: picoclaw cron enable/disable ") + return + } + + jobID := os.Args[3] + cs := cron.NewCronService(storePath, nil) + enabled := !disable + + job := cs.EnableJob(jobID, enabled) + if job != nil { + status := "enabled" + if disable { + status = "disabled" + } + fmt.Printf("✓ Job '%s' %s\n", job.Name, status) + } else { + fmt.Printf("✗ Job %s not found\n", jobID) + } +} diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go new file mode 100644 index 000000000..28ef76ad3 --- /dev/null +++ b/cmd/picoclaw/cmd_gateway.go @@ -0,0 +1,248 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "os/signal" + "path/filepath" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/cron" + "github.com/sipeed/picoclaw/pkg/devices" + "github.com/sipeed/picoclaw/pkg/health" + "github.com/sipeed/picoclaw/pkg/heartbeat" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/state" + "github.com/sipeed/picoclaw/pkg/tools" + "github.com/sipeed/picoclaw/pkg/voice" +) + +func gatewayCmd() { + // Check for --debug flag + args := os.Args[2:] + for _, arg := range args { + if arg == "--debug" || arg == "-d" { + logger.SetLevel(logger.DEBUG) + fmt.Println("🔍 Debug mode enabled") + break + } + } + + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + os.Exit(1) + } + + provider, modelID, err := providers.CreateProvider(cfg) + if err != nil { + fmt.Printf("Error creating provider: %v\n", err) + os.Exit(1) + } + // Use the resolved model ID from provider creation + if modelID != "" { + cfg.Agents.Defaults.Model = modelID + } + + msgBus := bus.NewMessageBus() + agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + + // Print agent startup info + fmt.Println("\n📦 Agent Status:") + startupInfo := agentLoop.GetStartupInfo() + toolsInfo := startupInfo["tools"].(map[string]any) + skillsInfo := startupInfo["skills"].(map[string]any) + fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"]) + fmt.Printf(" • Skills: %d/%d available\n", + skillsInfo["available"], + skillsInfo["total"]) + + // Log to file as well + logger.InfoCF("agent", "Agent initialized", + map[string]any{ + "tools_count": toolsInfo["count"], + "skills_total": skillsInfo["total"], + "skills_available": skillsInfo["available"], + }) + + // Setup cron tool and service + execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute + cronService := setupCronTool( + agentLoop, + msgBus, + cfg.WorkspacePath(), + cfg.Agents.Defaults.RestrictToWorkspace, + execTimeout, + cfg, + ) + + heartbeatService := heartbeat.NewHeartbeatService( + cfg.WorkspacePath(), + cfg.Heartbeat.Interval, + cfg.Heartbeat.Enabled, + ) + heartbeatService.SetBus(msgBus) + heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + // Use cli:direct as fallback if no valid channel + if channel == "" || chatID == "" { + channel, chatID = "cli", "direct" + } + // Use ProcessHeartbeat - no session history, each heartbeat is independent + var response string + response, err = agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID) + if err != nil { + return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err)) + } + if response == "HEARTBEAT_OK" { + return tools.SilentResult("Heartbeat OK") + } + // For heartbeat, always return silent - the subagent result will be + // sent to user via processSystemMessage when the async task completes + return tools.SilentResult(response) + }) + + channelManager, err := channels.NewManager(cfg, msgBus) + if err != nil { + fmt.Printf("Error creating channel manager: %v\n", err) + os.Exit(1) + } + + // Inject channel manager into agent loop for command handling + agentLoop.SetChannelManager(channelManager) + + var transcriber *voice.GroqTranscriber + groqAPIKey := cfg.Providers.Groq.APIKey + if groqAPIKey == "" { + for _, mc := range cfg.ModelList { + if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey != "" { + groqAPIKey = mc.APIKey + break + } + } + } + if groqAPIKey != "" { + transcriber = voice.NewGroqTranscriber(groqAPIKey) + logger.InfoC("voice", "Groq voice transcription enabled") + } + + if transcriber != nil { + if telegramChannel, ok := channelManager.GetChannel("telegram"); ok { + if tc, ok := telegramChannel.(*channels.TelegramChannel); ok { + tc.SetTranscriber(transcriber) + logger.InfoC("voice", "Groq transcription attached to Telegram channel") + } + } + if discordChannel, ok := channelManager.GetChannel("discord"); ok { + if dc, ok := discordChannel.(*channels.DiscordChannel); ok { + dc.SetTranscriber(transcriber) + logger.InfoC("voice", "Groq transcription attached to Discord channel") + } + } + if slackChannel, ok := channelManager.GetChannel("slack"); ok { + if sc, ok := slackChannel.(*channels.SlackChannel); ok { + sc.SetTranscriber(transcriber) + logger.InfoC("voice", "Groq transcription attached to Slack channel") + } + } + } + + enabledChannels := channelManager.GetEnabledChannels() + if len(enabledChannels) > 0 { + fmt.Printf("✓ Channels enabled: %s\n", enabledChannels) + } else { + fmt.Println("⚠ Warning: No channels enabled") + } + + fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port) + fmt.Println("Press Ctrl+C to stop") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := cronService.Start(); err != nil { + fmt.Printf("Error starting cron service: %v\n", err) + } + fmt.Println("✓ Cron service started") + + if err := heartbeatService.Start(); err != nil { + fmt.Printf("Error starting heartbeat service: %v\n", err) + } + fmt.Println("✓ Heartbeat service started") + + stateManager := state.NewManager(cfg.WorkspacePath()) + deviceService := devices.NewService(devices.Config{ + Enabled: cfg.Devices.Enabled, + MonitorUSB: cfg.Devices.MonitorUSB, + }, stateManager) + deviceService.SetBus(msgBus) + if err := deviceService.Start(ctx); err != nil { + fmt.Printf("Error starting device service: %v\n", err) + } else if cfg.Devices.Enabled { + fmt.Println("✓ Device event service started") + } + + if err := channelManager.StartAll(ctx); err != nil { + fmt.Printf("Error starting channels: %v\n", err) + } + + healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + go func() { + if err := healthServer.Start(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("health", "Health server error", map[string]any{"error": err.Error()}) + } + }() + fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) + + go agentLoop.Run(ctx) + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt) + <-sigChan + + fmt.Println("\nShutting down...") + cancel() + healthServer.Stop(context.Background()) + deviceService.Stop() + heartbeatService.Stop() + cronService.Stop() + agentLoop.Stop() + channelManager.StopAll(ctx) + fmt.Println("✓ Gateway stopped") +} + +func setupCronTool( + agentLoop *agent.AgentLoop, + msgBus *bus.MessageBus, + workspace string, + restrict bool, + execTimeout time.Duration, + cfg *config.Config, +) *cron.CronService { + cronStorePath := filepath.Join(workspace, "cron", "jobs.json") + + // Create cron service + cronService := cron.NewCronService(cronStorePath, nil) + + // Create and register CronTool + cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg) + agentLoop.RegisterTool(cronTool) + + // Set the onJob handler + cronService.SetOnJob(func(job *cron.CronJob) (string, error) { + result := cronTool.ExecuteJob(context.Background(), job) + return result, nil + }) + + return cronService +} diff --git a/cmd/picoclaw/cmd_migrate.go b/cmd/picoclaw/cmd_migrate.go new file mode 100644 index 000000000..86d4903ef --- /dev/null +++ b/cmd/picoclaw/cmd_migrate.go @@ -0,0 +1,81 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "fmt" + "os" + + "github.com/sipeed/picoclaw/pkg/migrate" +) + +func migrateCmd() { + if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") { + migrateHelp() + return + } + + opts := migrate.Options{} + + args := os.Args[2:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--dry-run": + opts.DryRun = true + case "--config-only": + opts.ConfigOnly = true + case "--workspace-only": + opts.WorkspaceOnly = true + case "--force": + opts.Force = true + case "--refresh": + opts.Refresh = true + case "--openclaw-home": + if i+1 < len(args) { + opts.OpenClawHome = args[i+1] + i++ + } + case "--picoclaw-home": + if i+1 < len(args) { + opts.PicoClawHome = args[i+1] + i++ + } + default: + fmt.Printf("Unknown flag: %s\n", args[i]) + migrateHelp() + os.Exit(1) + } + } + + result, err := migrate.Run(opts) + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + + if !opts.DryRun { + migrate.PrintSummary(result) + } +} + +func migrateHelp() { + fmt.Println("\nMigrate from OpenClaw to PicoClaw") + fmt.Println() + fmt.Println("Usage: picoclaw migrate [options]") + fmt.Println() + fmt.Println("Options:") + fmt.Println(" --dry-run Show what would be migrated without making changes") + fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)") + fmt.Println(" --config-only Only migrate config, skip workspace files") + fmt.Println(" --workspace-only Only migrate workspace files, skip config") + fmt.Println(" --force Skip confirmation prompts") + fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)") + fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw") + fmt.Println(" picoclaw migrate --dry-run Show what would be migrated") + fmt.Println(" picoclaw migrate --refresh Re-sync workspace files") + fmt.Println(" picoclaw migrate --force Migrate without confirmation") +} diff --git a/cmd/picoclaw/cmd_onboard.go b/cmd/picoclaw/cmd_onboard.go new file mode 100644 index 000000000..1a9ebad61 --- /dev/null +++ b/cmd/picoclaw/cmd_onboard.go @@ -0,0 +1,108 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "embed" + "fmt" + "io/fs" + "os" + "path/filepath" + + "github.com/sipeed/picoclaw/pkg/config" +) + +//go:generate cp -r ../../workspace . +//go:embed workspace +var embeddedFiles embed.FS + +func onboard() { + configPath := getConfigPath() + + if _, err := os.Stat(configPath); err == nil { + fmt.Printf("Config already exists at %s\n", configPath) + fmt.Print("Overwrite? (y/n): ") + var response string + fmt.Scanln(&response) + if response != "y" { + fmt.Println("Aborted.") + return + } + } + + cfg := config.DefaultConfig() + if err := config.SaveConfig(configPath, cfg); err != nil { + fmt.Printf("Error saving config: %v\n", err) + os.Exit(1) + } + + workspace := cfg.WorkspacePath() + createWorkspaceTemplates(workspace) + + fmt.Printf("%s picoclaw is ready!\n", logo) + fmt.Println("\nNext steps:") + fmt.Println(" 1. Add your API key to", configPath) + fmt.Println("") + fmt.Println(" Recommended:") + fmt.Println(" - OpenRouter: https://openrouter.ai/keys (access 100+ models)") + fmt.Println(" - Ollama: https://ollama.com (local, free)") + fmt.Println("") + fmt.Println(" See README.md for 17+ supported providers.") + fmt.Println("") + fmt.Println(" 2. Chat: picoclaw agent -m \"Hello!\"") +} + +func copyEmbeddedToTarget(targetDir string) error { + // Ensure target directory exists + if err := os.MkdirAll(targetDir, 0o755); err != nil { + return fmt.Errorf("Failed to create target directory: %w", err) + } + + // Walk through all files in embed.FS + err := fs.WalkDir(embeddedFiles, "workspace", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // Skip directories + if d.IsDir() { + return nil + } + + // Read embedded file + data, err := embeddedFiles.ReadFile(path) + if err != nil { + return fmt.Errorf("Failed to read embedded file %s: %w", path, err) + } + + new_path, err := filepath.Rel("workspace", path) + if err != nil { + return fmt.Errorf("Failed to get relative path for %s: %v\n", path, err) + } + + // Build target file path + targetPath := filepath.Join(targetDir, new_path) + + // Ensure target file's directory exists + if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { + return fmt.Errorf("Failed to create directory %s: %w", filepath.Dir(targetPath), err) + } + + // Write file + if err := os.WriteFile(targetPath, data, 0o644); err != nil { + return fmt.Errorf("Failed to write file %s: %w", targetPath, err) + } + + return nil + }) + + return err +} + +func createWorkspaceTemplates(workspace string) { + err := copyEmbeddedToTarget(workspace) + if err != nil { + fmt.Printf("Error copying workspace templates: %v\n", err) + } +} diff --git a/cmd/picoclaw/cmd_skills.go b/cmd/picoclaw/cmd_skills.go new file mode 100644 index 000000000..0814494b3 --- /dev/null +++ b/cmd/picoclaw/cmd_skills.go @@ -0,0 +1,305 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/skills" + "github.com/sipeed/picoclaw/pkg/utils" +) + +func skillsHelp() { + fmt.Println("\nSkills commands:") + fmt.Println(" list List installed skills") + fmt.Println(" install Install skill from GitHub") + fmt.Println(" install-builtin Install all builtin skills to workspace") + fmt.Println(" list-builtin List available builtin skills") + fmt.Println(" remove Remove installed skill") + fmt.Println(" search Search available skills") + fmt.Println(" show Show skill details") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw skills list") + fmt.Println(" picoclaw skills install sipeed/picoclaw-skills/weather") + fmt.Println(" picoclaw skills install-builtin") + fmt.Println(" picoclaw skills list-builtin") + fmt.Println(" picoclaw skills remove weather") + fmt.Println(" picoclaw skills install --registry clawhub github") +} + +func skillsListCmd(loader *skills.SkillsLoader) { + allSkills := loader.ListSkills() + + if len(allSkills) == 0 { + fmt.Println("No skills installed.") + return + } + + fmt.Println("\nInstalled Skills:") + fmt.Println("------------------") + for _, skill := range allSkills { + fmt.Printf(" ✓ %s (%s)\n", skill.Name, skill.Source) + if skill.Description != "" { + fmt.Printf(" %s\n", skill.Description) + } + } +} + +func skillsInstallCmd(installer *skills.SkillInstaller, cfg *config.Config) { + if len(os.Args) < 4 { + fmt.Println("Usage: picoclaw skills install ") + fmt.Println(" picoclaw skills install --registry ") + return + } + + // Check for --registry flag. + if os.Args[3] == "--registry" { + if len(os.Args) < 6 { + fmt.Println("Usage: picoclaw skills install --registry ") + fmt.Println("Example: picoclaw skills install --registry clawhub github") + return + } + registryName := os.Args[4] + slug := os.Args[5] + skillsInstallFromRegistry(cfg, registryName, slug) + return + } + + // Default: install from GitHub (backward compatible). + repo := os.Args[3] + fmt.Printf("Installing skill from %s...\n", repo) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := installer.InstallFromGitHub(ctx, repo); err != nil { + fmt.Printf("\u2717 Failed to install skill: %v\n", err) + os.Exit(1) + } + + fmt.Printf("\u2713 Skill '%s' installed successfully!\n", filepath.Base(repo)) +} + +// skillsInstallFromRegistry installs a skill from a named registry (e.g. clawhub). +func skillsInstallFromRegistry(cfg *config.Config, registryName, slug string) { + err := utils.ValidateSkillIdentifier(registryName) + if err != nil { + fmt.Printf("\u2717 Invalid registry name: %v\n", err) + os.Exit(1) + } + + err = utils.ValidateSkillIdentifier(slug) + if err != nil { + fmt.Printf("\u2717 Invalid slug: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Installing skill '%s' from %s registry...\n", slug, registryName) + + registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ + MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, + ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), + }) + + registry := registryMgr.GetRegistry(registryName) + if registry == nil { + fmt.Printf("\u2717 Registry '%s' not found or not enabled. Check your config.json.\n", registryName) + os.Exit(1) + } + + workspace := cfg.WorkspacePath() + targetDir := filepath.Join(workspace, "skills", slug) + + if _, err = os.Stat(targetDir); err == nil { + fmt.Printf("\u2717 Skill '%s' already installed at %s\n", slug, targetDir) + os.Exit(1) + } + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + if err = os.MkdirAll(filepath.Join(workspace, "skills"), 0o755); err != nil { + fmt.Printf("\u2717 Failed to create skills directory: %v\n", err) + os.Exit(1) + } + + result, err := registry.DownloadAndInstall(ctx, slug, "", targetDir) + if err != nil { + rmErr := os.RemoveAll(targetDir) + if rmErr != nil { + fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr) + } + fmt.Printf("\u2717 Failed to install skill: %v\n", err) + os.Exit(1) + } + + if result.IsMalwareBlocked { + rmErr := os.RemoveAll(targetDir) + if rmErr != nil { + fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr) + } + fmt.Printf("\u2717 Skill '%s' is flagged as malicious and cannot be installed.\n", slug) + os.Exit(1) + } + + if result.IsSuspicious { + fmt.Printf("\u26a0\ufe0f Warning: skill '%s' is flagged as suspicious.\n", slug) + } + + fmt.Printf("\u2713 Skill '%s' v%s installed successfully!\n", slug, result.Version) + if result.Summary != "" { + fmt.Printf(" %s\n", result.Summary) + } +} + +func skillsRemoveCmd(installer *skills.SkillInstaller, skillName string) { + fmt.Printf("Removing skill '%s'...\n", skillName) + + if err := installer.Uninstall(skillName); err != nil { + fmt.Printf("✗ Failed to remove skill: %v\n", err) + os.Exit(1) + } + + fmt.Printf("✓ Skill '%s' removed successfully!\n", skillName) +} + +func skillsInstallBuiltinCmd(workspace string) { + builtinSkillsDir := "./picoclaw/skills" + workspaceSkillsDir := filepath.Join(workspace, "skills") + + fmt.Printf("Copying builtin skills to workspace...\n") + + skillsToInstall := []string{ + "weather", + "news", + "stock", + "calculator", + } + + for _, skillName := range skillsToInstall { + builtinPath := filepath.Join(builtinSkillsDir, skillName) + workspacePath := filepath.Join(workspaceSkillsDir, skillName) + + if _, err := os.Stat(builtinPath); err != nil { + fmt.Printf("⊘ Builtin skill '%s' not found: %v\n", skillName, err) + continue + } + + if err := os.MkdirAll(workspacePath, 0o755); err != nil { + fmt.Printf("✗ Failed to create directory for %s: %v\n", skillName, err) + continue + } + + if err := copyDirectory(builtinPath, workspacePath); err != nil { + fmt.Printf("✗ Failed to copy %s: %v\n", skillName, err) + } + } + + fmt.Println("\n✓ All builtin skills installed!") + fmt.Println("Now you can use them in your workspace.") +} + +func skillsListBuiltinCmd() { + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + return + } + builtinSkillsDir := filepath.Join(filepath.Dir(cfg.WorkspacePath()), "picoclaw", "skills") + + fmt.Println("\nAvailable Builtin Skills:") + fmt.Println("-----------------------") + + entries, err := os.ReadDir(builtinSkillsDir) + if err != nil { + fmt.Printf("Error reading builtin skills: %v\n", err) + return + } + + if len(entries) == 0 { + fmt.Println("No builtin skills available.") + return + } + + for _, entry := range entries { + if entry.IsDir() { + skillName := entry.Name() + skillFile := filepath.Join(builtinSkillsDir, skillName, "SKILL.md") + + description := "No description" + if _, err := os.Stat(skillFile); err == nil { + data, err := os.ReadFile(skillFile) + if err == nil { + content := string(data) + if idx := strings.Index(content, "\n"); idx > 0 { + firstLine := content[:idx] + if strings.Contains(firstLine, "description:") { + descLine := strings.Index(content[idx:], "\n") + if descLine > 0 { + description = strings.TrimSpace(content[idx+descLine : idx+descLine]) + } + } + } + } + } + status := "✓" + fmt.Printf(" %s %s\n", status, entry.Name()) + if description != "" { + fmt.Printf(" %s\n", description) + } + } + } +} + +func skillsSearchCmd(installer *skills.SkillInstaller) { + fmt.Println("Searching for available skills...") + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + availableSkills, err := installer.ListAvailableSkills(ctx) + if err != nil { + fmt.Printf("✗ Failed to fetch skills list: %v\n", err) + return + } + + if len(availableSkills) == 0 { + fmt.Println("No skills available.") + return + } + + fmt.Printf("\nAvailable Skills (%d):\n", len(availableSkills)) + fmt.Println("--------------------") + for _, skill := range availableSkills { + fmt.Printf(" 📦 %s\n", skill.Name) + fmt.Printf(" %s\n", skill.Description) + fmt.Printf(" Repo: %s\n", skill.Repository) + if skill.Author != "" { + fmt.Printf(" Author: %s\n", skill.Author) + } + if len(skill.Tags) > 0 { + fmt.Printf(" Tags: %v\n", skill.Tags) + } + fmt.Println() + } +} + +func skillsShowCmd(loader *skills.SkillsLoader, skillName string) { + content, ok := loader.LoadSkill(skillName) + if !ok { + fmt.Printf("✗ Skill '%s' not found\n", skillName) + return + } + + fmt.Printf("\n📦 Skill: %s\n", skillName) + fmt.Println("----------------------") + fmt.Println(content) +} diff --git a/cmd/picoclaw/cmd_status.go b/cmd/picoclaw/cmd_status.go new file mode 100644 index 000000000..07296784e --- /dev/null +++ b/cmd/picoclaw/cmd_status.go @@ -0,0 +1,102 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT + +package main + +import ( + "fmt" + "os" + + "github.com/sipeed/picoclaw/pkg/auth" +) + +func statusCmd() { + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + return + } + + configPath := getConfigPath() + + fmt.Printf("%s picoclaw Status\n", logo) + fmt.Printf("Version: %s\n", formatVersion()) + build, _ := formatBuildInfo() + if build != "" { + fmt.Printf("Build: %s\n", build) + } + fmt.Println() + + if _, err := os.Stat(configPath); err == nil { + fmt.Println("Config:", configPath, "✓") + } else { + fmt.Println("Config:", configPath, "✗") + } + + workspace := cfg.WorkspacePath() + if _, err := os.Stat(workspace); err == nil { + fmt.Println("Workspace:", workspace, "✓") + } else { + fmt.Println("Workspace:", workspace, "✗") + } + + if _, err := os.Stat(configPath); err == nil { + fmt.Printf("Model: %s\n", cfg.Agents.Defaults.Model) + + hasOpenRouter := cfg.Providers.OpenRouter.APIKey != "" + hasAnthropic := cfg.Providers.Anthropic.APIKey != "" + hasOpenAI := cfg.Providers.OpenAI.APIKey != "" + hasGemini := cfg.Providers.Gemini.APIKey != "" + hasZhipu := cfg.Providers.Zhipu.APIKey != "" + hasQwen := cfg.Providers.Qwen.APIKey != "" + hasGroq := cfg.Providers.Groq.APIKey != "" + hasVLLM := cfg.Providers.VLLM.APIBase != "" + hasMoonshot := cfg.Providers.Moonshot.APIKey != "" + hasDeepSeek := cfg.Providers.DeepSeek.APIKey != "" + hasVolcEngine := cfg.Providers.VolcEngine.APIKey != "" + hasNvidia := cfg.Providers.Nvidia.APIKey != "" + hasOllama := cfg.Providers.Ollama.APIBase != "" + + status := func(enabled bool) string { + if enabled { + return "✓" + } + return "not set" + } + fmt.Println("OpenRouter API:", status(hasOpenRouter)) + fmt.Println("Anthropic API:", status(hasAnthropic)) + fmt.Println("OpenAI API:", status(hasOpenAI)) + fmt.Println("Gemini API:", status(hasGemini)) + fmt.Println("Zhipu API:", status(hasZhipu)) + fmt.Println("Qwen API:", status(hasQwen)) + fmt.Println("Groq API:", status(hasGroq)) + fmt.Println("Moonshot API:", status(hasMoonshot)) + fmt.Println("DeepSeek API:", status(hasDeepSeek)) + fmt.Println("VolcEngine API:", status(hasVolcEngine)) + fmt.Println("Nvidia API:", status(hasNvidia)) + if hasVLLM { + fmt.Printf("vLLM/Local: ✓ %s\n", cfg.Providers.VLLM.APIBase) + } else { + fmt.Println("vLLM/Local: not set") + } + if hasOllama { + fmt.Printf("Ollama: ✓ %s\n", cfg.Providers.Ollama.APIBase) + } else { + fmt.Println("Ollama: not set") + } + + store, _ := auth.LoadStore() + if store != nil && len(store.Credentials) > 0 { + fmt.Println("\nOAuth/Token Auth:") + for provider, cred := range store.Credentials { + status := "authenticated" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status) + } + } + } +} diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 8c001109f..2b220caa8 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -7,50 +7,67 @@ package main import ( - "bufio" "context" + "encoding/json" "fmt" "io" "os" + "os/exec" "os/signal" "path/filepath" "runtime" "strings" + "sync" "time" - "github.com/chzyer/readline" + "github.com/nats-io/nats.go" "github.com/sipeed/picoclaw/pkg/agent" - "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/cron" - "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/migrate" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/skills" + "github.com/sipeed/picoclaw/pkg/swarm" "github.com/sipeed/picoclaw/pkg/tools" - "github.com/sipeed/picoclaw/pkg/voice" ) var ( version = "dev" + gitCommit string buildTime string goVersion string ) const logo = "🦞" -func printVersion() { - fmt.Printf("%s picoclaw %s\n", logo, version) +// formatVersion returns the version string with optional git commit +func formatVersion() string { + v := version + if gitCommit != "" { + v += fmt.Sprintf(" (git: %s)", gitCommit) + } + return v +} + +// formatBuildInfo returns build time and go version info +func formatBuildInfo() (build string, goVer string) { if buildTime != "" { - fmt.Printf(" Build: %s\n", buildTime) + build = buildTime } - goVer := goVersion + goVer = goVersion if goVer == "" { goVer = runtime.Version() } + return +} + +func printVersion() { + fmt.Printf("%s picoclaw %s\n", logo, formatVersion()) + build, goVer := formatBuildInfo() + if build != "" { + fmt.Printf(" Build: %s\n", build) + } if goVer != "" { fmt.Printf(" Go: %s\n", goVer) } @@ -113,6 +130,8 @@ func main() { authCmd() case "cron": cronCmd() + case "swarm": + swarmCmd() case "skills": if len(os.Args) < 3 { skillsHelp() @@ -139,7 +158,7 @@ func main() { case "list": skillsListCmd(skillsLoader) case "install": - skillsInstallCmd(installer) + skillsInstallCmd(installer, cfg) case "remove", "uninstall": if len(os.Args) < 4 { fmt.Println("Usage: picoclaw skills remove ") @@ -184,1347 +203,1017 @@ func printHelp() { fmt.Println(" cron Manage scheduled tasks") fmt.Println(" migrate Migrate from OpenClaw to PicoClaw") fmt.Println(" skills Manage skills (install, list, remove)") + fmt.Println(" swarm Run in swarm mode (multi-instance collaboration)") fmt.Println(" version Show version information") } -func onboard() { - configPath := getConfigPath() - - if _, err := os.Stat(configPath); err == nil { - fmt.Printf("Config already exists at %s\n", configPath) - fmt.Print("Overwrite? (y/n): ") - var response string - fmt.Scanln(&response) - if response != "y" { - fmt.Println("Aborted.") - return - } - } +func getConfigPath() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, ".picoclaw", "config.json") +} - cfg := config.DefaultConfig() - if err := config.SaveConfig(configPath, cfg); err != nil { - fmt.Printf("Error saving config: %v\n", err) - os.Exit(1) +func loadConfig() (*config.Config, error) { + return config.LoadConfig(getConfigPath()) +} +func swarmCmd() { + if len(os.Args) < 3 { + swarmHelp() + return } - workspace := cfg.WorkspacePath() - os.MkdirAll(workspace, 0755) - os.MkdirAll(filepath.Join(workspace, "memory"), 0755) - os.MkdirAll(filepath.Join(workspace, "skills"), 0755) - - createWorkspaceTemplates(workspace) + subcommand := os.Args[2] - fmt.Printf("%s picoclaw is ready!\n", logo) - fmt.Println("\nNext steps:") - fmt.Println(" 1. Add your API key to", configPath) - fmt.Println(" Get one at: https://openrouter.ai/keys") - fmt.Println(" 2. Chat: picoclaw agent -m \"Hello!\"") + switch subcommand { + case "start": + swarmStartCmd() + case "stop": + swarmStopCmd() + case "dispatch": + swarmDispatchCmd() + case "status": + swarmStatusCmd() + case "nodes": + swarmNodesCmd() + case "result": + swarmResultCmd() + default: + fmt.Printf("Unknown swarm command: %s\n", subcommand) + swarmHelp() + } } -func createWorkspaceTemplates(workspace string) { - templates := map[string]string{ - "AGENTS.md": `# Agent Instructions - -You are a helpful AI assistant. Be concise, accurate, and friendly. - -## Guidelines - -- Always explain what you're doing before taking actions -- Ask for clarification when request is ambiguous -- Use tools to help accomplish tasks -- Remember important information in your memory files -- Be proactive and helpful -- Learn from user feedback -`, - "SOUL.md": `# Soul - -I am picoclaw, a lightweight AI assistant powered by AI. - -## Personality - -- Helpful and friendly -- Concise and to the point -- Curious and eager to learn -- Honest and transparent - -## Values - -- Accuracy over speed -- User privacy and safety -- Transparency in actions -- Continuous improvement -`, - "USER.md": `# User - -Information about user goes here. - -## Preferences - -- Communication style: (casual/formal) -- Timezone: (your timezone) -- Language: (your preferred language) - -## Personal Information - -- Name: (optional) -- Location: (optional) -- Occupation: (optional) - -## Learning Goals - -- What the user wants to learn from AI -- Preferred interaction style -- Areas of interest -`, - "IDENTITY.md": `# Identity - -## Name -PicoClaw 🦞 - -## Description -Ultra-lightweight personal AI assistant written in Go, inspired by nanobot. - -## Version -0.1.0 +func swarmHelp() { + fmt.Println("\nSwarm commands:") + fmt.Println(" start Start swarm node") + fmt.Println(" stop Stop running swarm node") + fmt.Println(" dispatch Submit a task to the swarm") + fmt.Println(" status Show swarm configuration") + fmt.Println(" nodes List discovered nodes (requires running node)") + fmt.Println() + fmt.Println("Start options:") + fmt.Println(" --role Node role: coordinator, worker, specialist") + fmt.Println(" --capabilities Comma-separated capabilities") + fmt.Println(" --embedded Use embedded NATS server (development mode)") + fmt.Println(" --debug Enable debug logging") + fmt.Println(" --hid Human/Owner ID (tenant identifier)") + fmt.Println(" --sid Shrimp/Service ID (instance identifier)") + fmt.Println(" --identity Both IDs in one parameter") + fmt.Println() + fmt.Println("Dispatch options:") + fmt.Println(" --type Task type: direct, workflow, broadcast") + fmt.Println(" --capability Required capability for routing") + fmt.Println(" --timeout Task timeout in milliseconds") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw swarm start --role coordinator --embedded") + fmt.Println(" picoclaw swarm start --role worker --capabilities code,search") + fmt.Println(" picoclaw swarm start --role worker --hid alice --sid worker1") + fmt.Println(" picoclaw swarm start --role worker --identity alice/worker1") + fmt.Println(" picoclaw swarm dispatch --type direct 'Analyze this code' --capability code") + fmt.Println(" picoclaw swarm status") +} -## Purpose -- Provide intelligent AI assistance with minimal resource usage -- Support multiple LLM providers (OpenAI, Anthropic, Zhipu, etc.) -- Enable easy customization through skills system -- Run on minimal hardware ($10 boards, <10MB RAM) +func swarmStartCmd() { + // Parse flags + role := "worker" + capabilities := []string{} + embedded := false + var hid, sid, natsServer, temporalServer string -## Capabilities + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--role", "-r": + if i+1 < len(args) { + role = args[i+1] + i++ + } + case "--capabilities", "-c": + if i+1 < len(args) { + capabilities = strings.Split(args[i+1], ",") + i++ + } + case "--embedded": + embedded = true + case "--debug", "-d": + logger.SetLevel(logger.DEBUG) + fmt.Println("Debug mode enabled") + case "--hid", "--identity-hid": + if i+1 < len(args) { + hid = args[i+1] + i++ + } + case "--sid", "--identity-sid": + if i+1 < len(args) { + sid = args[i+1] + i++ + } + case "--identity": + if i+1 < len(args) { + // Parse "hid/sid" format + identityParts := strings.SplitN(args[i+1], "/", 2) + hid = identityParts[0] + if len(identityParts) > 1 { + sid = identityParts[1] + } + i++ + } + case "--nats-server", "--nats": + if i+1 < len(args) { + natsServer = args[i+1] + i++ + } + case "--temporal", "--temporal-server": + if i+1 < len(args) { + temporalServer = args[i+1] + i++ + } + } + } -- Web search and content fetching -- File system operations (read, write, edit) -- Shell command execution -- Multi-channel messaging (Telegram, WhatsApp, Feishu) -- Skill-based extensibility -- Memory and context management + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + os.Exit(1) + } -## Philosophy + // Override config with CLI flags + cfg.Swarm.Enabled = true + cfg.Swarm.Role = role + if len(capabilities) > 0 { + cfg.Swarm.Capabilities = capabilities + } + cfg.Swarm.NATS.Embedded = embedded -- Simplicity over complexity -- Performance over features -- User control and privacy -- Transparent operation -- Community-driven development + // Override NATS server if provided + if natsServer != "" { + cfg.Swarm.NATS.URLs = []string{"nats://" + natsServer} + } -## Goals + // Override Temporal server if provided + if temporalServer != "" { + cfg.Swarm.Temporal.Host = temporalServer + } -- Provide a fast, lightweight AI assistant -- Support offline-first operation where possible -- Enable easy customization and extension -- Maintain high quality responses -- Run efficiently on constrained hardware + // Set identity if provided + if hid != "" { + cfg.Swarm.HID = hid + } + if sid != "" { + cfg.Swarm.SID = sid + } -## License -MIT License - Free and open source + // Create provider + provider, _, err := providers.CreateProvider(cfg) + if err != nil { + fmt.Printf("Error creating provider: %v\n", err) + os.Exit(1) + } -## Repository -https://github.com/sipeed/picoclaw + // Create message bus and agent loop + msgBus := bus.NewMessageBus() + agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) -## Contact -Issues: https://github.com/sipeed/picoclaw/issues -Discussions: https://github.com/sipeed/picoclaw/discussions + // Register swarm info tool for worker/coordinator agents + swarmInfoTool := tools.NewSwarmInfoTool() + swarmInfoTool.AddWorker("coordinator", "coordinator", []string{"orchestration", "scheduling"}, "/Users/dev/service/coordinator") + swarmInfoTool.AddWorker("worker-a", "worker", []string{"code", "macos"}, "/Users/dev/service/worker-a") + swarmInfoTool.AddWorker("worker-b", "worker", []string{"search", "windows"}, "/Users/dev/service/worker-b") + agentLoop.RegisterTool(swarmInfoTool) + logger.InfoC("swarm", "Swarm info tool registered for worker") + + // Create and start swarm manager + manager := swarm.NewManager(cfg, agentLoop, provider, msgBus) + if manager == nil { + fmt.Println("Error: failed to create swarm manager (invalid configuration)") + os.Exit(1) + } ---- + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -"Every bit helps, every bit matters." -- Picoclaw -`, + if err := manager.Start(ctx); err != nil { + fmt.Printf("Error starting swarm: %v\n", err) + os.Exit(1) } - for filename, content := range templates { - filePath := filepath.Join(workspace, filename) - if _, err := os.Stat(filePath); os.IsNotExist(err) { - os.WriteFile(filePath, []byte(content), 0644) - fmt.Printf(" Created %s\n", filename) - } + nodeInfo := manager.GetNodeInfo() + fmt.Printf("%s Swarm node started\n", logo) + fmt.Printf(" Node ID: %s\n", nodeInfo.ID) + fmt.Printf(" Role: %s\n", nodeInfo.Role) + fmt.Printf(" Capabilities: %v\n", nodeInfo.Capabilities) + if embedded { + fmt.Println(" Mode: Embedded NATS (development)") } + fmt.Printf(" NATS: %v\n", manager.IsNATSConnected()) + fmt.Printf(" Temporal: %v\n", manager.IsTemporalConnected()) + fmt.Println("\nPress Ctrl+C to stop") - memoryDir := filepath.Join(workspace, "memory") - os.MkdirAll(memoryDir, 0755) - memoryFile := filepath.Join(memoryDir, "MEMORY.md") - if _, err := os.Stat(memoryFile); os.IsNotExist(err) { - memoryContent := `# Long-term Memory - -This file stores important information that should persist across sessions. - -## User Information + // Start agent loop in background + // For coordinator, disable auto-consume since coordinator handles message routing + if role == "coordinator" { + agentLoop.AutoConsume = false + } + go agentLoop.Run(ctx) -(Important facts about user) + // For coordinator role, also start channel manager (Telegram, etc.) + if role == "coordinator" { + channelManager, err := channels.NewManager(cfg, msgBus) + if err != nil { + fmt.Printf("Error creating channel manager: %v\n", err) + os.Exit(1) + } -## Preferences + // Start channels in background + if err := channelManager.StartAll(ctx); err != nil { + fmt.Printf("Error starting channel manager: %v\n", err) + os.Exit(1) + } + defer func() { + channelManager.StopAll(ctx) + }() -(User preferences learned over time) + // Get enabled channels + enabledChannels := channelManager.GetEnabledChannels() + fmt.Printf(" Channels: %v\n", enabledChannels) + } -## Important Notes + // Wait for interrupt + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt) + <-sigChan -(Things to remember) + fmt.Println("\nShutting down...") + cancel() + manager.Stop() + agentLoop.Stop() + fmt.Printf("%s Swarm node stopped\n", logo) +} -## Configuration +func swarmStopCmd() { + fmt.Printf("%s Stopping swarm node...\n", logo) -- Model preferences -- Channel settings -- Skills enabled -` - os.WriteFile(memoryFile, []byte(memoryContent), 0644) - fmt.Println(" Created memory/MEMORY.md") + // Find and stop swarm processes + pids, err := findSwarmProcesses() + if err != nil { + fmt.Printf("Error finding swarm processes: %v\n", err) + return + } - skillsDir := filepath.Join(workspace, "skills") - if _, err := os.Stat(skillsDir); os.IsNotExist(err) { - os.MkdirAll(skillsDir, 0755) - fmt.Println(" Created skills/") - } + if len(pids) == 0 { + fmt.Println("No running swarm nodes found") + return } - for filename, content := range templates { - filePath := filepath.Join(workspace, filename) - if _, err := os.Stat(filePath); os.IsNotExist(err) { - os.WriteFile(filePath, []byte(content), 0644) - fmt.Printf(" Created %s\n", filename) + fmt.Printf("Found %d swarm node(s)\n", len(pids)) + for _, pid := range pids { + fmt.Printf(" Stopping PID %d...\n", pid) + if err := stopProcess(pid); err != nil { + fmt.Printf(" Error: %v\n", err) + } else { + fmt.Printf(" Stopped\n") } } + fmt.Printf("%s Swarm node(s) stopped\n", logo) } -func migrateCmd() { - if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") { - migrateHelp() +func swarmDispatchCmd() { + if len(os.Args) < 4 { + fmt.Println("Usage: picoclaw swarm dispatch [options]") + fmt.Println() + fmt.Println("Options:") + fmt.Println(" --type Task type: direct, workflow, broadcast (default: workflow)") + fmt.Println(" --capability Required capability for routing") + fmt.Println(" --timeout Task timeout in milliseconds (default: 300000)") + fmt.Println(" --wait, -w Wait for result and display it") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw swarm dispatch 'Analyze all node files' --type workflow") + fmt.Println(" picoclaw swarm dispatch 'Read coordinator-info.txt and worker-a-info.txt in parallel' --wait") return } - opts := migrate.Options{} + // Parse arguments + taskType := "workflow" // 默认使用 workflow 来启用任务拆分 + capability := "general" + timeout := 600000 // 10 minutes (默认超时) + prompt := "" + waitForResult := false - args := os.Args[2:] + args := os.Args[3:] for i := 0; i < len(args); i++ { switch args[i] { - case "--dry-run": - opts.DryRun = true - case "--config-only": - opts.ConfigOnly = true - case "--workspace-only": - opts.WorkspaceOnly = true - case "--force": - opts.Force = true - case "--refresh": - opts.Refresh = true - case "--openclaw-home": + case "--type": + if i+1 < len(args) { + taskType = args[i+1] + i++ + } + case "--capability", "-c": if i+1 < len(args) { - opts.OpenClawHome = args[i+1] + capability = args[i+1] i++ } - case "--picoclaw-home": + case "--timeout", "-t": if i+1 < len(args) { - opts.PicoClawHome = args[i+1] + var ms int + fmt.Sscanf(args[i+1], "%d", &ms) + timeout = ms i++ } + case "--wait", "-w": + waitForResult = true default: - fmt.Printf("Unknown flag: %s\n", args[i]) - migrateHelp() - os.Exit(1) + if prompt == "" { + prompt = args[i] + } } } - result, err := migrate.Run(opts) + if prompt == "" { + fmt.Println("Error: prompt is required") + return + } + + // Load config + cfg, err := loadConfig() if err != nil { - fmt.Printf("Error: %v\n", err) - os.Exit(1) + fmt.Printf("Error loading config: %v\n", err) + return } - if !opts.DryRun { - migrate.PrintSummary(result) + // For workflow type, use Temporal + if taskType == "workflow" { + dispatchWorkflowTask(cfg, prompt, capability, timeout, waitForResult) + return } + + // For direct type, execute locally (fallback) + dispatchLocalTask(cfg, prompt, capability, timeout) } -func migrateHelp() { - fmt.Println("\nMigrate from OpenClaw to PicoClaw") +func dispatchWorkflowTask(cfg *config.Config, prompt, capability string, timeout int, waitForResult bool) { + fmt.Printf("%s Dispatching workflow task...\n", logo) + fmt.Printf(" Type: workflow (with decomposition)") + fmt.Printf(" Capability: %s\n", capability) + fmt.Printf(" Timeout: %d ms\n", timeout) + fmt.Printf(" Prompt: %s\n", truncateForDisplay(prompt, 60)) fmt.Println() - fmt.Println("Usage: picoclaw migrate [options]") - fmt.Println() - fmt.Println("Options:") - fmt.Println(" --dry-run Show what would be migrated without making changes") - fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)") - fmt.Println(" --config-only Only migrate config, skip workspace files") - fmt.Println(" --workspace-only Only migrate workspace files, skip config") - fmt.Println(" --force Skip confirmation prompts") - fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)") - fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)") - fmt.Println() - fmt.Println("Examples:") - fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw") - fmt.Println(" picoclaw migrate --dry-run Show what would be migrated") - fmt.Println(" picoclaw migrate --refresh Re-sync workspace files") - fmt.Println(" picoclaw migrate --force Migrate without confirmation") -} -func agentCmd() { - message := "" - sessionKey := "cli:default" + // Import temporal client packages + // We'll use go-temporal client to start workflow + workflowID := fmt.Sprintf("task-%d", time.Now().UnixNano()) - args := os.Args[2:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "--debug", "-d": - logger.SetLevel(logger.DEBUG) - fmt.Println("🔍 Debug mode enabled") - case "-m", "--message": - if i+1 < len(args) { - message = args[i+1] - i++ - } - case "-s", "--session": - if i+1 < len(args) { - sessionKey = args[i+1] - i++ - } - } - } + // Create task JSON + taskJSON := fmt.Sprintf(`{"id":"%s","prompt":"%s","capability":"%s","type":"workflow"}`, + workflowID, escapeJSON(prompt), capability) - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - os.Exit(1) - } + // Use temporal CLI to start workflow + cmd := exec.Command("temporal", "workflow", "start", + "--address", cfg.Swarm.Temporal.Host, + "--namespace", cfg.Swarm.Temporal.Namespace, + "--task-queue", cfg.Swarm.Temporal.TaskQueue, + "--type", "SwarmWorkflow", + "--input", taskJSON, + "--workflow-id", workflowID) - provider, err := providers.CreateProvider(cfg) + output, err := cmd.CombinedOutput() if err != nil { - fmt.Printf("Error creating provider: %v\n", err) - os.Exit(1) + fmt.Printf("Error starting workflow: %v\n", err) + fmt.Printf("Output: %s\n", string(output)) + return } - msgBus := bus.NewMessageBus() - agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + fmt.Printf("\n✓ Workflow started\n") + fmt.Printf(" Workflow ID: %s\n", workflowID) + fmt.Printf(" Temporal UI: http://localhost:8088/namespaces/%s/workflows/%s\n", + cfg.Swarm.Temporal.Namespace, workflowID) - // Print agent startup info (only for interactive mode) - startupInfo := agentLoop.GetStartupInfo() - logger.InfoCF("agent", "Agent initialized", - map[string]interface{}{ - "tools_count": startupInfo["tools"].(map[string]interface{})["count"], - "skills_total": startupInfo["skills"].(map[string]interface{})["total"], - "skills_available": startupInfo["skills"].(map[string]interface{})["available"], - }) - - if message != "" { - ctx := context.Background() - response, err := agentLoop.ProcessDirect(ctx, message, sessionKey) - if err != nil { - fmt.Printf("Error: %v\n", err) - os.Exit(1) - } - fmt.Printf("\n%s %s\n", logo, response) + if waitForResult { + fmt.Printf("\n⏳ Waiting for result...\n") + waitForWorkflowCompletion(cfg, workflowID, timeout*2) // Double timeout for wait mode } else { - fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", logo) - interactiveMode(agentLoop, sessionKey) + fmt.Printf("\n💡 Use 'temporal workflow describe %s' to check status\n", workflowID) + fmt.Printf("💡 Use 'picoclaw swarm result %s' to get result\n", workflowID) } } -func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) { - prompt := fmt.Sprintf("%s You: ", logo) - - rl, err := readline.NewEx(&readline.Config{ - Prompt: prompt, - HistoryFile: filepath.Join(os.TempDir(), ".picoclaw_history"), - HistoryLimit: 100, - InterruptPrompt: "^C", - EOFPrompt: "exit", - }) +func waitForWorkflowCompletion(cfg *config.Config, workflowID string, timeout int) { + start := time.Now() + timeoutDuration := time.Duration(timeout) * time.Millisecond - if err != nil { - fmt.Printf("Error initializing readline: %v\n", err) - fmt.Println("Falling back to simple input mode...") - simpleInteractiveMode(agentLoop, sessionKey) - return - } - defer rl.Close() + for time.Since(start) < timeoutDuration { + cmd := exec.Command("temporal", "workflow", "describe", + "--address", cfg.Swarm.Temporal.Host, + "--namespace", cfg.Swarm.Temporal.Namespace, + "--workflow-id", workflowID, + "--output", "json") - for { - line, err := rl.Readline() + output, err := cmd.Output() if err != nil { - if err == readline.ErrInterrupt || err == io.EOF { - fmt.Println("\nGoodbye!") - return - } - fmt.Printf("Error reading input: %v\n", err) + time.Sleep(2 * time.Second) continue } - input := strings.TrimSpace(line) - if input == "" { - continue + // Parse JSON to check status + var result map[string]interface{} + if err := json.Unmarshal(output, &result); err == nil { + if status, ok := result["workflowExecutionInfo"].(map[string]interface{})["status"].(string); ok { + if status == "COMPLETED" { + // Try multiple ways to extract result + if rawResult, ok := result["result"].(map[string]interface{}); ok { + if value, ok := rawResult["value"].(string); ok { + fmt.Printf("\n%s Result:\n", logo) + fmt.Println(value) + return + } + if data, ok := rawResult["data"].(string); ok { + fmt.Printf("\n%s Result:\n", logo) + fmt.Println(data) + return + } + } + fmt.Printf("\n%s Result:\n", logo) + fmt.Printf(" (Completed - use Temporal UI for full output)\n") + return + } else if status == "FAILED" { + fmt.Printf("\n❌ Workflow failed\n") + return + } else if status == "CANCELED" { + fmt.Printf("\n❌ Workflow canceled\n") + return + } + // Still running + fmt.Printf(".") + time.Sleep(2 * time.Second) + } } + } + fmt.Printf("\n⏱ Timeout waiting for result\n") + fmt.Printf("💡 Check status: temporal workflow describe --namespace %s %s\n", + cfg.Swarm.Temporal.Namespace, workflowID) +} - if input == "exit" || input == "quit" { - fmt.Println("Goodbye!") - return - } +func dispatchLocalTask(cfg *config.Config, prompt, capability string, timeout int) { + // Create provider and agent loop + provider, _, err := providers.CreateProvider(cfg) + if err != nil { + fmt.Printf("Error creating provider: %v\n", err) + return + } - ctx := context.Background() - response, err := agentLoop.ProcessDirect(ctx, input, sessionKey) - if err != nil { - fmt.Printf("Error: %v\n", err) - continue - } + msgBus := bus.NewMessageBus() + agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) - fmt.Printf("\n%s %s\n\n", logo, response) - } -} + // Execute task locally + fmt.Printf("%s Executing task locally...\n", logo) + fmt.Printf(" Capability: %s\n", capability) + fmt.Printf(" Timeout: %d ms\n", timeout) + fmt.Printf(" Prompt: %s\n", prompt) + fmt.Println() -func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) { - reader := bufio.NewReader(os.Stdin) - for { - fmt.Print(fmt.Sprintf("%s You: ", logo)) - line, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - fmt.Println("\nGoodbye!") - return - } - fmt.Printf("Error reading input: %v\n", err) - continue - } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Millisecond) + defer cancel() - input := strings.TrimSpace(line) - if input == "" { - continue - } + response, err := agentLoop.ProcessDirect(ctx, prompt, "swarm:dispatch") - if input == "exit" || input == "quit" { - fmt.Println("Goodbye!") - return - } + if err != nil { + fmt.Printf("Error: %v\n", err) + return + } - ctx := context.Background() - response, err := agentLoop.ProcessDirect(ctx, input, sessionKey) - if err != nil { - fmt.Printf("Error: %v\n", err) - continue - } + fmt.Printf("\n%s Result:\n", logo) + fmt.Println(response) +} - fmt.Printf("\n%s %s\n\n", logo, response) +func escapeJSON(s string) string { + s = strings.ReplaceAll(s, "\\", "\\\\") + s = strings.ReplaceAll(s, "\"", "\\\"") + s = strings.ReplaceAll(s, "\n", "\\n") + return s +} + +func truncateForDisplay(s string, maxLen int) string { + if len(s) <= maxLen { + return s } + return s[:maxLen] + "..." } -func gatewayCmd() { - // Check for --debug flag - args := os.Args[2:] - for _, arg := range args { - if arg == "--debug" || arg == "-d" { - logger.SetLevel(logger.DEBUG) - fmt.Println("🔍 Debug mode enabled") - break - } +func swarmResultCmd() { + if len(os.Args) < 4 { + fmt.Println("Usage: picoclaw swarm result ") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw swarm result task-1234567890") + return } + workflowID := os.Args[3] + cfg, err := loadConfig() if err != nil { fmt.Printf("Error loading config: %v\n", err) - os.Exit(1) + return } - provider, err := providers.CreateProvider(cfg) + cmd := exec.Command("temporal", "workflow", "describe", + "--address", cfg.Swarm.Temporal.Host, + "--namespace", cfg.Swarm.Temporal.Namespace, + "--workflow-id", workflowID, + "--output", "json") + + output, err := cmd.Output() if err != nil { - fmt.Printf("Error creating provider: %v\n", err) - os.Exit(1) + fmt.Printf("Error fetching workflow: %v\n", err) + fmt.Printf("Make sure the workflow ID is correct.\n") + return } - msgBus := bus.NewMessageBus() - agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) - - // Print agent startup info - fmt.Println("\n📦 Agent Status:") - startupInfo := agentLoop.GetStartupInfo() - toolsInfo := startupInfo["tools"].(map[string]interface{}) - skillsInfo := startupInfo["skills"].(map[string]interface{}) - fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"]) - fmt.Printf(" • Skills: %d/%d available\n", - skillsInfo["available"], - skillsInfo["total"]) - - // Log to file as well - logger.InfoCF("agent", "Agent initialized", - map[string]interface{}{ - "tools_count": toolsInfo["count"], - "skills_total": skillsInfo["total"], - "skills_available": skillsInfo["available"], - }) - - // Setup cron tool and service - cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath()) - - heartbeatService := heartbeat.NewHeartbeatService( - cfg.WorkspacePath(), - cfg.Heartbeat.Interval, - cfg.Heartbeat.Enabled, - ) - heartbeatService.SetBus(msgBus) - heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { - // Use cli:direct as fallback if no valid channel - if channel == "" || chatID == "" { - channel, chatID = "cli", "direct" - } - // Use ProcessHeartbeat - no session history, each heartbeat is independent - response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID) - if err != nil { - return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err)) - } - if response == "HEARTBEAT_OK" { - return tools.SilentResult("Heartbeat OK") - } - // For heartbeat, always return silent - the subagent result will be - // sent to user via processSystemMessage when the async task completes - return tools.SilentResult(response) - }) + var result map[string]interface{} + if err := json.Unmarshal(output, &result); err != nil { + fmt.Printf("Error parsing result: %v\n", err) + return + } - channelManager, err := channels.NewManager(cfg, msgBus) - if err != nil { - fmt.Printf("Error creating channel manager: %v\n", err) - os.Exit(1) + info, ok := result["workflowExecutionInfo"].(map[string]interface{}) + if !ok { + fmt.Printf("Error: invalid response format\n") + return } - var transcriber *voice.GroqTranscriber - if cfg.Providers.Groq.APIKey != "" { - transcriber = voice.NewGroqTranscriber(cfg.Providers.Groq.APIKey) - logger.InfoC("voice", "Groq voice transcription enabled") + status, _ := info["status"].(string) + + fmt.Printf("%s Workflow Result\n\n", logo) + fmt.Printf("Workflow ID: %s\n", workflowID) + fmt.Printf("Status: %s\n", status) + + if startTime, ok := info["startTime"].(string); ok { + fmt.Printf("Started: %s\n", startTime) } - if transcriber != nil { - if telegramChannel, ok := channelManager.GetChannel("telegram"); ok { - if tc, ok := telegramChannel.(*channels.TelegramChannel); ok { - tc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Telegram channel") - } - } - if discordChannel, ok := channelManager.GetChannel("discord"); ok { - if dc, ok := discordChannel.(*channels.DiscordChannel); ok { - dc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Discord channel") + if status == "COMPLETED" { + if res, ok := result["result"].(map[string]interface{}); ok { + if rawValue, ok := res["raw"].(string); ok { + // Try to decode base64 if present + fmt.Printf("\n--- Result ---\n%s\n--- End ---\n", rawValue) + } else if value, ok := res["value"].(string); ok { + fmt.Printf("\n--- Result ---\n%s\n--- End ---\n", value) + } else if data, ok := res["data"].(string); ok { + fmt.Printf("\n--- Result ---\n%s\n--- End ---\n", data) + } else { + fmt.Printf("\n--- Result ---\n%+v\n--- End ---\n", res) } } - if slackChannel, ok := channelManager.GetChannel("slack"); ok { - if sc, ok := slackChannel.(*channels.SlackChannel); ok { - sc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Slack channel") + } else if status == "FAILED" { + if res, ok := result["result"].(map[string]interface{}); ok { + if value, ok := res["value"].(string); ok { + fmt.Printf("\n--- Error ---\n%s\n--- End ---\n", value) } } + } else if status == "RUNNING" { + fmt.Printf("\n⏳ Workflow is still running...\n") + fmt.Printf("Use --wait flag to wait for completion:\n") + fmt.Printf(" picoclaw swarm result %s --wait\n", workflowID) } - enabledChannels := channelManager.GetEnabledChannels() - if len(enabledChannels) > 0 { - fmt.Printf("✓ Channels enabled: %s\n", enabledChannels) - } else { - fmt.Println("⚠ Warning: No channels enabled") - } + fmt.Printf("\nMore info:\n") + fmt.Printf(" temporal workflow describe --namespace %s %s\n", cfg.Swarm.Temporal.Namespace, workflowID) + fmt.Printf(" http://localhost:8088/namespaces/%s/workflows/%s\n", cfg.Swarm.Temporal.Namespace, workflowID) +} - fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port) - fmt.Println("Press Ctrl+C to stop") +// Helper functions for process management - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - if err := cronService.Start(); err != nil { - fmt.Printf("Error starting cron service: %v\n", err) +func findSwarmProcesses() ([]int, error) { + cmd := exec.Command("pgrep", "-f", "picoclaw swarm start") + output, err := cmd.Output() + if err != nil { + return nil, err } - fmt.Println("✓ Cron service started") - if err := heartbeatService.Start(); err != nil { - fmt.Printf("Error starting heartbeat service: %v\n", err) + var pids []int + lines := strings.Split(string(output), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + var pid int + if _, err := fmt.Sscanf(line, "%d", &pid); err == nil { + pids = append(pids, pid) + } } - fmt.Println("✓ Heartbeat service started") + return pids, nil +} - if err := channelManager.StartAll(ctx); err != nil { - fmt.Printf("Error starting channels: %v\n", err) +func stopProcess(pid int) error { + proc, err := os.FindProcess(pid) + if err != nil { + return err } - - go agentLoop.Run(ctx) - - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt) - <-sigChan - - fmt.Println("\nShutting down...") - cancel() - heartbeatService.Stop() - cronService.Stop() - agentLoop.Stop() - channelManager.StopAll(ctx) - fmt.Println("✓ Gateway stopped") + return proc.Signal(os.Interrupt) } -func statusCmd() { +func swarmStatusCmd() { cfg, err := loadConfig() if err != nil { fmt.Printf("Error loading config: %v\n", err) return } - configPath := getConfigPath() - - fmt.Printf("%s picoclaw Status\n\n", logo) - - if _, err := os.Stat(configPath); err == nil { - fmt.Println("Config:", configPath, "✓") - } else { - fmt.Println("Config:", configPath, "✗") - } - - workspace := cfg.WorkspacePath() - if _, err := os.Stat(workspace); err == nil { - fmt.Println("Workspace:", workspace, "✓") - } else { - fmt.Println("Workspace:", workspace, "✗") - } - - if _, err := os.Stat(configPath); err == nil { - fmt.Printf("Model: %s\n", cfg.Agents.Defaults.Model) - - hasOpenRouter := cfg.Providers.OpenRouter.APIKey != "" - hasAnthropic := cfg.Providers.Anthropic.APIKey != "" - hasOpenAI := cfg.Providers.OpenAI.APIKey != "" - hasGemini := cfg.Providers.Gemini.APIKey != "" - hasZhipu := cfg.Providers.Zhipu.APIKey != "" - hasGroq := cfg.Providers.Groq.APIKey != "" - hasVLLM := cfg.Providers.VLLM.APIBase != "" - - status := func(enabled bool) string { - if enabled { - return "✓" - } - return "not set" - } - fmt.Println("OpenRouter API:", status(hasOpenRouter)) - fmt.Println("Anthropic API:", status(hasAnthropic)) - fmt.Println("OpenAI API:", status(hasOpenAI)) - fmt.Println("Gemini API:", status(hasGemini)) - fmt.Println("Zhipu API:", status(hasZhipu)) - fmt.Println("Groq API:", status(hasGroq)) - if hasVLLM { - fmt.Printf("vLLM/Local: ✓ %s\n", cfg.Providers.VLLM.APIBase) - } else { - fmt.Println("vLLM/Local: not set") - } - - store, _ := auth.LoadStore() - if store != nil && len(store.Credentials) > 0 { - fmt.Println("\nOAuth/Token Auth:") - for provider, cred := range store.Credentials { - status := "authenticated" - if cred.IsExpired() { - status = "expired" - } else if cred.NeedsRefresh() { - status = "needs refresh" - } - fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status) - } - } - } -} - -func authCmd() { - if len(os.Args) < 3 { - authHelp() - return - } - - switch os.Args[2] { - case "login": - authLoginCmd() - case "logout": - authLogoutCmd() - case "status": - authStatusCmd() - default: - fmt.Printf("Unknown auth command: %s\n", os.Args[2]) - authHelp() - } + fmt.Printf("%s Swarm Configuration\n\n", logo) + fmt.Printf("Enabled: %v\n", cfg.Swarm.Enabled) + fmt.Printf("Node ID: %s\n", cfg.Swarm.NodeID) + fmt.Printf("Role: %s\n", cfg.Swarm.Role) + fmt.Printf("Capabilities: %v\n", cfg.Swarm.Capabilities) + fmt.Printf("Max Concurrent: %d\n", cfg.Swarm.MaxConcurrent) + fmt.Println("\nNATS:") + fmt.Printf(" URLs: %v\n", cfg.Swarm.NATS.URLs) + fmt.Printf(" Embedded: %v\n", cfg.Swarm.NATS.Embedded) + fmt.Printf(" Heartbeat: %s\n", cfg.Swarm.NATS.HeartbeatInterval) + fmt.Printf(" Node Timeout: %s\n", cfg.Swarm.NATS.NodeTimeout) + fmt.Println("\nTemporal:") + fmt.Printf(" Host: %s\n", cfg.Swarm.Temporal.Host) + fmt.Printf(" Namespace: %s\n", cfg.Swarm.Temporal.Namespace) + fmt.Printf(" Task Queue: %s\n", cfg.Swarm.Temporal.TaskQueue) } -func authHelp() { - fmt.Println("\nAuth commands:") - fmt.Println(" login Login via OAuth or paste token") - fmt.Println(" logout Remove stored credentials") - fmt.Println(" status Show current auth status") - fmt.Println() - fmt.Println("Login options:") - fmt.Println(" --provider Provider to login with (openai, anthropic)") - fmt.Println(" --device-code Use device code flow (for headless environments)") - fmt.Println() - fmt.Println("Examples:") - fmt.Println(" picoclaw auth login --provider openai") - fmt.Println(" picoclaw auth login --provider openai --device-code") - fmt.Println(" picoclaw auth login --provider anthropic") - fmt.Println(" picoclaw auth logout --provider openai") - fmt.Println(" picoclaw auth status") -} - -func authLoginCmd() { - provider := "" - useDeviceCode := false - - args := os.Args[3:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "--provider", "-p": - if i+1 < len(args) { - provider = args[i+1] - i++ - } - case "--device-code": - useDeviceCode = true - } - } - - if provider == "" { - fmt.Println("Error: --provider is required") - fmt.Println("Supported providers: openai, anthropic") +func swarmNodesCmd() { + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) return } - switch provider { - case "openai": - authLoginOpenAI(useDeviceCode) - case "anthropic": - authLoginPasteToken(provider) - default: - fmt.Printf("Unsupported provider: %s\n", provider) - fmt.Println("Supported providers: openai, anthropic") + // Get NATS URL from config or use default + natsURL := "nats://localhost:4222" + if len(cfg.Swarm.NATS.URLs) > 0 { + natsURL = cfg.Swarm.NATS.URLs[0] } -} -func authLoginOpenAI(useDeviceCode bool) { - cfg := auth.OpenAIOAuthConfig() - - var cred *auth.AuthCredential - var err error - - if useDeviceCode { - cred, err = auth.LoginDeviceCode(cfg) - } else { - cred, err = auth.LoginBrowser(cfg) - } + // Get HID for filtering + hid := cfg.Swarm.HID + // Connect to NATS + nc, err := nats.Connect(natsURL, + nats.Timeout(5*time.Second), + nats.ReconnectWait(100*time.Millisecond), + nats.MaxReconnects(2), + ) if err != nil { - fmt.Printf("Login failed: %v\n", err) - os.Exit(1) + fmt.Printf("%s Swarm Nodes\n\n", logo) + fmt.Printf("Failed to connect to NATS at %s\n", natsURL) + fmt.Printf("Error: %v\n\n", err) + fmt.Println("Make sure swarm nodes are running:") + fmt.Println(" pm2 status") + return } + defer nc.Close() - if err := auth.SetCredential("openai", cred); err != nil { - fmt.Printf("Failed to save credentials: %v\n", err) - os.Exit(1) - } + // Wait a bit for connection to be fully established + time.Sleep(100 * time.Millisecond) - appCfg, err := loadConfig() - if err == nil { - appCfg.Providers.OpenAI.AuthMethod = "oauth" - if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { - fmt.Printf("Warning: could not update config: %v\n", err) - } + // Node info structures matching the swarm package + type Heartbeat struct { + NodeID string `json:"node_id"` + Timestamp int64 `json:"timestamp"` + Load float64 `json:"load"` + TasksRunning int `json:"tasks_running"` + Status string `json:"status"` + Capabilities []string `json:"capabilities"` } - fmt.Println("Login successful!") - if cred.AccountID != "" { - fmt.Printf("Account: %s\n", cred.AccountID) + type NodeInfo struct { + ID string `json:"id"` + NodeID string `json:"node_id"` + Role string `json:"role"` + Capabilities []string `json:"capabilities"` + Status string `json:"status"` + Load float64 `json:"load"` + TasksRunning int `json:"tasks_running"` + MaxTasks int `json:"max_tasks"` + Model string `json:"model"` + HID string `json:"hid"` + SID string `json:"sid"` } -} -func authLoginPasteToken(provider string) { - cred, err := auth.LoginPasteToken(provider, os.Stdin) - if err != nil { - fmt.Printf("Login failed: %v\n", err) - os.Exit(1) + type DiscoveryAnnounce struct { + Node NodeInfo `json:"node"` + Timestamp int64 `json:"timestamp"` } - if err := auth.SetCredential(provider, cred); err != nil { - fmt.Printf("Failed to save credentials: %v\n", err) - os.Exit(1) - } + nodes := make(map[string]NodeInfo) + var mu sync.Mutex - appCfg, err := loadConfig() - if err == nil { - switch provider { - case "anthropic": - appCfg.Providers.Anthropic.AuthMethod = "token" - case "openai": - appCfg.Providers.OpenAI.AuthMethod = "token" - } - if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { - fmt.Printf("Warning: could not update config: %v\n", err) + // Subscribe to heartbeat messages (picoclaw.swarm.heartbeat.>) + sub1, err := nc.Subscribe("picoclaw.swarm.heartbeat.>", func(msg *nats.Msg) { + var hb Heartbeat + if err := json.Unmarshal(msg.Data, &hb); err != nil { + return } - } - - fmt.Printf("Token saved for %s!\n", provider) -} - -func authLogoutCmd() { - provider := "" - - args := os.Args[3:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "--provider", "-p": - if i+1 < len(args) { - provider = args[i+1] - i++ + mu.Lock() + // Update existing node or add new one + if node, ok := nodes[hb.NodeID]; ok { + node.Load = hb.Load + node.TasksRunning = hb.TasksRunning + node.Status = hb.Status + if len(hb.Capabilities) > 0 { + node.Capabilities = hb.Capabilities } + nodes[hb.NodeID] = node } + mu.Unlock() + }) + if err == nil { + defer sub1.Unsubscribe() } - if provider != "" { - if err := auth.DeleteCredential(provider); err != nil { - fmt.Printf("Failed to remove credentials: %v\n", err) - os.Exit(1) + // Subscribe to discovery announce messages (picoclaw.swarm.discovery.announce) + sub2, err := nc.Subscribe("picoclaw.swarm.discovery.announce", func(msg *nats.Msg) { + var announce DiscoveryAnnounce + if err := json.Unmarshal(msg.Data, &announce); err != nil { + return } + node := announce.Node - appCfg, err := loadConfig() - if err == nil { - switch provider { - case "openai": - appCfg.Providers.OpenAI.AuthMethod = "" - case "anthropic": - appCfg.Providers.Anthropic.AuthMethod = "" - } - config.SaveConfig(getConfigPath(), appCfg) + // Use ID or NodeID field + nodeID := node.ID + if nodeID == "" { + nodeID = node.NodeID } - fmt.Printf("Logged out from %s\n", provider) - } else { - if err := auth.DeleteAllCredentials(); err != nil { - fmt.Printf("Failed to remove credentials: %v\n", err) - os.Exit(1) + // Filter by HID if specified + if hid != "" && node.HID != hid { + return } - appCfg, err := loadConfig() - if err == nil { - appCfg.Providers.OpenAI.AuthMethod = "" - appCfg.Providers.Anthropic.AuthMethod = "" - config.SaveConfig(getConfigPath(), appCfg) + // Ensure NodeID is set for lookup + if node.NodeID == "" { + node.NodeID = nodeID } - fmt.Println("Logged out from all providers") - } -} - -func authStatusCmd() { - store, err := auth.LoadStore() - if err != nil { - fmt.Printf("Error loading auth store: %v\n", err) - return - } - - if len(store.Credentials) == 0 { - fmt.Println("No authenticated providers.") - fmt.Println("Run: picoclaw auth login --provider ") - return - } - - fmt.Println("\nAuthenticated Providers:") - fmt.Println("------------------------") - for provider, cred := range store.Credentials { - status := "active" - if cred.IsExpired() { - status = "expired" - } else if cred.NeedsRefresh() { - status = "needs refresh" - } - - fmt.Printf(" %s:\n", provider) - fmt.Printf(" Method: %s\n", cred.AuthMethod) - fmt.Printf(" Status: %s\n", status) - if cred.AccountID != "" { - fmt.Printf(" Account: %s\n", cred.AccountID) - } - if !cred.ExpiresAt.IsZero() { - fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04")) + mu.Lock() + // Merge with existing node info if any + if existing, ok := nodes[nodeID]; ok { + // Keep heartbeat-updated fields + if existing.Load > 0 { + node.Load = existing.Load + } + if existing.TasksRunning > 0 { + node.TasksRunning = existing.TasksRunning + } } - } -} - -func getConfigPath() string { - home, _ := os.UserHomeDir() - return filepath.Join(home, ".picoclaw", "config.json") -} - -func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string) *cron.CronService { - cronStorePath := filepath.Join(workspace, "cron", "jobs.json") - - // Create cron service - cronService := cron.NewCronService(cronStorePath, nil) - - // Create and register CronTool - cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace) - agentLoop.RegisterTool(cronTool) - - // Set the onJob handler - cronService.SetOnJob(func(job *cron.CronJob) (string, error) { - result := cronTool.ExecuteJob(context.Background(), job) - return result, nil + nodes[nodeID] = node + mu.Unlock() }) - - return cronService -} - -func loadConfig() (*config.Config, error) { - return config.LoadConfig(getConfigPath()) -} - -func cronCmd() { - if len(os.Args) < 3 { - cronHelp() - return - } - - subcommand := os.Args[2] - - // Load config to get workspace path - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - return - } - - cronStorePath := filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json") - - switch subcommand { - case "list": - cronListCmd(cronStorePath) - case "add": - cronAddCmd(cronStorePath) - case "remove": - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw cron remove ") + if err == nil { + defer sub2.Unsubscribe() + } + + // Also subscribe to discovery response (reply to query) + sub3, err := nc.Subscribe("picoclaw.swarm.discovery.>", func(msg *nats.Msg) { + var announce DiscoveryAnnounce + if err := json.Unmarshal(msg.Data, &announce); err != nil { + // Try direct NodeInfo + var node NodeInfo + if err2 := json.Unmarshal(msg.Data, &node); err2 == nil { + mu.Lock() + nodeID := node.ID + if nodeID == "" { + nodeID = node.NodeID + } + if node.NodeID == "" { + node.NodeID = nodeID + } + if hid == "" || node.HID == hid { + nodes[nodeID] = node + } + mu.Unlock() + } return } - cronRemoveCmd(cronStorePath, os.Args[3]) - case "enable": - cronEnableCmd(cronStorePath, false) - case "disable": - cronEnableCmd(cronStorePath, true) - default: - fmt.Printf("Unknown cron command: %s\n", subcommand) - cronHelp() - } -} - -func cronHelp() { - fmt.Println("\nCron commands:") - fmt.Println(" list List all scheduled jobs") - fmt.Println(" add Add a new scheduled job") - fmt.Println(" remove Remove a job by ID") - fmt.Println(" enable Enable a job") - fmt.Println(" disable Disable a job") - fmt.Println() - fmt.Println("Add options:") - fmt.Println(" -n, --name Job name") - fmt.Println(" -m, --message Message for agent") - fmt.Println(" -e, --every Run every N seconds") - fmt.Println(" -c, --cron Cron expression (e.g. '0 9 * * *')") - fmt.Println(" -d, --deliver Deliver response to channel") - fmt.Println(" --to Recipient for delivery") - fmt.Println(" --channel Channel for delivery") -} - -func cronListCmd(storePath string) { - cs := cron.NewCronService(storePath, nil) - jobs := cs.ListJobs(true) // Show all jobs, including disabled - - if len(jobs) == 0 { - fmt.Println("No scheduled jobs.") - return - } - - fmt.Println("\nScheduled Jobs:") - fmt.Println("----------------") - for _, job := range jobs { - var schedule string - if job.Schedule.Kind == "every" && job.Schedule.EveryMS != nil { - schedule = fmt.Sprintf("every %ds", *job.Schedule.EveryMS/1000) - } else if job.Schedule.Kind == "cron" { - schedule = job.Schedule.Expr - } else { - schedule = "one-time" + node := announce.Node + nodeID := node.ID + if nodeID == "" { + nodeID = node.NodeID } - - nextRun := "scheduled" - if job.State.NextRunAtMS != nil { - nextTime := time.UnixMilli(*job.State.NextRunAtMS) - nextRun = nextTime.Format("2006-01-02 15:04") + if node.NodeID == "" { + node.NodeID = nodeID } - - status := "enabled" - if !job.Enabled { - status = "disabled" + mu.Lock() + if hid == "" || node.HID == hid { + nodes[nodeID] = node } - - fmt.Printf(" %s (%s)\n", job.Name, job.ID) - fmt.Printf(" Schedule: %s\n", schedule) - fmt.Printf(" Status: %s\n", status) - fmt.Printf(" Next run: %s\n", nextRun) + mu.Unlock() + }) + if err == nil { + defer sub3.Unsubscribe() } -} -func cronAddCmd(storePath string) { - name := "" - message := "" - var everySec *int64 - cronExpr := "" - deliver := false - channel := "" - to := "" + // Publish a discovery query to prompt nodes to respond + queryMsg := map[string]interface{}{ + "requester_id": "picoclaw-cli-query", + "timestamp": time.Now().UnixMilli(), + } + queryData, _ := json.Marshal(queryMsg) - args := os.Args[3:] - for i := 0; i < len(args); i++ { - switch args[i] { - case "-n", "--name": - if i+1 < len(args) { - name = args[i+1] - i++ - } - case "-m", "--message": - if i+1 < len(args) { - message = args[i+1] - i++ - } - case "-e", "--every": - if i+1 < len(args) { - var sec int64 - fmt.Sscanf(args[i+1], "%d", &sec) - everySec = &sec - i++ + // Use PublishRequest to allow nodes to respond via msg.Respond() + inbox := nats.NewInbox() + responseSub, _ := nc.Subscribe(inbox, func(msg *nats.Msg) { + var node NodeInfo + if err := json.Unmarshal(msg.Data, &node); err == nil { + nodeID := node.ID + if nodeID == "" { + nodeID = node.NodeID } - case "-c", "--cron": - if i+1 < len(args) { - cronExpr = args[i+1] - i++ - } - case "-d", "--deliver": - deliver = true - case "--to": - if i+1 < len(args) { - to = args[i+1] - i++ + if nodeID != "" { + mu.Lock() + if hid == "" || node.HID == hid { + if existing, ok := nodes[nodeID]; ok { + // Update with more complete info + if node.Role != "" && existing.Role == "" { + existing.Role = node.Role + } + if node.Status != "" && existing.Status == "" { + existing.Status = node.Status + } + if len(node.Capabilities) > 0 && len(existing.Capabilities) == 0 { + existing.Capabilities = node.Capabilities + } + if node.MaxTasks > 0 && existing.MaxTasks == 0 { + existing.MaxTasks = node.MaxTasks + } + if node.Model != "" && existing.Model == "" { + existing.Model = node.Model + } + if node.HID != "" { + existing.HID = node.HID + } + if node.SID != "" { + existing.SID = node.SID + } + nodes[nodeID] = existing + } else { + // Ensure NodeID is set + if node.NodeID == "" { + node.NodeID = nodeID + } + nodes[nodeID] = node + } + } + mu.Unlock() } - case "--channel": - if i+1 < len(args) { - channel = args[i+1] - i++ + } + }) + defer responseSub.Unsubscribe() + + nc.PublishRequest("picoclaw.swarm.discovery.query", inbox, queryData) + + // Also try wildcard subscription to catch heartbeat messages + debugSub, _ := nc.Subscribe("picoclaw.swarm.heartbeat.>", func(msg *nats.Msg) { + var raw map[string]interface{} + if err := json.Unmarshal(msg.Data, &raw); err == nil { + if nodeID, ok := raw["node_id"].(string); ok { + mu.Lock() + if existing, ok := nodes[nodeID]; ok { + // Update heartbeat fields + if load, ok := raw["load"].(float64); ok { + existing.Load = load + } + if tasksRunning, ok := raw["tasks_running"].(float64); ok { + existing.TasksRunning = int(tasksRunning) + } + if status, ok := raw["status"].(string); ok { + existing.Status = status + } + if caps, ok := raw["capabilities"].([]interface{}); ok && len(existing.Capabilities) == 0 { + for _, c := range caps { + if cs, ok := c.(string); ok { + existing.Capabilities = append(existing.Capabilities, cs) + } + } + } + nodes[nodeID] = existing + } else { + // Create new node from heartbeat (for nodes that only send heartbeats) + newNode := NodeInfo{ + ID: nodeID, + NodeID: nodeID, + Status: "online", + Role: "worker", // Default to worker if not specified + } + if load, ok := raw["load"].(float64); ok { + newNode.Load = load + } + if tasksRunning, ok := raw["tasks_running"].(float64); ok { + newNode.TasksRunning = int(tasksRunning) + } + if status, ok := raw["status"].(string); ok { + newNode.Status = status + } + if caps, ok := raw["capabilities"].([]interface{}); ok { + for _, c := range caps { + if cs, ok := c.(string); ok { + newNode.Capabilities = append(newNode.Capabilities, cs) + } + } + } + nodes[nodeID] = newNode + } + mu.Unlock() } } - } - - if name == "" { - fmt.Println("Error: --name is required") - return - } - - if message == "" { - fmt.Println("Error: --message is required") - return - } - - if everySec == nil && cronExpr == "" { - fmt.Println("Error: Either --every or --cron must be specified") - return - } - - var schedule cron.CronSchedule - if everySec != nil { - everyMS := *everySec * 1000 - schedule = cron.CronSchedule{ - Kind: "every", - EveryMS: &everyMS, + }) + defer debugSub.Unsubscribe() + + // Wait longer for responses and heartbeats (heartbeat interval is 10s) + // Wait at least 20 seconds to capture at least 2 heartbeat cycles from all nodes + time.Sleep(20 * time.Second) + + mu.Lock() + nodeList := make([]NodeInfo, 0, len(nodes)) + for _, node := range nodes { + // Only include nodes with valid IDs + nodeID := node.ID + if nodeID == "" { + nodeID = node.NodeID } - } else { - schedule = cron.CronSchedule{ - Kind: "cron", - Expr: cronExpr, + if nodeID != "" { + nodeList = append(nodeList, node) } } + mu.Unlock() - cs := cron.NewCronService(storePath, nil) - job, err := cs.AddJob(name, schedule, message, deliver, channel, to) - if err != nil { - fmt.Printf("Error adding job: %v\n", err) - return - } + // Display results + fmt.Printf("%s Swarm Nodes\n\n", logo) - fmt.Printf("✓ Added job '%s' (%s)\n", job.Name, job.ID) -} - -func cronRemoveCmd(storePath, jobID string) { - cs := cron.NewCronService(storePath, nil) - if cs.RemoveJob(jobID) { - fmt.Printf("✓ Removed job %s\n", jobID) - } else { - fmt.Printf("✗ Job %s not found\n", jobID) - } -} - -func cronEnableCmd(storePath string, disable bool) { - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw cron enable/disable ") + if len(nodeList) == 0 { + fmt.Println("No nodes discovered.") + fmt.Println("\nMake sure swarm nodes are running:") + fmt.Println(" pm2 status") + fmt.Println("\nOr start a swarm node:") + fmt.Println(" picoclaw swarm start --role coordinator --embedded") return } - jobID := os.Args[3] - cs := cron.NewCronService(storePath, nil) - enabled := !disable + // Count by role + coordinators := 0 + workers := 0 + specialists := 0 - job := cs.EnableJob(jobID, enabled) - if job != nil { - status := "enabled" - if disable { - status = "disabled" + for _, node := range nodeList { + switch node.Role { + case "coordinator": + coordinators++ + case "worker": + workers++ + case "specialist": + specialists++ } - fmt.Printf("✓ Job '%s' %s\n", job.Name, status) - } else { - fmt.Printf("✗ Job %s not found\n", jobID) } -} -func skillsCmd() { - if len(os.Args) < 3 { - skillsHelp() - return - } + fmt.Printf("Total: %d node(s) found\n\n", len(nodeList)) + fmt.Printf(" • Coordinators: %d\n", coordinators) + fmt.Printf(" • Workers: %d\n", workers) + fmt.Printf(" • Specialists: %d\n", specialists) + fmt.Println("\nNodes:") - subcommand := os.Args[2] - - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - os.Exit(1) - } - - workspace := cfg.WorkspacePath() - installer := skills.NewSkillInstaller(workspace) - // 获取全局配置目录和内置 skills 目录 - globalDir := filepath.Dir(getConfigPath()) - globalSkillsDir := filepath.Join(globalDir, "skills") - builtinSkillsDir := filepath.Join(globalDir, "picoclaw", "skills") - skillsLoader := skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir) - - switch subcommand { - case "list": - skillsListCmd(skillsLoader) - case "install": - skillsInstallCmd(installer) - case "remove", "uninstall": - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw skills remove ") - return + for _, node := range nodeList { + statusIcon := "●" + if node.Status != "online" { + statusIcon = "○" } - skillsRemoveCmd(installer, os.Args[3]) - case "search": - skillsSearchCmd(installer) - case "show": - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw skills show ") - return + roleIcon := "C" + if node.Role == "worker" { + roleIcon = "W" + } else if node.Role == "specialist" { + roleIcon = "S" } - skillsShowCmd(skillsLoader, os.Args[3]) - default: - fmt.Printf("Unknown skills command: %s\n", subcommand) - skillsHelp() - } -} -func skillsHelp() { - fmt.Println("\nSkills commands:") - fmt.Println(" list List installed skills") - fmt.Println(" install Install skill from GitHub") - fmt.Println(" install-builtin Install all builtin skills to workspace") - fmt.Println(" list-builtin List available builtin skills") - fmt.Println(" remove Remove installed skill") - fmt.Println(" search Search available skills") - fmt.Println(" show Show skill details") - fmt.Println() - fmt.Println("Examples:") - fmt.Println(" picoclaw skills list") - fmt.Println(" picoclaw skills install sipeed/picoclaw-skills/weather") - fmt.Println(" picoclaw skills install-builtin") - fmt.Println(" picoclaw skills list-builtin") - fmt.Println(" picoclaw skills remove weather") -} - -func skillsListCmd(loader *skills.SkillsLoader) { - allSkills := loader.ListSkills() - - if len(allSkills) == 0 { - fmt.Println("No skills installed.") - return - } + loadPercent := int(node.Load * 100) - fmt.Println("\nInstalled Skills:") - fmt.Println("------------------") - for _, skill := range allSkills { - fmt.Printf(" ✓ %s (%s)\n", skill.Name, skill.Source) - if skill.Description != "" { - fmt.Printf(" %s\n", skill.Description) + // Use ID or NodeID for display + displayID := node.ID + if displayID == "" { + displayID = node.NodeID } - } -} - -func skillsInstallCmd(installer *skills.SkillInstaller) { - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw skills install ") - fmt.Println("Example: picoclaw skills install sipeed/picoclaw-skills/weather") - return - } - - repo := os.Args[3] - fmt.Printf("Installing skill from %s...\n", repo) - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if err := installer.InstallFromGitHub(ctx, repo); err != nil { - fmt.Printf("✗ Failed to install skill: %v\n", err) - os.Exit(1) - } - - fmt.Printf("✓ Skill '%s' installed successfully!\n", filepath.Base(repo)) -} - -func skillsRemoveCmd(installer *skills.SkillInstaller, skillName string) { - fmt.Printf("Removing skill '%s'...\n", skillName) - - if err := installer.Uninstall(skillName); err != nil { - fmt.Printf("✗ Failed to remove skill: %v\n", err) - os.Exit(1) - } - - fmt.Printf("✓ Skill '%s' removed successfully!\n", skillName) -} - -func skillsInstallBuiltinCmd(workspace string) { - builtinSkillsDir := "./picoclaw/skills" - workspaceSkillsDir := filepath.Join(workspace, "skills") - - fmt.Printf("Copying builtin skills to workspace...\n") - - skillsToInstall := []string{ - "weather", - "news", - "stock", - "calculator", - } - - for _, skillName := range skillsToInstall { - builtinPath := filepath.Join(builtinSkillsDir, skillName) - workspacePath := filepath.Join(workspaceSkillsDir, skillName) - - if _, err := os.Stat(builtinPath); err != nil { - fmt.Printf("⊘ Builtin skill '%s' not found: %v\n", skillName, err) - continue + if len(displayID) > 20 { + displayID = displayID[:17] + "..." } - if err := os.MkdirAll(workspacePath, 0755); err != nil { - fmt.Printf("✗ Failed to create directory for %s: %v\n", skillName, err) - continue - } + fmt.Printf(" %s %s %-20s [%2s] %s (load: %d%%, tasks: %d/%d)\n", + statusIcon, roleIcon, displayID, node.Role, + node.Status, loadPercent, node.TasksRunning, node.MaxTasks) - if err := copyDirectory(builtinPath, workspacePath); err != nil { - fmt.Printf("✗ Failed to copy %s: %v\n", skillName, err) + if len(node.Capabilities) > 0 { + fmt.Printf(" Capabilities: %s\n", strings.Join(node.Capabilities, ", ")) } - } - - fmt.Println("\n✓ All builtin skills installed!") - fmt.Println("Now you can use them in your workspace.") -} - -func skillsListBuiltinCmd() { - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - return - } - builtinSkillsDir := filepath.Join(filepath.Dir(cfg.WorkspacePath()), "picoclaw", "skills") - - fmt.Println("\nAvailable Builtin Skills:") - fmt.Println("-----------------------") - - entries, err := os.ReadDir(builtinSkillsDir) - if err != nil { - fmt.Printf("Error reading builtin skills: %v\n", err) - return - } - - if len(entries) == 0 { - fmt.Println("No builtin skills available.") - return - } - - for _, entry := range entries { - if entry.IsDir() { - skillName := entry.Name() - skillFile := filepath.Join(builtinSkillsDir, skillName, "SKILL.md") - - description := "No description" - if _, err := os.Stat(skillFile); err == nil { - data, err := os.ReadFile(skillFile) - if err == nil { - content := string(data) - if idx := strings.Index(content, "\n"); idx > 0 { - firstLine := content[:idx] - if strings.Contains(firstLine, "description:") { - descLine := strings.Index(content[idx:], "\n") - if descLine > 0 { - description = strings.TrimSpace(content[idx+descLine : idx+descLine]) - } - } - } - } - } - status := "✓" - fmt.Printf(" %s %s\n", status, entry.Name()) - if description != "" { - fmt.Printf(" %s\n", description) - } + if node.SID != "" { + fmt.Printf(" SID: %s\n", node.SID) } } -} - -func skillsSearchCmd(installer *skills.SkillInstaller) { - fmt.Println("Searching for available skills...") - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - availableSkills, err := installer.ListAvailableSkills(ctx) - if err != nil { - fmt.Printf("✗ Failed to fetch skills list: %v\n", err) - return + fmt.Printf("\nNATS: %s\n", natsURL) + if hid != "" { + fmt.Printf("HID: %s (filtered)\n", hid) } - - if len(availableSkills) == 0 { - fmt.Println("No skills available.") - return - } - - fmt.Printf("\nAvailable Skills (%d):\n", len(availableSkills)) - fmt.Println("--------------------") - for _, skill := range availableSkills { - fmt.Printf(" 📦 %s\n", skill.Name) - fmt.Printf(" %s\n", skill.Description) - fmt.Printf(" Repo: %s\n", skill.Repository) - if skill.Author != "" { - fmt.Printf(" Author: %s\n", skill.Author) - } - if len(skill.Tags) > 0 { - fmt.Printf(" Tags: %v\n", skill.Tags) - } - fmt.Println() - } -} - -func skillsShowCmd(loader *skills.SkillsLoader, skillName string) { - content, ok := loader.LoadSkill(skillName) - if !ok { - fmt.Printf("✗ Skill '%s' not found\n", skillName) - return - } - - fmt.Printf("\n📦 Skill: %s\n", skillName) - fmt.Println("----------------------") - fmt.Println(content) } diff --git a/config/config.example.json b/config/config.example.json index c71587a04..77a8c0683 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -3,22 +3,67 @@ "defaults": { "workspace": "~/.picoclaw/workspace", "restrict_to_workspace": true, - "model": "glm-4.7", + "model": "gpt4", "max_tokens": 8192, "temperature": 0.7, "max_tool_iterations": 20 } }, + "model_list": [ + { + "model_name": "gpt4", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key", + "api_base": "https://api.openai.com/v1" + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "sk-ant-your-key", + "api_base": "https://api.anthropic.com/v1" + }, + { + "model_name": "gemini", + "model": "antigravity/gemini-2.0-flash", + "auth_method": "oauth" + }, + { + "model_name": "deepseek", + "model": "deepseek/deepseek-chat", + "api_key": "sk-your-deepseek-key" + }, + { + "model_name": "loadbalanced-gpt4", + "model": "openai/gpt-5.2", + "api_key": "sk-key1", + "api_base": "https://api1.example.com/v1" + }, + { + "model_name": "loadbalanced-gpt4", + "model": "openai/gpt-5.2", + "api_key": "sk-key2", + "api_base": "https://api2.example.com/v1" + } + ], "channels": { "telegram": { "enabled": false, "token": "YOUR_TELEGRAM_BOT_TOKEN", "proxy": "", - "allow_from": ["YOUR_USER_ID"] + "allow_from": [ + "YOUR_USER_ID" + ] }, "discord": { "enabled": false, "token": "YOUR_DISCORD_BOT_TOKEN", + "allow_from": [], + "mention_only": false + }, + "qq": { + "enabled": false, + "app_id": "YOUR_QQ_APP_ID", + "app_secret": "YOUR_QQ_APP_SECRET", "allow_from": [] }, "maixcam": { @@ -51,16 +96,61 @@ "bot_token": "xoxb-YOUR-BOT-TOKEN", "app_token": "xapp-YOUR-APP-TOKEN", "allow_from": [] + }, + "line": { + "enabled": false, + "channel_secret": "YOUR_LINE_CHANNEL_SECRET", + "channel_access_token": "YOUR_LINE_CHANNEL_ACCESS_TOKEN", + "webhook_host": "0.0.0.0", + "webhook_port": 18791, + "webhook_path": "/webhook/line", + "allow_from": [] + }, + "onebot": { + "enabled": false, + "ws_url": "ws://127.0.0.1:3001", + "access_token": "", + "reconnect_interval": 5, + "group_trigger_prefix": [], + "allow_from": [] + }, + "wecom": { + "_comment": "WeCom Bot (智能机器人) - Easier setup, supports group chats", + "enabled": false, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", + "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18793, + "webhook_path": "/webhook/wecom", + "allow_from": [], + "reply_timeout": 5 + }, + "wecom_app": { + "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only. See docs/wecom-app-configuration.md", + "enabled": false, + "corp_id": "YOUR_CORP_ID", + "corp_secret": "YOUR_CORP_SECRET", + "agent_id": 1000002, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18792, + "webhook_path": "/webhook/wecom-app", + "allow_from": [], + "reply_timeout": 5 } }, "providers": { + "_comment": "DEPRECATED: Use model_list instead. This will be removed in a future version", "anthropic": { "api_key": "", "api_base": "" }, "openai": { "api_key": "", - "api_base": "" + "api_base": "", + "web_search": true }, "openrouter": { "api_key": "sk-or-v1-xxx", @@ -90,13 +180,57 @@ "moonshot": { "api_key": "sk-xxx", "api_base": "" + }, + "qwen": { + "api_key": "sk-xxx", + "api_base": "" + }, + "ollama": { + "api_key": "", + "api_base": "http://localhost:11434/v1" + }, + "cerebras": { + "api_key": "", + "api_base": "" + }, + "volcengine": { + "api_key": "", + "api_base": "" } }, "tools": { "web": { - "search": { + "brave": { + "enabled": false, "api_key": "YOUR_BRAVE_API_KEY", "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + }, + "perplexity": { + "enabled": false, + "api_key": "pplx-xxx", + "max_results": 5 + } + }, + "cron": { + "exec_timeout_minutes": 5 + }, + "exec": { + "enable_deny_patterns": false, + "custom_deny_patterns": [] + }, + "skills": { + "registries": { + "clawhub": { + "enabled": true, + "base_url": "https://clawhub.ai", + "search_path": "/api/v1/search", + "skills_path": "/api/v1/skills", + "download_path": "/api/v1/download" + } } } }, @@ -104,6 +238,10 @@ "enabled": true, "interval": 30 }, + "devices": { + "enabled": false, + "monitor_usb": true + }, "gateway": { "host": "0.0.0.0", "port": 18790 diff --git a/config/config.openrouter.json b/config/config.openrouter.json deleted file mode 100644 index 4aca883d4..000000000 --- a/config/config.openrouter.json +++ /dev/null @@ -1,86 +0,0 @@ -{ - "agents": { - "defaults": { - "workspace": "~/.picoclaw/workspace", - "model": "arcee-ai/trinity-large-preview:free", - "max_tokens": 8192, - "temperature": 0.7, - "max_tool_iterations": 20 - } - }, - "channels": { - "telegram": { - "enabled": false, - "token": "YOUR_TELEGRAM_BOT_TOKEN", - "allow_from": [ - "YOUR_USER_ID" - ] - }, - "discord": { - "enabled": true, - "token": "YOUR_DISCORD_BOT_TOKEN", - "allow_from": [] - }, - "maixcam": { - "enabled": false, - "host": "0.0.0.0", - "port": 18790, - "allow_from": [] - }, - "whatsapp": { - "enabled": false, - "bridge_url": "ws://localhost:3001", - "allow_from": [] - }, - "feishu": { - "enabled": false, - "app_id": "", - "app_secret": "", - "encrypt_key": "", - "verification_token": "", - "allow_from": [] - } - }, - "providers": { - "anthropic": { - "api_key": "", - "api_base": "" - }, - "openai": { - "api_key": "", - "api_base": "" - }, - "openrouter": { - "api_key": "sk-or-v1-xxx", - "api_base": "" - }, - "groq": { - "api_key": "gsk_xxx", - "api_base": "" - }, - "zhipu": { - "api_key": "YOUR_ZHIPU_API_KEY", - "api_base": "" - }, - "gemini": { - "api_key": "", - "api_base": "" - }, - "vllm": { - "api_key": "", - "api_base": "" - } - }, - "tools": { - "web": { - "search": { - "api_key": "YOUR_BRAVE_API_KEY", - "max_results": 5 - } - } - }, - "gateway": { - "host": "0.0.0.0", - "port": 18790 - } -} \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 48769627c..c268b01cd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,9 +10,12 @@ services: container_name: picoclaw-agent profiles: - agent + # Uncomment to access host network; leave commented unless needed. + #extra_hosts: + # - "host.docker.internal:host-gateway" volumes: - - ./config/config.json:/root/.picoclaw/config.json:ro - - picoclaw-workspace:/root/.picoclaw/workspace + - ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro + - picoclaw-workspace:/home/picoclaw/.picoclaw/workspace entrypoint: ["picoclaw", "agent"] stdin_open: true tty: true @@ -29,11 +32,14 @@ services: restart: unless-stopped profiles: - gateway + # Uncomment to access host network; leave commented unless needed. + #extra_hosts: + # - "host.docker.internal:host-gateway" volumes: # Configuration file - - ./config/config.json:/root/.picoclaw/config.json:ro + - ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro # Persistent workspace (sessions, memory, logs) - - picoclaw-workspace:/root/.picoclaw/workspace + - picoclaw-workspace:/home/picoclaw/.picoclaw/workspace command: ["gateway"] volumes: diff --git a/docs/ANTIGRAVITY_AUTH.md b/docs/ANTIGRAVITY_AUTH.md new file mode 100644 index 000000000..89261d899 --- /dev/null +++ b/docs/ANTIGRAVITY_AUTH.md @@ -0,0 +1,807 @@ +# Antigravity Authentication & Integration Guide + +## Overview + +**Antigravity** (Google Cloud Code Assist) is a Google-backed AI model provider that offers access to models like Claude Opus 4.6 and Gemini through Google's Cloud infrastructure. This document provides a complete guide on how authentication works, how to fetch models, and how to implement a new provider in PicoClaw. + +--- + +## Table of Contents + +1. [Authentication Flow](#authentication-flow) +2. [OAuth Implementation Details](#oauth-implementation-details) +3. [Token Management](#token-management) +4. [Models List Fetching](#models-list-fetching) +5. [Usage Tracking](#usage-tracking) +6. [Provider Plugin Structure](#provider-plugin-structure) +7. [Integration Requirements](#integration-requirements) +8. [API Endpoints](#api-endpoints) +9. [Configuration](#configuration) +10. [Creating a New Provider in PicoClaw](#creating-a-new-provider-in-picoclaw) + +--- + +## Authentication Flow + +### 1. OAuth 2.0 with PKCE + +Antigravity uses **OAuth 2.0 with PKCE (Proof Key for Code Exchange)** for secure authentication: + +``` +┌─────────────┐ ┌─────────────────┐ +│ Client │ ───(1) Generate PKCE Pair────────> │ │ +│ │ ───(2) Open Auth URL─────────────> │ Google OAuth │ +│ │ │ Server │ +│ │ <──(3) Redirect with Code───────── │ │ +│ │ └─────────────────┘ +│ │ ───(4) Exchange Code for Tokens──> │ Token URL │ +│ │ │ │ +│ │ <──(5) Access + Refresh Tokens──── │ │ +└─────────────┘ └─────────────────┘ +``` + +### 2. Detailed Steps + +#### Step 1: Generate PKCE Parameters +```typescript +function generatePkce(): { verifier: string; challenge: string } { + const verifier = randomBytes(32).toString("hex"); + const challenge = createHash("sha256").update(verifier).digest("base64url"); + return { verifier, challenge }; +} +``` + +#### Step 2: Build Authorization URL +```typescript +const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"; +const REDIRECT_URI = "http://localhost:51121/oauth-callback"; + +function buildAuthUrl(params: { challenge: string; state: string }): string { + const url = new URL(AUTH_URL); + url.searchParams.set("client_id", CLIENT_ID); + url.searchParams.set("response_type", "code"); + url.searchParams.set("redirect_uri", REDIRECT_URI); + url.searchParams.set("scope", SCOPES.join(" ")); + url.searchParams.set("code_challenge", params.challenge); + url.searchParams.set("code_challenge_method", "S256"); + url.searchParams.set("state", params.state); + url.searchParams.set("access_type", "offline"); + url.searchParams.set("prompt", "consent"); + return url.toString(); +} +``` + +**Required Scopes:** +```typescript +const SCOPES = [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/cclog", + "https://www.googleapis.com/auth/experimentsandconfigs", +]; +``` + +#### Step 3: Handle OAuth Callback + +**Automatic Mode (Local Development):** +- Start a local HTTP server on port 51121 +- Wait for the redirect from Google +- Extract the authorization code from the query parameters + +**Manual Mode (Remote/Headless):** +- Display the authorization URL to the user +- User completes authentication in their browser +- User pastes the full redirect URL back into the terminal +- Parse the code from the pasted URL + +#### Step 4: Exchange Code for Tokens +```typescript +const TOKEN_URL = "https://oauth2.googleapis.com/token"; + +async function exchangeCode(params: { + code: string; + verifier: string; +}): Promise<{ access: string; refresh: string; expires: number }> { + const response = await fetch(TOKEN_URL, { + method: "POST", + headers: { "Content-Type": "application/x-www-form-urlencoded" }, + body: new URLSearchParams({ + client_id: CLIENT_ID, + client_secret: CLIENT_SECRET, + code: params.code, + grant_type: "authorization_code", + redirect_uri: REDIRECT_URI, + code_verifier: params.verifier, + }), + }); + + const data = await response.json(); + + return { + access: data.access_token, + refresh: data.refresh_token, + expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000, // 5 min buffer + }; +} +``` + +#### Step 5: Fetch Additional User Data + +**User Email:** +```typescript +async function fetchUserEmail(accessToken: string): Promise { + const response = await fetch( + "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", + { headers: { Authorization: `Bearer ${accessToken}` } } + ); + const data = await response.json(); + return data.email; +} +``` + +**Project ID (Required for API calls):** +```typescript +async function fetchProjectId(accessToken: string): Promise { + const headers = { + Authorization: `Bearer ${accessToken}`, + "Content-Type": "application/json", + "User-Agent": "google-api-nodejs-client/9.15.1", + "X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1", + "Client-Metadata": JSON.stringify({ + ideType: "IDE_UNSPECIFIED", + platform: "PLATFORM_UNSPECIFIED", + pluginType: "GEMINI", + }), + }; + + const response = await fetch( + "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist", + { + method: "POST", + headers, + body: JSON.stringify({ + metadata: { + ideType: "IDE_UNSPECIFIED", + platform: "PLATFORM_UNSPECIFIED", + pluginType: "GEMINI", + }, + }), + } + ); + + const data = await response.json(); + return data.cloudaicompanionProject || "rising-fact-p41fc"; // Default fallback +} +``` + +--- + +## OAuth Implementation Details + +### Client Credentials + +**Important:** These are base64-encoded in the source code for sync with pi-ai: + +```typescript +const decode = (s: string) => Buffer.from(s, "base64").toString(); + +const CLIENT_ID = decode( + "MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==" +); +const CLIENT_SECRET = decode("R09DU1BYLUs1OEZXUjQ4NkxkTEoxbUxCOHNYQzR6NnFEQWY="); +``` + +### OAuth Flow Modes + +1. **Automatic Flow** (Local machines with browser): + - Opens browser automatically + - Local callback server captures redirect + - No user interaction required after initial auth + +2. **Manual Flow** (Remote/headless/WSL2): + - URL displayed for manual copy-paste + - User completes auth in external browser + - User pastes full redirect URL back + +```typescript +function shouldUseManualOAuthFlow(isRemote: boolean): boolean { + return isRemote || isWSL2Sync(); +} +``` + +--- + +## Token Management + +### Auth Profile Structure + +```typescript +type OAuthCredential = { + type: "oauth"; + provider: "google-antigravity"; + access: string; // Access token + refresh: string; // Refresh token + expires: number; // Expiration timestamp (ms since epoch) + email?: string; // User email + projectId?: string; // Google Cloud project ID +}; +``` + +### Token Refresh + +The credential includes a refresh token that can be used to obtain new access tokens when the current one expires. The expiration is set with a 5-minute buffer to prevent race conditions. + +--- + +## Models List Fetching + +### Fetch Available Models + +```typescript +const BASE_URL = "https://cloudcode-pa.googleapis.com"; + +async function fetchAvailableModels( + accessToken: string, + projectId: string +): Promise { + const headers = { + Authorization: `Bearer ${accessToken}`, + "Content-Type": "application/json", + "User-Agent": "antigravity", + "X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1", + }; + + const response = await fetch( + `${BASE_URL}/v1internal:fetchAvailableModels`, + { + method: "POST", + headers, + body: JSON.stringify({ project: projectId }), + } + ); + + const data = await response.json(); + + // Returns models with quota information + return Object.entries(data.models).map(([modelId, modelInfo]) => ({ + id: modelId, + displayName: modelInfo.displayName, + quotaInfo: { + remainingFraction: modelInfo.quotaInfo?.remainingFraction, + resetTime: modelInfo.quotaInfo?.resetTime, + isExhausted: modelInfo.quotaInfo?.isExhausted, + }, + })); +} +``` + +### Response Format + +```typescript +type FetchAvailableModelsResponse = { + models?: Record; +}; +``` + +--- + +## Usage Tracking + +### Fetch Usage Data + +```typescript +export async function fetchAntigravityUsage( + token: string, + timeoutMs: number +): Promise { + // 1. Fetch credits and plan info + const loadCodeAssistRes = await fetch( + `${BASE_URL}/v1internal:loadCodeAssist`, + { + method: "POST", + headers: { + Authorization: `Bearer ${token}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + metadata: { + ideType: "ANTIGRAVITY", + platform: "PLATFORM_UNSPECIFIED", + pluginType: "GEMINI", + }, + }), + } + ); + + // Extract credits info + const { availablePromptCredits, planInfo, currentTier } = data; + + // 2. Fetch model quotas + const modelsRes = await fetch( + `${BASE_URL}/v1internal:fetchAvailableModels`, + { + method: "POST", + headers: { Authorization: `Bearer ${token}` }, + body: JSON.stringify({ project: projectId }), + } + ); + + // Build usage windows + return { + provider: "google-antigravity", + displayName: "Google Antigravity", + windows: [ + { label: "Credits", usedPercent: calculateUsedPercent(available, monthly) }, + // Individual model quotas... + ], + plan: currentTier?.name || planType, + }; +} +``` + +### Usage Response Structure + +```typescript +type ProviderUsageSnapshot = { + provider: "google-antigravity"; + displayName: string; + windows: UsageWindow[]; + plan?: string; + error?: string; +}; + +type UsageWindow = { + label: string; // "Credits" or model ID + usedPercent: number; // 0-100 + resetAt?: number; // Timestamp when quota resets +}; +``` + +--- + +## Provider Plugin Structure + +### Plugin Definition + +```typescript +const antigravityPlugin = { + id: "google-antigravity-auth", + name: "Google Antigravity Auth", + description: "OAuth flow for Google Antigravity (Cloud Code Assist)", + configSchema: emptyPluginConfigSchema(), + + register(api: PicoClawPluginApi) { + api.registerProvider({ + id: "google-antigravity", + label: "Google Antigravity", + docsPath: "/providers/models", + aliases: ["antigravity"], + + auth: [ + { + id: "oauth", + label: "Google OAuth", + hint: "PKCE + localhost callback", + kind: "oauth", + run: async (ctx: ProviderAuthContext) => { + // OAuth implementation here + }, + }, + ], + }); + }, +}; +``` + +### ProviderAuthContext + +```typescript +type ProviderAuthContext = { + config: PicoClawConfig; + agentDir?: string; + workspaceDir?: string; + prompter: WizardPrompter; // UI prompts/notifications + runtime: RuntimeEnv; // Logging, etc. + isRemote: boolean; // Whether running remotely + openUrl: (url: string) => Promise; // Browser opener + oauth: { + createVpsAwareHandlers: Function; + }; +}; +``` + +### ProviderAuthResult + +```typescript +type ProviderAuthResult = { + profiles: Array<{ + profileId: string; + credential: AuthProfileCredential; + }>; + configPatch?: Partial; + defaultModel?: string; + notes?: string[]; +}; +``` + +--- + +## Integration Requirements + +### 1. Required Environment/Dependencies + +- Go ≥ 1.21 +- PicoClaw codebase (`pkg/providers/` and `pkg/auth/`) +- `crypto` and `net/http` standard library packages + +### 2. Required Headers for API Calls + +```typescript +const REQUIRED_HEADERS = { + "Authorization": `Bearer ${accessToken}`, + "Content-Type": "application/json", + "User-Agent": "antigravity", // or "google-api-nodejs-client/9.15.1" + "X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1", +}; + +// For loadCodeAssist calls, also include: +const CLIENT_METADATA = { + ideType: "ANTIGRAVITY", // or "IDE_UNSPECIFIED" + platform: "PLATFORM_UNSPECIFIED", + pluginType: "GEMINI", +}; +``` + +### 3. Model Schema Sanitization + +Antigravity uses Gemini-compatible models, so tool schemas must be sanitized: + +```typescript +const GOOGLE_SCHEMA_UNSUPPORTED_KEYWORDS = new Set([ + "patternProperties", + "additionalProperties", + "$schema", + "$id", + "$ref", + "$defs", + "definitions", + "examples", + "minLength", + "maxLength", + "minimum", + "maximum", + "multipleOf", + "pattern", + "format", + "minItems", + "maxItems", + "uniqueItems", + "minProperties", + "maxProperties", +]); + +// Clean schema before sending +function cleanToolSchemaForGemini(schema: Record): unknown { + // Remove unsupported keywords + // Ensure top-level has type: "object" + // Flatten anyOf/oneOf unions +} +``` + +### 4. Thinking Block Handling (Claude Models) + +For Antigravity Claude models, thinking blocks require special handling: + +```typescript +const ANTIGRAVITY_SIGNATURE_RE = /^[A-Za-z0-9+/]+={0,2}$/; + +export function sanitizeAntigravityThinkingBlocks( + messages: AgentMessage[] +): AgentMessage[] { + // Validate thinking signatures + // Normalize signature fields + // Discard unsigned thinking blocks +} +``` + +--- + +## API Endpoints + +### Authentication Endpoints + +| Endpoint | Method | Purpose | +|----------|--------|---------| +| `https://accounts.google.com/o/oauth2/v2/auth` | GET | OAuth authorization | +| `https://oauth2.googleapis.com/token` | POST | Token exchange | +| `https://www.googleapis.com/oauth2/v1/userinfo` | GET | User info (email) | + +### Cloud Code Assist Endpoints + +| Endpoint | Method | Purpose | +|----------|--------|---------| +| `https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist` | POST | Load project info, credits, plan | +| `https://cloudcode-pa.googleapis.com/v1internal:fetchAvailableModels` | POST | List available models with quotas | +| `https://cloudcode-pa.googleapis.com/v1internal:streamGenerateContent?alt=sse` | POST | Chat streaming endpoint | + +**API Request Format (Chat):** +The `v1internal:streamGenerateContent` endpoint expects an envelope wrapping the standard Gemini request: + +```json +{ + "project": "your-project-id", + "model": "model-id", + "request": { + "contents": [...], + "systemInstruction": {...}, + "generationConfig": {...}, + "tools": [...] + }, + "requestType": "agent", + "userAgent": "antigravity", + "requestId": "agent-timestamp-random" +} +``` + +**API Response Format (SSE):** +Each SSE message (`data: {...}`) is wrapped in a `response` field: + +```json +{ + "response": { + "candidates": [...], + "usageMetadata": {...}, + "modelVersion": "...", + "responseId": "..." + }, + "traceId": "...", + "metadata": {} +} +``` + +--- + +## Configuration + +### config.json Configuration + +```json +{ + "model_list": [ + { + "model_name": "gemini-flash", + "model": "antigravity/gemini-3-flash", + "auth_method": "oauth" + } + ], + "agents": { + "defaults": { + "model": "gemini-flash" + } + } +} +``` + +### Auth Profile Storage + +Auth profiles are stored in `~/.picoclaw/auth.json`: + +```json +{ + "credentials": { + "google-antigravity": { + "access_token": "ya29...", + "refresh_token": "1//...", + "expires_at": "2026-01-01T00:00:00Z", + "provider": "google-antigravity", + "auth_method": "oauth", + "email": "user@example.com", + "project_id": "my-project-id" + } + } +} +``` + +--- + +## Creating a New Provider in PicoClaw + +PicoClaw providers are implemented as Go packages under `pkg/providers/`. To add a new provider: + +### Step-by-Step Implementation + +#### 1. Create Provider File + +Create a new Go file in `pkg/providers/`: + +``` +pkg/providers/ +└── your_provider.go +``` + +#### 2. Implement the Provider Interface + +Your provider must implement the `Provider` interface defined in `pkg/providers/types.go`: + +```go +package providers + +type YourProvider struct { + apiKey string + apiBase string +} + +func NewYourProvider(apiKey, apiBase, proxy string) *YourProvider { + if apiBase == "" { + apiBase = "https://api.your-provider.com/v1" + } + return &YourProvider{apiKey: apiKey, apiBase: apiBase} +} + +func (p *YourProvider) Chat(ctx context.Context, messages []Message, tools []Tool, cb StreamCallback) error { + // Implement chat completion with streaming +} +``` + +#### 3. Register in the Factory + +Add your provider to the protocol switch in `pkg/providers/factory.go`: + +```go +case "your-provider": + return NewYourProvider(sel.apiKey, sel.apiBase, sel.proxy), nil +``` + +#### 4. Add Default Config (Optional) + +Add a default entry in `pkg/config/defaults.go`: + +```go +{ + ModelName: "your-model", + Model: "your-provider/model-name", + APIKey: "", +}, +``` + +#### 5. Add Auth Support (Optional) + +If your provider requires OAuth or special authentication, add a case to `cmd/picoclaw/cmd_auth.go`: + +```go +case "your-provider": + authLoginYourProvider() +``` + +#### 6. Configure via `config.json` + +```json +{ + "model_list": [ + { + "model_name": "your-model", + "model": "your-provider/model-name", + "api_key": "your-api-key", + "api_base": "https://api.your-provider.com/v1" + } + ] +} +``` + +--- + +## Testing Your Implementation + +### CLI Commands + +```bash +# Authenticate with a provider +picoclaw auth login --provider your-provider + +# List models (for Antigravity) +picoclaw auth models + +# Start the gateway +picoclaw gateway + +# Run an agent with a specific model +picoclaw agent -m "Hello" --model your-model +``` + +### Environment Variables for Testing + +```bash +# Override default model +export PICOCLAW_AGENTS_DEFAULTS_MODEL=your-model + +# Override provider settings +export PICOCLAW_MODEL_LIST='[{"model_name":"your-model","model":"your-provider/model-name","api_key":"..."}]' +``` + +--- + +## References + +- **Source Files:** + - `pkg/providers/antigravity_provider.go` - Antigravity provider implementation + - `pkg/auth/oauth.go` - OAuth flow implementation + - `pkg/auth/store.go` - Auth credential storage (`~/.picoclaw/auth.json`) + - `pkg/providers/factory.go` - Provider factory and protocol routing + - `pkg/providers/types.go` - Provider interface definitions + - `cmd/picoclaw/cmd_auth.go` - Auth CLI commands + +- **Documentation:** + - `docs/ANTIGRAVITY_USAGE.md` - Antigravity usage guide + - `docs/migration/model-list-migration.md` - Migration guide + +--- + +## Notes + +1. **Google Cloud Project:** Antigravity requires Gemini for Google Cloud to be enabled on your Google Cloud project +2. **Quotas:** Uses Google Cloud project quotas (not separate billing) +3. **Model Access:** Available models depend on your Google Cloud project configuration +4. **Thinking Blocks:** Claude models via Antigravity require special handling of thinking blocks with signatures +5. **Schema Sanitization:** Tool schemas must be sanitized to remove unsupported JSON Schema keywords + +--- + +--- + +## Common Error Handling + +### 1. Rate Limiting (HTTP 429) + +Antigravity returns a 429 error when project/model quotas are exhausted. The error response often contains a `quotaResetDelay` in the `details` field. + +**Example 429 Error:** +```json +{ + "error": { + "code": 429, + "message": "You have exhausted your capacity on this model. Your quota will reset after 4h30m28s.", + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "metadata": { + "quotaResetDelay": "4h30m28.060903746s" + } + } + ] + } +} +``` + +### 2. Empty Responses (Restricted Models) + +Some models might show up in the available models list but return an empty response (200 OK but empty SSE stream). This usually happens for preview or restricted models that the current project doesn't have permission to use. + +**Treatment:** Treat empty responses as errors informing the user that the model might be restricted or invalid for their project. + +--- + +## Troubleshooting + +### "Token expired" +- Refresh OAuth tokens: `picoclaw auth login --provider antigravity` + +### "Gemini for Google Cloud is not enabled" +- Enable the API in your Google Cloud Console + +### "Project not found" +- Ensure your Google Cloud project has the necessary APIs enabled +- Check that the project ID is correctly fetched during authentication + +### Models not appearing in list +- Verify OAuth authentication completed successfully +- Check auth profile storage: `~/.picoclaw/auth.json` +- Re-run `picoclaw auth login --provider antigravity` diff --git a/docs/ANTIGRAVITY_USAGE.md b/docs/ANTIGRAVITY_USAGE.md new file mode 100644 index 000000000..e8194b6bc --- /dev/null +++ b/docs/ANTIGRAVITY_USAGE.md @@ -0,0 +1,70 @@ +# Using Antigravity Provider in PicoClaw + +This guide explains how to set up and use the **Antigravity** (Google Cloud Code Assist) provider in PicoClaw. + +## Prerequisites + +1. A Google account. +2. Google Cloud Code Assist enabled (usually available via the "Gemini for Google Cloud" onboarding). + +## 1. Authentication + +To authenticate with Antigravity, run the following command: + +```bash +picoclaw auth login --provider antigravity +``` + +### Manual Authentication (Headless/VPS) +If you are running on a server (Coolify/Docker) and cannot reach `localhost`, follow these steps: +1. Run the command above. +2. Copy the URL provided and open it in your local browser. +3. Complete the login. +4. Your browser will redirect to a `localhost:51121` URL (which will fail to load). +5. **Copy that final URL** from your browser's address bar. +6. **Paste it back into the terminal** where PicoClaw is waiting. + +PicoClaw will extract the authorization code and complete the process automatically. + +## 2. Managing Models + +### List Available Models +To see which models your project has access to and check their quotas: + +```bash +picoclaw auth models +``` + +### Switch Models +You can change the default model in `~/.picoclaw/config.json` or override it via the CLI: + +```bash +# Override for a single command +picoclaw agent -m "Hello" --model claude-opus-4-6-thinking +``` + +## 3. Real-world Usage (Coolify/Docker) + +If you are deploying via Coolify or Docker, follow these steps to test: + +1. **Environment Variables**: + * `PICOCLAW_AGENTS_DEFAULTS_MODEL=gemini-flash` +2. **Authentication persistence**: + If you've logged in locally, you can copy your credentials to the server: + ```bash + scp ~/.picoclaw/auth.json user@your-server:~/.picoclaw/ + ``` + *Alternatively*, run the `auth login` command once on the server if you have terminal access. + +## 4. Troubleshooting + +* **Empty Response**: If a model returns an empty reply, it may be restricted for your project. Try `gemini-3-flash` or `claude-opus-4-6-thinking`. +* **429 Rate Limit**: Antigravity has strict quotas. PicoClaw will display the "reset time" in the error message if you hit a limit. +* **404 Not Found**: Ensure you are using a model ID from the `picoclaw auth models` list. Use the short ID (e.g., `gemini-3-flash`) not the full path. + +## 5. Summary of Working Models + +Based on testing, the following models are most reliable: +* `gemini-3-flash` (Fast, highly available) +* `gemini-2.5-flash-lite` (Lightweight) +* `claude-opus-4-6-thinking` (Powerful, includes reasoning) diff --git a/docs/channels/dingtalk/README.zh.md b/docs/channels/dingtalk/README.zh.md new file mode 100644 index 000000000..1e445d0b0 --- /dev/null +++ b/docs/channels/dingtalk/README.zh.md @@ -0,0 +1,33 @@ +# 钉钉 + +钉钉是阿里巴巴的企业通讯平台,在中国职场中广受欢迎。它采用流式 SDK 来维持持久连接。 + +## 配置 + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "allow_from": [] + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| ------------- | ------ | ---- | -------------------------------- | +| enabled | bool | 是 | 是否启用钉钉频道 | +| client_id | string | 是 | 钉钉应用的 Client ID | +| client_secret | string | 是 | 钉钉应用的 Client Secret | +| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 | + +## 设置流程 + +1. 前往 [钉钉开放平台](https://open.dingtalk.com/) +2. 创建一个企业内部应用 +3. 从应用设置中获取 Client ID 和 Client Secret +4. 配置OAuth和事件订阅(如需要) +5. 将 Client ID 和 Client Secret 填入配置文件中 diff --git a/docs/channels/discord/README.zh.md b/docs/channels/discord/README.zh.md new file mode 100644 index 000000000..5b597eced --- /dev/null +++ b/docs/channels/discord/README.zh.md @@ -0,0 +1,35 @@ +# Discord + +Discord 是一个专为社区设计的免费语音、视频和文本聊天应用。PicoClaw 通过 Discord Bot API 连接到 Discord 服务器,支持接收和发送消息。 + +## 配置 + +```json +{ + "channels": { + "discord": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allow_from": ["YOUR_USER_ID"], + "mention_only": false + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| ------------ | ------ | ---- | -------------------------------- | +| enabled | bool | 是 | 是否启用 Discord 频道 | +| token | string | 是 | Discord 机器人 Token | +| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 | +| mention_only | bool | 否 | 是否仅响应提及机器人的消息 | + +## 设置流程 + +1. 前往 [Discord 开发者门户](https://discord.com/developers/applications) 创建一个新的应用 +2. 启用 Intents: + - Message Content Intent + - Server Members Intent +3. 获取 Bot Token +4. 将 Bot Token 填入配置文件中 +5. 邀请机器人加入服务器并授予必要权限(例如发送消息、读取消息历史等) diff --git a/docs/channels/feishu/README.zh.md b/docs/channels/feishu/README.zh.md new file mode 100644 index 000000000..310827723 --- /dev/null +++ b/docs/channels/feishu/README.zh.md @@ -0,0 +1,37 @@ +# 飞书 + +飞书(国际版名称:Lark)是字节跳动旗下的企业协作平台。它通过事件驱动的 Webhook 同时支持中国和全球市场。 + +## 配置 + +```json +{ + "channels": { + "feishu": { + "enabled": true, + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| ------------------ | ------ | ---- | -------------------------------- | +| enabled | bool | 是 | 是否启用飞书频道 | +| app_id | string | 是 | 飞书应用的 App ID(以cli\_开头) | +| app_secret | string | 是 | 飞书应用的 App Secret | +| encrypt_key | string | 否 | 事件回调加密密钥 | +| verification_token | string | 否 | 用于Webhook事件验证的Token | +| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 | + +## 设置流程 + +1. 前往 [飞书开放平台](https://open.feishu.cn/)创建应用程序 +2. 获取 App ID 和 App Secret +3. 配置事件订阅和Webhook URL +4. 设置加密(可选,生产环境建议启用) +5. 将 App ID、App Secret、Encrypt Key 和 Verification Token(如果启用加密) 填入配置文件中 diff --git a/docs/channels/line/README.zh.md b/docs/channels/line/README.zh.md new file mode 100644 index 000000000..fd3aa80da --- /dev/null +++ b/docs/channels/line/README.zh.md @@ -0,0 +1,41 @@ +# Line + +PicoClaw 通过 LINE Messaging API 配合 Webhook 回调功能实现对 LINE 的支持。 + +## 配置 + +```json +{ + "channels": { + "line": { + "enabled": true, + "channel_secret": "YOUR_CHANNEL_SECRET", + "channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN", + "webhook_host": "0.0.0.0", + "webhook_port": 18791, + "webhook_path": "/webhook/line", + "allow_from": [] + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| -------------------- | ------ | ---- | ------------------------------------------ | +| enabled | bool | 是 | 是否启用 LINE Channel | +| channel_secret | string | 是 | LINE Messaging API 的 Channel Secret | +| channel_access_token | string | 是 | LINE Messaging API 的 Channel Access Token | +| webhook_host | string | 是 | Webhook 监听的主机地址 (通常为 0.0.0.0) | +| webhook_port | int | 是 | Webhook 监听的端口 (默认为 18791) | +| webhook_path | string | 是 | Webhook 的路径 (默认为 /webhook/line) | +| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 | + +## 设置流程 + +1. 前往 [LINE Developers Console](https://developers.line.biz/console/) 创建一个服务提供商和一个 Messaging API Channel +2. 获取 Channel Secret 和 Channel Access Token +3. 配置Webhook: + - Line要求Webhook必须使用HTTPS协议,因此需要部署一个支持HTTPS的服务器,或者使用反向代理工具如ngrok将本地服务器暴露到公网 + - 将 Webhook URL 设置为 `https://your-domain.com/webhook/line` + - 启用 Webhook 并验证 URL +4. 将 Channel Secret 和 Channel Access Token 填入配置文件中 diff --git a/docs/channels/maixcam/README.zh.md b/docs/channels/maixcam/README.zh.md new file mode 100644 index 000000000..8d53d4bef --- /dev/null +++ b/docs/channels/maixcam/README.zh.md @@ -0,0 +1,31 @@ +# MaixCam + +MaixCam 是专用于连接矽速科技 MaixCAM 与 MaixCAM2 AI 摄像设备的通道。它采用 TCP 套接字实现双向通信,支持边缘 AI 部署场景。 + +## 配置 + +```json +{ + "channels": { + "maixcam": { + "enabled": true, + "server_address": "0.0.0.0:8899", + "allow_from": [] + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| -------------- | ------ | ---- | -------------------------------- | +| enabled | bool | 是 | 是否启用 MaixCam 频道 | +| server_address | string | 是 | TCP 服务器监听地址和端口 | +| allow_from | array | 否 | 设备ID白名单,空表示允许所有设备 | + +## 使用场景 + +MaixCam 通道使 PicoClaw 能够作为边缘设备的 AI 后端运行: + +- **智能监控** :MaixCAM 发送图像帧,PicoClaw 通过视觉模型进行分析 +- **物联网控制** :设备发送传感器数据,PicoClaw 协调响应 +- **离线AI** :在本地网络部署 PicoClaw 实现低延迟推理 diff --git a/docs/channels/onebot/README.zh.md b/docs/channels/onebot/README.zh.md new file mode 100644 index 000000000..6195f1c98 --- /dev/null +++ b/docs/channels/onebot/README.zh.md @@ -0,0 +1,31 @@ +# OneBot + +OneBot 是一个面向 QQ 机器人的开放协议标准,为多种 QQ 机器人实现(例如 go-cqhttp、Mirai)提供了统一的接口。它使用 WebSocket 进行通信。 + +## 配置 + +```json +{ + "channels": { + "onebot": { + "enabled": true, + "ws_url": "ws://localhost:8080", + "access_token": "", + "allow_from": [] + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| ------------ | ------ | ---- | -------------------------------- | +| enabled | bool | 是 | 是否启用 OneBot 频道 | +| ws_url | string | 是 | OneBot 服务器的 WebSocket URL | +| access_token | string | 否 | 连接 OneBot 服务器的访问令牌 | +| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 | + +## 设置流程 + +1. 部署一个 OneBot 兼容的实现(例如napcat) +2. 配置 OneBot 实现以启用 WebSocket 服务并设置访问令牌(如果需要) +3. 将 WebSocket URL 和访问令牌填入配置文件中 diff --git a/docs/channels/qq/README.zh.md b/docs/channels/qq/README.zh.md new file mode 100644 index 000000000..bd774960f --- /dev/null +++ b/docs/channels/qq/README.zh.md @@ -0,0 +1,32 @@ +# QQ + +PicoClaw 通过 QQ 开放平台的官方机器人 API 提供对 QQ 的支持。 + +## 配置 + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "YOUR_APP_ID", + "app_secret": "YOUR_APP_SECRET", + "allow_from": [] + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| ---------- | ------ | ---- | -------------------------------- | +| enabled | bool | 是 | 是否启用 QQ Channel | +| app_id | string | 是 | QQ 机器人应用的 App ID | +| app_secret | string | 是 | QQ 机器人应用的 App Secret | +| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 | + +## 设置流程 + +1. 前往 [QQ 开放平台](https://q.qq.com/) 创建一个机器人 +2. 通过仪表盘获取 App ID 和 App Secret +3. 开启机器人沙箱模式, 将用户和群添加到沙箱中 +4. 将 App ID 和 App Secret 填入配置文件中 diff --git a/docs/channels/slack/README.zh.md b/docs/channels/slack/README.zh.md new file mode 100644 index 000000000..58ebcb566 --- /dev/null +++ b/docs/channels/slack/README.zh.md @@ -0,0 +1,33 @@ +# Slack + +Slack 是全球领先的企业级即时通讯平台。PicoClaw 采用 Slack 的 Socket Mode 实现实时双向通信,无需配置公开的 Webhook 端点。 + +## 配置 + +```json +{ + "channels": { + "slack": { + "enabled": true, + "bot_token": "xoxb-...", + "app_token": "xapp-...", + "allow_from": [] + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| ---------- | ------ | ---- | -------------------------------------------------------- | +| enabled | bool | 是 | 是否启用 Slack 频道 | +| bot_token | string | 是 | Slack 机器人的 Bot User OAuth Token (以 xoxb- 开头) | +| app_token | string | 是 | Slack 应用的 Socket Mode App Level Token (以 xapp- 开头) | +| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 | + +## 设置流程 + +1. 前往 [Slack API](https://api.slack.com/) 创建一个新的 Slack 应用 +2. 启用 Socket Mode 并获取 App Level Token +3. 添加 Bot Token Scopes(例如`chat:write`、`im:history`等) +4. 安装应用到工作区并获取 Bot User OAuth Token +5. 将 Bot Token 和 App Token 填入配置文件中 diff --git a/docs/channels/telegram/README.zh.md b/docs/channels/telegram/README.zh.md new file mode 100644 index 000000000..d453c68fa --- /dev/null +++ b/docs/channels/telegram/README.zh.md @@ -0,0 +1,33 @@ +# Telegram + +Telegram Channel 通过 Telegram 机器人 API 使用长轮询实现基于机器人的通信。它支持文本消息、媒体附件(照片、语音、音频、文档)、通过 Groq Whisper 进行语音转录以及内置命令处理器。 + +## 配置 + +```json +{ + "channels": { + "telegram": { + "enabled": true, + "token": "123456789:ABCdefGHIjklMNOpqrsTUVwxyz", + "allow_from": ["123456789"], + "proxy": "" + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| ---------- | ------ | ---- | --------------------------------------------------------- | +| enabled | bool | 是 | 是否启用 Telegram 频道 | +| token | string | 是 | Telegram 机器人 API Token | +| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 | +| proxy | string | 否 | 连接 Telegram API 的代理 URL (例如 http://127.0.0.1:7890) | + +## 设置流程 + +1. 在 Telegram 中搜索 `@BotFather` +2. 发送 `/newbot` 命令并按照提示创建新机器人 +3. 获取 HTTP API Token +4. 将 Token 填入配置文件中 +5. (可选) 配置 `allow_from` 以限制允许互动的用户 ID (可通过 `@userinfobot` 获取 ID) diff --git a/docs/channels/wecom/wecom_app/README.zh.md b/docs/channels/wecom/wecom_app/README.zh.md new file mode 100644 index 000000000..1e6a0e2b3 --- /dev/null +++ b/docs/channels/wecom/wecom_app/README.zh.md @@ -0,0 +1,47 @@ +# 企业微信自建应用 + +企业微信自建应用是指企业在企业微信中创建的应用,主要用于企业内部使用。通过企业微信自建应用,企业可以实现与员工的高效沟通和协作,提高工作效率。 + +## 配置 + +```json +{ + "channels": { + "wecom_app": { + "enabled": true, + "corp_id": "wwxxxxxxxxxxxxxxxx", + "corp_secret": "YOUR_CORP_SECRET", + "agent_id": 1000002, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18792, + "webhook_path": "/webhook/wecom-app", + "allow_from": [], + "reply_timeout": 5 + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| ---------------- | ------ | ---- | ---------------------------------------- | +| corp_id | string | 是 | 企业 ID | +| corp_secret | string | 是 | 应用程序密钥 | +| agent_id | int | 是 | 应用程序代理 ID | +| token | string | 是 | 回调验证令牌 | +| encoding_aes_key | string | 是 | 43 字符 AES 密钥 | +| webhook_host | string | 否 | HTTP 服务器绑定地址 | +| webhook_port | int | 否 | HTTP 服务器端口(默认:18792) | +| webhook_path | string | 否 | Webhook 路径(默认:/webhook/wecom-app) | +| allow_from | array | 否 | 用户 ID 白名单 | +| reply_timeout | int | 否 | 回复超时时间(秒) | + +## 设置流程 + +1. 登录 [企业微信管理后台](https://work.weixin.qq.com/) +2. 进入“应用管理” -> “创建应用” +3. 获取企业 ID (CorpID) 和应用 Secret +4. 在应用设置中配置“接收消息”,获取 Token 和 EncodingAESKey +5. 设置回调 URL 为 `http://:/webhook/wecom-app` +6. 将 CorpID, Secret, AgentID 等信息填入配置文件 diff --git a/docs/channels/wecom/wecom_bot/README.zh.md b/docs/channels/wecom/wecom_bot/README.zh.md new file mode 100644 index 000000000..c4bb1c87e --- /dev/null +++ b/docs/channels/wecom/wecom_bot/README.zh.md @@ -0,0 +1,41 @@ +# 企业微信机器人 + +企业微信机器人是企业微信提供的一种快速接入方式,可以通过 Webhook URL 接收消息。 + +## 配置 + +```json +{ + "channels": { + "wecom": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_ENCODING_AES_KEY", + "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY", + "webhook_host": "0.0.0.0", + "webhook_port": 18793, + "webhook_path": "/webhook/wecom", + "allow_from": [], + "reply_timeout": 5 + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| ---------------- | ------ | ---- | -------------------------------------------- | +| token | string | 是 | 签名验证代币 | +| encoding_aes_key | string | 是 | 用于解密的 43 字符 AES 密钥 | +| webhook_url | string | 是 | 用于发送回复的企业微信群聊机器人 Webhook URL | +| webhook_host | string | 否 | HTTP 服务器绑定地址(默认:0.0.0.0) | +| webhook_port | int | 否 | HTTP 服务器端口(默认:18793) | +| webhook_path | string | 否 | Webhook 端点路径(默认:/webhook/wecom) | +| allow_from | array | 否 | 用户 ID 白名单(空值 = 允许所有用户) | +| reply_timeout | int | 否 | 回复超时时间(单位:秒,默认值:5) | + +## 设置流程 + +1. 在企业微信群中添加机器人 +2. 获取 Webhook URL +3. (如需接收消息) 在机器人配置页面设置接收消息的 API 地址(回调地址)以及 Token 和 EncodingAESKey +4. 将相关信息填入配置文件 diff --git a/docs/design/provider-refactoring-tests.md b/docs/design/provider-refactoring-tests.md new file mode 100644 index 000000000..060be9ba8 --- /dev/null +++ b/docs/design/provider-refactoring-tests.md @@ -0,0 +1,174 @@ +# Provider Architecture Refactoring - Test Suite Summary + +This document summarizes the complete test suite designed for the Provider architecture refactoring. + +## Test File Structure + +``` +pkg/ +├── config/ +│ ├── model_config_test.go # US-001, US-002: ModelConfig struct and GetModelConfig tests +│ └── migration_test.go # US-003: Backward compatibility and migration tests +├── providers/ +│ ├── factory_test.go # US-004, US-005: Provider factory tests +│ └── factory_provider_test.go # Factory provider integration tests +``` + +--- + +## Test Case Checklist + +### 1. `pkg/config/model_config_test.go` - Configuration Parsing Tests + +| Test Name | Purpose | PRD Reference | +|-----------|---------|---------------| +| `TestModelConfig_Parsing` | Verify ModelConfig JSON parsing | US-001 | +| `TestModelConfig_ModelListInConfig` | Verify model_list parsing in Config | US-001 | +| `TestModelConfig_Validation` | Verify required field validation | US-001 | +| `TestConfig_GetModelConfig_Found` | Verify GetModelConfig finds model | US-002 | +| `TestConfig_GetModelConfig_NotFound` | Verify GetModelConfig returns error | US-002 | +| `TestConfig_GetModelConfig_EmptyModelList` | Verify empty model_list handling | US-002 | +| `TestConfig_BackwardCompatibility_ProvidersToModelList` | Verify old config conversion | US-003 | +| `TestConfig_DeprecationWarning` | Verify deprecation warning | US-003 | +| `TestModelConfig_ProtocolExtraction` | Verify protocol prefix extraction | US-004 | +| `TestConfig_ModelNameUniqueness` | Verify model_name uniqueness | US-001 | + +### 2. `pkg/config/migration_test.go` - Migration Tests + +| Test Name | Purpose | PRD Reference | +|-----------|---------|---------------| +| `TestConvertProvidersToModelList_OpenAI` | OpenAI config conversion | US-003 | +| `TestConvertProvidersToModelList_Anthropic` | Anthropic config conversion | US-003 | +| `TestConvertProvidersToModelList_MultipleProviders` | Multiple provider conversion | US-003 | +| `TestConvertProvidersToModelList_EmptyProviders` | Empty providers handling | US-003 | +| `TestConvertProvidersToModelList_GitHubCopilot` | GitHub Copilot conversion | US-003 | +| `TestConvertProvidersToModelList_Antigravity` | Antigravity conversion | US-003 | +| `TestGenerateModelName_*` | Model name generation | US-003 | +| `TestHasProvidersConfig_*` | Detect old config existence | US-003 | +| `TestValidateMigration_*` | Migration validation | US-003 | +| `TestMigrateConfig_DryRun` | Dry run migration | US-003 | +| `TestMigrateConfig_Actual` | Actual migration | US-003 | + +### 3. `pkg/providers/registry_test.go` - Load Balancing Tests + +| Test Name | Purpose | PRD Reference | +|-----------|---------|---------------| +| `TestModelRegistry_SingleConfig` | Single config returns same result | US-006 | +| `TestModelRegistry_RoundRobinSelection` | 3-config round-robin selection | US-006 | +| `TestModelRegistry_RoundRobinTwoConfigs` | 2-config round-robin selection | US-006 | +| `TestModelRegistry_ConcurrentAccess` | Concurrent access thread safety | US-006 | +| `TestModelRegistry_RaceDetection` | Data race detection | US-006 | +| `TestModelRegistry_ModelNotFound` | Model not found error | US-006 | +| `TestModelRegistry_EmptyRegistry` | Empty registry handling | US-006 | +| `TestModelRegistry_MultipleModels` | Multiple model registration | US-006 | +| `TestModelRegistry_MixedSingleAndMultiple` | Single/multiple config mix | US-006 | +| `TestModelRegistry_CaseSensitiveModelNames` | Case sensitivity | US-006 | + +### 4. `pkg/providers/factory/factory_test.go` - Provider Factory Tests + +| Test Name | Purpose | PRD Reference | +|-----------|---------|---------------| +| `TestCreateProviderFromConfig_OpenAI` | Create OpenAI provider | US-004 | +| `TestCreateProviderFromConfig_OpenAIDefault` | Default openai protocol | US-004 | +| `TestCreateProviderFromConfig_Anthropic` | Create Anthropic provider | US-004 | +| `TestCreateProviderFromConfig_Antigravity` | Create Antigravity provider | US-004 | +| `TestCreateProviderFromConfig_ClaudeCLI` | Create Claude CLI provider | US-004 | +| `TestCreateProviderFromConfig_CodexCLI` | Create Codex CLI provider | US-004 | +| `TestCreateProviderFromConfig_GitHubCopilot` | Create GitHub Copilot provider | US-004 | +| `TestCreateProviderFromConfig_UnknownProtocol` | Unknown protocol error handling | US-004 | +| `TestCreateProviderFromConfig_MissingAPIKey` | Missing API key error | US-004 | +| `TestExtractProtocol` | Protocol prefix extraction | US-004 | +| `TestCreateProvider_UsesModelList` | Create using model_list | US-005 | +| `TestCreateProvider_FallbackToProviders` | Fallback to providers | US-005 | +| `TestCreateProvider_PriorityModelListOverProviders` | model_list priority | US-005 | + +### 5. `pkg/providers/integration_test.go` - E2E Integration Tests + +| Test Name | Purpose | PRD Reference | +|-----------|---------|---------------| +| `TestE2E_OpenAICompatibleProvider_NoCodeChange` | Zero-code provider addition | Goal | +| `TestE2E_LoadBalancing_RoundRobin` | Load balancing actual effect | US-006 | +| `TestE2E_BackwardCompatibility_OldProvidersConfig` | Old config compatibility | US-003 | +| `TestE2E_ErrorHandling_ModelNotFound` | Model not found | FR-30 | +| `TestE2E_ErrorHandling_MissingAPIKey` | Missing API key | FR-31 | +| `TestE2E_ErrorHandling_InvalidAPIBase` | Invalid API base | FR-30 | +| `TestE2E_ToolCalls_OpenAICompatible` | Tool call support | - | +| `TestE2E_AntigravityProvider` | Antigravity provider | US-004 | +| `TestE2E_ClaudeCLIProvider` | Claude CLI provider | US-004 | + +### 6. Performance Tests + +| Test Name | Purpose | +|-----------|---------| +| `BenchmarkCreateProviderFromConfig` | Provider creation performance | +| `BenchmarkGetModelConfig` | Model lookup performance | +| `BenchmarkGetModelConfigParallel` | Concurrent lookup performance | + +--- + +## Running Tests + +```bash +# Run all tests +go test ./pkg/... -v + +# Run with data race detection +go test ./pkg/... -race + +# Run specific package tests +go test ./pkg/config -v +go test ./pkg/providers -v + +# Run E2E tests +go test ./pkg/providers -run TestE2E -v + +# Run performance tests +go test ./pkg/providers -bench=. -benchmem +``` + +--- + +## PRD Acceptance Criteria Mapping + +| PRD Acceptance Criteria | Test Cases | +|------------------------|------------| +| US-001: Add ModelConfig struct | `TestModelConfig_Parsing`, `TestModelConfig_Validation` | +| US-001: model_name unique | `TestConfig_ModelNameUniqueness` | +| US-002: GetModelConfig method | `TestConfig_GetModelConfig_*` | +| US-003: Auto-convert providers | `TestConvertProvidersToModelList_*` | +| US-003: Deprecation warning | `TestConfig_DeprecationWarning` | +| US-003: Existing tests pass | (existing test files unchanged) | +| US-004: Protocol prefix factory | `TestExtractProtocol`, `TestCreateProviderFromConfig_*` | +| US-004: Default prefix openai | `TestCreateProviderFromConfig_OpenAIDefault` | +| US-005: CreateProvider uses factory | `TestCreateProvider_*` | +| US-006: Round-robin selection | `TestModelRegistry_RoundRobin*` | +| US-006: Thread-safe atomic | `TestModelRegistry_RaceDetection` | + +--- + +## Recommended Implementation Order + +1. **Phase 1: Configuration Structure** (US-001, US-002) + - Implement `ModelConfig` struct + - Implement `GetModelConfig` method + - Run `model_config_test.go` + +2. **Phase 2: Protocol Factory** (US-004) + - Implement `CreateProviderFromConfig` + - Implement `ExtractProtocol` + - Run `factory_test.go` + +3. **Phase 3: Load Balancing** (US-006) + - Implement `ModelRegistry` + - Implement round-robin selection + - Run `registry_test.go` (with `-race`) + +4. **Phase 4: Backward Compatibility** (US-003, US-005) + - Implement `ConvertProvidersToModelList` + - Refactor `CreateProvider` + - Run `migration_test.go` + - Verify existing tests pass + +5. **Phase 5: E2E Verification** + - Run `integration_test.go` + - Manual testing with `config.example.json` diff --git a/docs/design/provider-refactoring.md b/docs/design/provider-refactoring.md new file mode 100644 index 000000000..a214d9857 --- /dev/null +++ b/docs/design/provider-refactoring.md @@ -0,0 +1,334 @@ +# Provider Architecture Refactoring Design + +> Issue: #283 +> Discussion: #122 +> Branch: feat/refactor-provider-by-protocol + +## 1. Current Problems + +### 1.1 Configuration Structure Issues + +**Current State**: Each Provider requires a predefined field in `ProvidersConfig` + +```go +type ProvidersConfig struct { + Anthropic ProviderConfig `json:"anthropic"` + OpenAI ProviderConfig `json:"openai"` + DeepSeek ProviderConfig `json:"deepseek"` + Qwen ProviderConfig `json:"qwen"` + Cerebras ProviderConfig `json:"cerebras"` + VolcEngine ProviderConfig `json:"volcengine"` + // ... every new provider requires changes here +} +``` + +**Problems**: +- Adding a new Provider requires modifying Go code (struct definition) +- `CreateProvider` function in `http_provider.go` has 200+ lines of switch-case +- Most Providers are OpenAI-compatible, but code is duplicated + +### 1.2 Code Bloat Trend + +Recent PRs demonstrate this issue: + +| PR | Provider | Code Changes | +|----|----------|--------------| +| #365 | Qwen | +17 lines to http_provider.go | +| #333 | Cerebras | +17 lines to http_provider.go | +| #368 | Volcengine | +18 lines to http_provider.go | + +Each OpenAI-compatible Provider requires: +1. Modify `config.go` to add configuration field +2. Modify `http_provider.go` to add switch case +3. Update documentation + +### 1.3 Agent-Provider Coupling + +```json +{ + "agents": { + "defaults": { + "provider": "deepseek", // need to know provider name + "model": "deepseek-chat" + } + } +} +``` + +Problem: Agent needs to know both `provider` and `model`, adding complexity. + +--- + +## 2. New Approach: model_list + +### 2.1 Core Principles + +Inspired by [LiteLLM](https://docs.litellm.ai/docs/proxy/configs) design: + +1. **Model-centric**: Users care about models, not providers +2. **Protocol prefix**: Use `protocol/model_name` format, e.g., `openai/gpt-5.2`, `anthropic/claude-sonnet-4.6` +3. **Configuration-driven**: Adding new Providers only requires config changes, no code changes + +### 2.2 New Configuration Structure + +```json +{ + "model_list": [ + { + "model_name": "deepseek-chat", + "model": "openai/deepseek-chat", + "api_base": "https://api.deepseek.com/v1", + "api_key": "sk-xxx" + }, + { + "model_name": "gpt-5.2", + "model": "openai/gpt-5.2", + "api_key": "sk-xxx" + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "sk-xxx" + }, + { + "model_name": "gemini-3-flash", + "model": "antigravity/gemini-3-flash", + "auth_method": "oauth" + }, + { + "model_name": "my-company-llm", + "model": "openai/company-model-v1", + "api_base": "https://llm.company.com/v1", + "api_key": "xxx" + } + ], + + "agents": { + "defaults": { + "model": "deepseek-chat", + "max_tokens": 8192, + "temperature": 0.7 + } + } +} +``` + +### 2.3 Go Struct Definition + +```go +type Config struct { + ModelList []ModelConfig `json:"model_list"` // new + Providers ProvidersConfig `json:"providers"` // old, deprecated + + Agents AgentsConfig `json:"agents"` + Channels ChannelsConfig `json:"channels"` + // ... +} + +type ModelConfig struct { + // Required + ModelName string `json:"model_name"` // user-facing name (alias) + Model string `json:"model"` // protocol/model, e.g., openai/gpt-5.2 + + // Common config + APIBase string `json:"api_base,omitempty"` + APIKey string `json:"api_key,omitempty"` + Proxy string `json:"proxy,omitempty"` + + // Special provider config + AuthMethod string `json:"auth_method,omitempty"` // oauth, token + ConnectMode string `json:"connect_mode,omitempty"` // stdio, grpc + + // Optional optimizations + RPM int `json:"rpm,omitempty"` // rate limit + MaxTokensField string `json:"max_tokens_field,omitempty"` // max_tokens or max_completion_tokens +} +``` + +### 2.4 Protocol Recognition + +Identify protocol via prefix in `model` field: + +| Prefix | Protocol | Description | +|--------|----------|-------------| +| `openai/` | OpenAI-compatible | Most common, includes DeepSeek, Qwen, Groq, etc. | +| `anthropic/` | Anthropic | Claude series specific | +| `antigravity/` | Antigravity | Google Cloud Code Assist | +| `gemini/` | Gemini | Google Gemini native API (if needed) | + +--- + +## 3. Design Rationale + +### 3.1 Problems Solved + +| Problem | Old Approach | New Approach | +|---------|--------------|--------------| +| Add OpenAI-compatible Provider | Change 3 code locations | Add one config entry | +| Agent specifies model | Need provider + model | Only need model | +| Code duplication | Each Provider duplicates logic | Share protocol implementation | +| Multi-Agent support | Complex | Naturally compatible | + +### 3.2 Multi-Agent Compatibility + +```json +{ + "model_list": [...], + + "agents": { + "defaults": { + "model": "deepseek-chat" + }, + "coder": { + "model": "gpt-5.2", + "system_prompt": "You are a coding assistant..." + }, + "translator": { + "model": "claude-sonnet-4.6" + } + } +} +``` + +Each Agent only needs to specify `model` (corresponds to `model_name` in `model_list`). + +### 3.3 Industry Comparison + +**LiteLLM** (most mature open-source LLM Proxy) uses similar design: + +```yaml +model_list: + - model_name: gpt-4o + litellm_params: + model: openai/gpt-5.2 + api_key: xxx + - model_name: my-custom + litellm_params: + model: openai/custom-model + api_base: https://my-api.com/v1 +``` + +--- + +## 4. Migration Plan + +### 4.1 Phase 1: Compatibility Period (v1.x) + +Support both `providers` and `model_list`: + +```go +func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) { + // Prefer new config + if len(c.ModelList) > 0 { + return c.findModelByName(modelName) + } + + // Backward compatibility with old config + if !c.Providers.IsEmpty() { + logger.Warn("'providers' config is deprecated, please migrate to 'model_list'") + return c.convertFromProviders(modelName) + } + + return nil, fmt.Errorf("model %s not found", modelName) +} +``` + +### 4.2 Phase 2: Warning Period (late v1.x) + +- Print more prominent warnings at startup +- Provide automatic migration script +- Mark `providers` as deprecated in documentation + +### 4.3 Phase 3: Removal Period (v2.0) + +- Completely remove `providers` support +- Remove `agents.defaults.provider` field +- Only support `model_list` + +### 4.4 Configuration Migration Example + +**Old Config**: +```json +{ + "providers": { + "deepseek": { + "api_key": "sk-xxx", + "api_base": "https://api.deepseek.com/v1" + } + }, + "agents": { + "defaults": { + "provider": "deepseek", + "model": "deepseek-chat" + } + } +} +``` + +**New Config**: +```json +{ + "model_list": [ + { + "model_name": "deepseek-chat", + "model": "openai/deepseek-chat", + "api_base": "https://api.deepseek.com/v1", + "api_key": "sk-xxx" + } + ], + "agents": { + "defaults": { + "model": "deepseek-chat" + } + } +} +``` + +--- + +## 5. Implementation Checklist + +### 5.1 Configuration Layer + +- [ ] Add `ModelConfig` struct +- [ ] Add `Config.ModelList` field +- [ ] Implement `GetModelConfig(modelName)` method +- [ ] Implement old config compatibility conversion +- [ ] Add `model_name` uniqueness validation + +### 5.2 Provider Layer + +- [ ] Create `pkg/providers/factory/` directory +- [ ] Implement `CreateProviderFromModelConfig()` +- [ ] Refactor `http_provider.go` to `openai/provider.go` +- [ ] Maintain backward compatibility for old `CreateProvider()` + +### 5.3 Testing + +- [ ] New config unit tests +- [ ] Old config compatibility tests +- [ ] Integration tests + +### 5.4 Documentation + +- [ ] Update README +- [ ] Update config.example.json +- [ ] Write migration guide + +--- + +## 6. Risks and Mitigations + +| Risk | Mitigation | +|------|------------| +| Breaking existing configs | Compatibility period keeps old config working | +| User migration cost | Provide automatic migration script | +| Special Provider incompatibility | Keep `auth_method` and other extension fields | + +--- + +## 7. References + +- [LiteLLM Config Documentation](https://docs.litellm.ai/docs/proxy/configs) +- [One-API GitHub](https://github.com/songquanpeng/one-api) +- Discussion #122: Refactor Provider Architecture diff --git a/docs/migration/model-list-migration.md b/docs/migration/model-list-migration.md new file mode 100644 index 000000000..589dfc043 --- /dev/null +++ b/docs/migration/model-list-migration.md @@ -0,0 +1,219 @@ +# Migration Guide: From `providers` to `model_list` + +This guide explains how to migrate from the legacy `providers` configuration to the new `model_list` format. + +## Why Migrate? + +The new `model_list` configuration offers several advantages: + +- **Zero-code provider addition**: Add OpenAI-compatible providers with configuration only +- **Load balancing**: Configure multiple endpoints for the same model +- **Protocol-based routing**: Use prefixes like `openai/`, `anthropic/`, etc. +- **Cleaner configuration**: Model-centric instead of vendor-centric + +## Timeline + +| Version | Status | +|---------|--------| +| v1.x | `model_list` introduced, `providers` deprecated but functional | +| v1.x+1 | Prominent deprecation warnings, migration tool available | +| v2.0 | `providers` configuration removed | + +## Before and After + +### Before: Legacy `providers` Configuration + +```json +{ + "providers": { + "openai": { + "api_key": "sk-your-openai-key", + "api_base": "https://api.openai.com/v1" + }, + "anthropic": { + "api_key": "sk-ant-your-key" + }, + "deepseek": { + "api_key": "sk-your-deepseek-key" + } + }, + "agents": { + "defaults": { + "provider": "openai", + "model": "gpt-5.2" + } + } +} +``` + +### After: New `model_list` Configuration + +```json +{ + "model_list": [ + { + "model_name": "gpt4", + "model": "openai/gpt-5.2", + "api_key": "sk-your-openai-key", + "api_base": "https://api.openai.com/v1" + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "sk-ant-your-key" + }, + { + "model_name": "deepseek", + "model": "deepseek/deepseek-chat", + "api_key": "sk-your-deepseek-key" + } + ], + "agents": { + "defaults": { + "model": "gpt4" + } + } +} +``` + +## Protocol Prefixes + +The `model` field uses a protocol prefix format: `[protocol/]model-identifier` + +| Prefix | Description | Example | +|--------|-------------|---------| +| `openai/` | OpenAI API (default) | `openai/gpt-5.2` | +| `anthropic/` | Anthropic API | `anthropic/claude-opus-4` | +| `antigravity/` | Google via Antigravity OAuth | `antigravity/gemini-2.0-flash` | +| `gemini/` | Google Gemini API | `gemini/gemini-2.0-flash-exp` | +| `claude-cli/` | Claude CLI (local) | `claude-cli/claude-sonnet-4.6` | +| `codex-cli/` | Codex CLI (local) | `codex-cli/codex-4` | +| `github-copilot/` | GitHub Copilot | `github-copilot/gpt-4o` | +| `openrouter/` | OpenRouter | `openrouter/anthropic/claude-sonnet-4.6` | +| `groq/` | Groq API | `groq/llama-3.1-70b` | +| `deepseek/` | DeepSeek API | `deepseek/deepseek-chat` | +| `cerebras/` | Cerebras API | `cerebras/llama-3.3-70b` | +| `qwen/` | Alibaba Qwen | `qwen/qwen-max` | +| `zhipu/` | Zhipu AI | `zhipu/glm-4` | +| `nvidia/` | NVIDIA NIM | `nvidia/llama-3.1-nemotron-70b` | +| `ollama/` | Ollama (local) | `ollama/llama3` | +| `vllm/` | vLLM (local) | `vllm/my-model` | +| `moonshot/` | Moonshot AI | `moonshot/moonshot-v1-8k` | +| `shengsuanyun/` | ShengSuanYun | `shengsuanyun/deepseek-v3` | +| `volcengine/` | Volcengine | `volcengine/doubao-pro-32k` | + +**Note**: If no prefix is specified, `openai/` is used as the default. + +## ModelConfig Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `model_name` | Yes | User-facing alias for the model | +| `model` | Yes | Protocol and model identifier (e.g., `openai/gpt-5.2`) | +| `api_base` | No | API endpoint URL | +| `api_key` | No* | API authentication key | +| `proxy` | No | HTTP proxy URL | +| `auth_method` | No | Authentication method: `oauth`, `token` | +| `connect_mode` | No | Connection mode for CLI providers: `stdio`, `grpc` | +| `rpm` | No | Requests per minute limit | +| `max_tokens_field` | No | Field name for max tokens | + +*`api_key` is required for HTTP-based protocols unless `api_base` points to a local server. + +## Load Balancing + +Configure multiple endpoints for the same model to distribute load: + +```json +{ + "model_list": [ + { + "model_name": "gpt4", + "model": "openai/gpt-5.2", + "api_key": "sk-key1", + "api_base": "https://api1.example.com/v1" + }, + { + "model_name": "gpt4", + "model": "openai/gpt-5.2", + "api_key": "sk-key2", + "api_base": "https://api2.example.com/v1" + }, + { + "model_name": "gpt4", + "model": "openai/gpt-5.2", + "api_key": "sk-key3", + "api_base": "https://api3.example.com/v1" + } + ] +} +``` + +When you request model `gpt4`, requests will be distributed across all three endpoints using round-robin selection. + +## Adding a New OpenAI-Compatible Provider + +With `model_list`, adding a new provider requires zero code changes: + +```json +{ + "model_list": [ + { + "model_name": "my-custom-llm", + "model": "openai/my-model-v1", + "api_key": "your-api-key", + "api_base": "https://api.your-provider.com/v1" + } + ] +} +``` + +Just specify `openai/` as the protocol (or omit it for the default), and provide your provider's API base URL. + +## Backward Compatibility + +During the migration period, your existing `providers` configuration will continue to work: + +1. If `model_list` is empty and `providers` has data, the system auto-converts internally +2. A deprecation warning is logged: `"providers config is deprecated, please migrate to model_list"` +3. All existing functionality remains unchanged + +## Migration Checklist + +- [ ] Identify all providers you're currently using +- [ ] Create `model_list` entries for each provider +- [ ] Use appropriate protocol prefixes +- [ ] Update `agents.defaults.model` to reference the new `model_name` +- [ ] Test that all models work correctly +- [ ] Remove or comment out the old `providers` section + +## Troubleshooting + +### Model not found error + +``` +model "xxx" not found in model_list or providers +``` + +**Solution**: Ensure the `model_name` in `model_list` matches the value in `agents.defaults.model`. + +### Unknown protocol error + +``` +unknown protocol "xxx" in model "xxx/model-name" +``` + +**Solution**: Use a supported protocol prefix. See the [Protocol Prefixes](#protocol-prefixes) table above. + +### Missing API key error + +``` +api_key or api_base is required for HTTP-based protocol "xxx" +``` + +**Solution**: Provide `api_key` and/or `api_base` for HTTP-based providers. + +## Need Help? + +- [GitHub Issues](https://github.com/sipeed/picoclaw/issues) +- [Discussion #122](https://github.com/sipeed/picoclaw/discussions/122): Original proposal diff --git a/docs/picoclaw_community_roadmap_260216.md b/docs/picoclaw_community_roadmap_260216.md new file mode 100644 index 000000000..95de768c6 --- /dev/null +++ b/docs/picoclaw_community_roadmap_260216.md @@ -0,0 +1,112 @@ +## 🚀 Join the PicoClaw Journey: Call for Community Volunteers & Roadmap Reveal + +**Hello, PicoClaw Community!** + +First, a massive thank you to everyone for your enthusiasm and PR contributions. It is because of you that PicoClaw continues to iterate and evolve so rapidly. Thanks to the simplicity and accessibility of the **Go language**, we’ve seen a non-stop stream of high-quality PRs! + +PicoClaw is growing much faster than we anticipated. As we are currently in the midst of the **Chinese New Year holiday**, we are looking to recruit community volunteers to help us maintain this incredible momentum. + +This document outlines the specific volunteer roles we need right now and provides a look at our upcoming **Roadmap**. + +### 🎁 Community Perks + +To show our appreciation, developers who officially join our community operations will receive: + +* **Exclusive AI Hardware:** Our upcoming, unreleased AI device. +* **Token Discounts:** Potential discounts on LLM tokens (currently in negotiations with major providers). + +### 🎥 Calling All Content Creators! + +Not a developer? You can still help! We welcome users to post **PicoClaw reviews or tutorials**. + +* **Twitter:** Use the tag **#picoclaw** and mention **@SipeedIO**. +* **Bilibili:** Mention **@Sipeed矽速科技** or send us a DM. +We will be rewarding high-quality content creators with the same perks as our community developers! + +--- + +## 🛠️ Urgent Volunteer Roles + +We are looking for experts in the following areas: + +1. **Issue/PR Reviewers** +* **The Mission:** With PRs and Issues exploding in volume, we need help with initial triage, evaluation, and merging. +* **Focus:** Preliminary merging and community health. Efficiency optimization and security audits will be handled by specialized roles. + + +2. **Resource Optimization Experts** +* **The Mission:** Rapid growth has introduced dependencies that are making PicoClaw a bit "heavy." We want to keep it lean. +* **Focus:** Analyzing resource growth between releases and trimming redundancy. +* **Priority:** **RAM usage optimization** > Binary size reduction. + + +3. **Security Audit & Bug Fixes** +* **The Mission:** Due to the "vibe coding" nature of our early stages, we need a thorough review of network security and AI permission management. +* **Focus:** Auditing the codebase for vulnerabilities and implementing robust fixes. + + +4. **Documentation & DX (Developer Experience)** +* **The Mission:** Our current README is a bit outdated. We need "step-by-step" guides that even beginners can follow. +* **Focus:** Creating clear, user-friendly documentation for both setup and development. + + +5. **AI-Powered CI/CD Optimization** +* **The Mission:** PicoClaw started as a "vibe coding" experiment; now we want to use AI to manage it. +* **Focus:** Automating builds with AI and exploring AI-driven issue resolution. + +**How to Apply:** > If you are interested in any of the roles above, please send an email to support@sipeed.com with the subject line: [Apply: PicoClaw Expert Volunteer] + Your Desired Role. +Please include a brief introduction and any relevant experience or portfolio links. We will review all applications and grant project permissions to selected contributors! + +--- + +## 📍 The Roadmap + +Interested in a specific feature? You can "claim" these tasks and start building: + +### +* **Provider:** + * **Provider Refactor:** Currently being handled by **@Daming** (ETA: 5 days) + * You can still submit code; Daming will merge it into the new implementation. +* **Channels:** + * Support for OneBot, additional platforms + * attachments (images, audio, video, files). +* **Skills:** + * Implementing `find_skill` to discover tools via [ClawhHub](https://clawhub.ai) and other platforms. +* **Operations:** * MCP Support. + * Android operations (e.g., botdrop). + * Browser automation via CDP or ActionBook. + + +* **Multi-Agent Ecosystem:** + * **Basic Model-Agent** + * **Model Routing:** Small models for easy tasks, large models for hard ones (to save tokens). + * **Swarm Mode.** + * **AIEOS Integration.** + + +* **Branding:** + * **Logo**: We need a cute logo! We’re leaning toward a **Mantis Shrimp**—small, but packs a legendary punch! + + +We have officially created these tasks as GitHub Issues, all marked with the roadmap tag. +This list will be updated continuously as we progress. +If you would like to claim a task, please feel free to start a conversation by commenting directly on the corresponding issue! + +--- + +## 🤝 How to Join + +**Everything is open to your creativity!** If you have a wild idea, just PR it. + +1. **The Fast Track:** Once you have at least **one merged PR**, you are eligible to join our **Developer Discord** to help plan the future of PicoClaw. +2. **The Application Track:** If you haven’t submitted a PR yet but want to dive in, email **support@sipeed.com** with the subject: +> `[Apply Join PicoClaw Dev Group] + Your GitHub Account` +> Include the role you're interested in and any evidence of your development experience. + + + +### Looking Ahead + +Powered by PicoClaw, we are crafting a Swarm AI Assistant to transform your environment into a seamless network of personal stewards. By automating the friction of daily life, we empower you to transcend the ordinary and freely explore your creative potential. + +**Finally, Happy Chinese New Year to everyone!** May PicoClaw gallop forward in this **Year of the Horse!** 🐎 diff --git a/docs/swarm/API.md b/docs/swarm/API.md new file mode 100644 index 000000000..62533655e --- /dev/null +++ b/docs/swarm/API.md @@ -0,0 +1,516 @@ +# PicoClaw Swarm API Reference + +> PicoClaw Swarm Mode | API Reference +> Last Updated: 2026-02-20 + +--- + +## Go API + +### Manager + +The `Manager` is the main entry point for swarm functionality. + +```go +package swarm + +// NewManager creates a new swarm manager +func NewManager( + cfg *config.Config, + agentLoop *agent.AgentLoop, + provider providers.LLMProvider, + localBus *bus.MessageBus, +) *Manager + +// Start initializes and starts all swarm components +func (m *Manager) Start(ctx context.Context) error + +// Stop gracefully stops all swarm components +func (m *Manager) Stop() + +// GetNodeInfo returns this node's information +func (m *Manager) GetNodeInfo() *NodeInfo + +// GetDiscoveredNodes returns all discovered nodes +func (m *Manager) GetDiscoveredNodes() []*NodeInfo + +// IsNATSConnected returns true if connected to NATS +func (m *Manager) IsNATSConnected() bool + +// IsTemporalConnected returns true if connected to Temporal +func (m *Manager) IsTemporalConnected() bool +``` + +### Dashboard + +```go +// NewDashboard creates a new dashboard +func NewDashboard(manager *Manager) *Dashboard + +// Start begins the dashboard update loop +func (d *Dashboard) Start(ctx context.Context) error + +// Stop stops the dashboard +func (d *Dashboard) Stop() + +// Render returns a formatted string representation +func (d *Dashboard) Render() string + +// RenderCompact returns a one-line status +func (d *Dashboard) RenderCompact() string + +// GetState returns the current dashboard state +func (d *Dashboard) GetState() *DashboardState +``` + +### Discovery + +```go +// NewDiscovery creates a new discovery service +func NewDiscovery( + bridge *NATSBridge, + nodeInfo *NodeInfo, + cfg *config.SwarmConfig, +) *Discovery + +// Start begins the discovery service +func (d *Discovery) Start(ctx context.Context) error + +// Stop stops the discovery service +func (d *Discovery) Stop() + +// GetNodes returns all known nodes, optionally filtered +func (d *Discovery) GetNodes(role, capability string) []*NodeInfo + +// SelectWorker selects the best worker for a task +func (d *Discovery) SelectWorker(task *SwarmTask) (*NodeInfo, error) + +// SelectWorkerWithPriority selects worker considering priority +func (d *Discovery) SelectWorkerWithPriority( + task *SwarmTask, + priority int, +) (*NodeInfo, error) +``` + +### Worker + +```go +// NewWorker creates a new worker +func NewWorker( + cfg *config.SwarmConfig, + bridge *NATSBridge, + temporal *TemporalClient, + agentLoop *agent.AgentLoop, + provider providers.LLMProvider, + nodeInfo *NodeInfo, +) *Worker + +// Start begins accepting tasks +func (w *Worker) Start(ctx context.Context) error + +// Stop stops accepting new tasks +func (w *Worker) Stop() + +// ExecuteTask runs a single task +func (w *Worker) ExecuteTask(ctx context.Context, task *SwarmTask) error + +// GetLoad returns current load (0-1) +func (w *Worker) GetLoad() float64 +``` + +### Coordinator + +```go +// NewCoordinator creates a new coordinator +func NewCoordinator( + cfg *config.SwarmConfig, + bridge *NATSBridge, + temporal *TemporalClient, + discovery *Discovery, + agentLoop *agent.AgentLoop, + provider providers.LLMProvider, + localBus *bus.MessageBus, +) *Coordinator + +// Start begins accepting requests +func (c *Coordinator) Start(ctx context.Context) error + +// Stop stops the coordinator +func (c *Coordinator) Stop() + +// SubmitTask submits a task for distributed execution +func (c *Coordinator) SubmitTask( + ctx context.Context, + task *SwarmTask, +) (*TaskResult, error) +``` + +### Types + +```go +// NodeInfo represents a node in the swarm +type NodeInfo struct { + ID string + Role NodeRole // "coordinator", "worker", "specialist" + Status NodeStatus // "online", "busy", "offline", "suspicious" + Capabilities []string + Model string + Load float64 + TasksRunning int + MaxTasks int + LastSeen int64 + StartedAt int64 + Metadata map[string]string +} + +// SwarmTask represents a task to be executed +type SwarmTask struct { + ID string + Type string + Prompt string + Priority int // 0-3 + Capabilities []string + ParentID string + Dependencies []string + Context map[string]interface{} +} + +// TaskResult represents the result of a task +type TaskResult struct { + TaskID string + NodeID string + Success bool + Content string + Error string + Duration time.Duration + Metadata map[string]interface{} +} +``` + +--- + +## NATS Subjects + +### Heartbeat + +``` +picoclaw.swarm.heartbeat.{node_id} +``` + +### Discovery + +``` +picoclaw.swarm.discovery.announce +picoclaw.swarm.discovery.query +``` + +### Task Assignment + +``` +picoclaw.swarm.task.assign.{node_id} +picoclaw.swarm.task.broadcast.{capability} +``` + +### Task Results + +``` +picoclaw.swarm.task.result.{task_id} +picoclaw.swarm.task.progress.{task_id} +``` + +### System + +``` +picoclaw.swarm.system.shutdown.{node_id} +``` + +### Cross-H-id Communication + +``` +picoclaw.x.{from_hid}.{to_hid} +``` + +--- + +## Configuration API + +### Config Structure + +```go +type SwarmConfig struct { + Enabled bool + Role string + Capabilities []string + MaxConcurrent int + HID string + SID string + NATS NATSConfig + Temporal TemporalConfig +} + +type NATSConfig struct { + URLs []string + Embedded bool + EmbeddedPort int + Credentials string + HeartbeatInterval string + NodeTimeout string +} + +type TemporalConfig struct { + Address string + Namespace string + TaskQueue string +} +``` + +### Environment Variables + +```go +// LoadFromEnv loads configuration from environment variables +func (c *SwarmConfig) LoadFromEnv() error + +// Validate validates the configuration +func (c *SwarmConfig) Validate() error +``` + +--- + +## CLI Flags + +```bash +# Swarm enablement +--swarm.enabled bool Enable swarm mode +--swarm.role string Node role (coordinator/worker/specialist) + +# Identity +--swarm.hid string Human identity +--swarm.sid string Session identity + +# NATS +--swarm.nats.urls strings NATS server URLs +--swarm.nats.embedded bool Use embedded NATS +--swarm.nats.embedded-port int Embedded NATS port +--swarm.nats.credentials string NATS credentials file + +# Capabilities +--swarm.capabilities strings Node capabilities +--swarm.max-concurrent int Max concurrent tasks + +# Temporal +--swarm.temporal.address string Temporal server address +--swarm.temporal.task-queue string Temporal task queue + +# Monitoring +--swarm.dashboard bool Enable dashboard output +--swarm.heartbeat-interval string Heartbeat interval +--swarm.node-timeout string Node timeout +``` + +--- + +## Example: Programmatic Usage + +### Basic Swarm Setup + +```go +package main + +import ( + "context" + "github.com/sipeed/picoclaw/pkg/swarm" + "github.com/sipeed/picoclaw/pkg/config" +) + +func main() { + ctx := context.Background() + + // Load config + cfg := config.Load() + + // Create manager + manager := swarm.NewManager(cfg, agentLoop, provider, bus) + + // Start swarm + if err := manager.Start(ctx); err != nil { + log.Fatal(err) + } + defer manager.Stop() + + // Create dashboard + dashboard := swarm.NewDashboard(manager) + dashboard.Start(ctx) + defer dashboard.Stop() + + // Run... + select {} +} +``` + +### Submit a Task + +```go +// Assuming coordinator is set up +task := &swarm.SwarmTask{ + ID: "task-001", + Type: "code_review", + Prompt: "Review this code for security issues...", + Priority: 2, // High priority + Capabilities: []string{"security", "audit"}, +} + +result, err := coordinator.SubmitTask(ctx, task) +if err != nil { + log.Fatal(err) +} + +fmt.Printf("Result: %s\n", result.Content) +``` + +### Monitor Swarm Status + +```go +// Get all nodes +nodes := manager.GetDiscoveredNodes() +for _, node := range nodes { + fmt.Printf("%s: %s (%.0f%% load)\n", + node.ID, node.Status, node.Load*100) +} + +// Or use dashboard +fmt.Println(dashboard.Render()) +``` + +--- + +## Events and Callbacks + +### Node Events + +```go +// Register callbacks for node join/leave +discovery.OnNodeJoin(func(node *NodeInfo) { + log.Printf("Node joined: %s", node.ID) +}) + +discovery.OnNodeLeave(func(nodeID string) { + log.Printf("Node left: %s", nodeID) +}) +``` + +### Task Events + +```go +// Register callbacks for task events +worker.OnTaskAssigned(func(task *SwarmTask) { + log.Printf("Task assigned: %s", task.ID) +}) + +worker.OnTaskComplete(func(result *TaskResult) { + log.Printf("Task complete: %s -> %v", + result.TaskID, result.Success) +}) +``` + +### Election Events + +```go +// Register callbacks for leader changes +electionMgr.OnBecameLeader(func() { + log.Printf("Became leader!") + // Promote to coordinator role +}) + +electionMgr.OnLostLeadership(func() { + log.Printf("Lost leadership!") + // Demote to worker role +}) +``` + +--- + +## Error Handling + +### Common Errors + +```go +// NATS connection failed +if err := manager.Start(ctx); err != nil { + if strings.Contains(err.Error(), "NATS") { + // Check NATS server availability + log.Fatal("Cannot connect to NATS") + } +} + +// Task timeout +if result, err := coordinator.SubmitTask(ctx, task); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + // Handle timeout + } +} + +// No workers available +if _, err := discovery.SelectWorker(task); err != nil { + if err == ErrNoWorkersAvailable { + // No matching workers for this task + } +} +``` + +--- + +## Testing API + +### Test Helpers + +```go +// Create test NATS environment +tn := swarm.StartTestNATS(t) +defer tn.Stop() + +// Create test bridge +bridge := swarm.ConnectTestBridge(t, tn.URL, nodeInfo) +defer bridge.Stop() + +// Create test node info +nodeInfo := swarm.CreateTestNodeInfo( + "test-node", + "worker", + []string{"test"}, +) +``` + +--- + +## Performance Considerations + +### Connection Pooling + +NATS connections are automatically pooled and reused. The `NATSBridge` handles connection management. + +### Batch Operations + +For multiple task submissions, consider using the workflow API: + +```go +// Submit multiple tasks as a workflow +workflow := &swarm.ParallelWorkflow{ + Tasks: []*SwarmTask{task1, task2, task3}, +} + +result, err := coordinator.ExecuteWorkflow(ctx, workflow) +``` + +### Memory Management + +- Nodes maintain a cache of discovered peers +- Old entries are pruned based on heartbeat timeout +- Adjust `--swarm.node-timeout` for your environment + +--- + +## See Also + +- [DEPLOYMENT.md](./DEPLOYMENT.md) - Deployment guide +- [CONFIG.md](./CONFIG.md) - Configuration reference +- [EXAMPLES.md](./EXAMPLES.md) - Code examples diff --git a/docs/swarm/DEPLOYMENT.md b/docs/swarm/DEPLOYMENT.md new file mode 100644 index 000000000..d128928f4 --- /dev/null +++ b/docs/swarm/DEPLOYMENT.md @@ -0,0 +1,485 @@ +# PicoClaw Swarm Deployment Guide + +> PicoClaw Swarm Mode | Deployment Guide +> Last Updated: 2026-02-20 + +--- + +## Overview + +PicoClaw Swarm enables multiple AI agent instances to work together collaboratively. This guide covers deployment scenarios from local development to production distributed clusters. + +--- + +## Quick Start + +### Single Node (Local Development) + +```bash +# Start embedded NATS, run as coordinator +picoclaw --swarm.enabled --swarm.role coordinator --swarm.nats.embedded + +# In another terminal, start a worker +picoclaw --swarm.enabled --swarm.role worker --swarm.nats.embedded +``` + +### Multi-Node (Docker Compose) + +```yaml +# docker-compose.yml +version: '3.8' +services: + nats: + image: nats:latest + command: "-js" + ports: + - "4222:4222" + + coordinator: + image: picoclaw:latest + environment: + - PICOCLAW_SWARM_ENABLED=true + - PICOCLAW_SWARM_ROLE=coordinator + - PICOCLAW_SWARM_NATS_URLS=nats://nats:4222 + + worker: + image: picoclaw:latest + environment: + - PICOCLAW_SWARM_ENABLED=true + - PICOCLAW_SWARM_ROLE=worker + - PICOCLAW_SWARM_CAPABILITIES=code,research + - PICOCLAW_SWARM_NATS_URLS=nats://nats:4222 + deploy: + replicas: 3 +``` + +```bash +docker-compose up -d +``` + +--- + +## Architecture + +``` +┌─────────────────────────────────────────────────────────┐ +│ PicoClaw Swarm │ +├─────────────────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ Coordinator │◄───────►│ Worker │ │ +│ │ │ │ │ │ +│ │ • Task Decomp│ │ • Execute │ │ +│ │ • Scheduling │ │ • Report │ │ +│ │ • Synthesis │ │ │ │ +│ └───────┬──────┘ └──────▲───────┘ │ +│ │ │ │ +│ └────────────────────────┘ │ +│ │ │ +│ ▼─────▼ │ +│ ┌──────────┐ │ +│ │ NATS │ │ +│ │ JetStream│ │ +│ └──────────┘ │ +│ │ +└─────────────────────────────────────────────────────────┘ +``` + +--- + +## Configuration + +### Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `PICOCLAW_SWARM_ENABLED` | Enable swarm mode | `false` | +| `PICOCLAW_SWARM_ROLE` | Node role | `worker` | +| `PICOCLAW_SWARM_NATS_URLS` | NATS servers | `nats://localhost:4222` | +| `PICOCLAW_SWARM_CAPABILITIES` | Node capabilities | `general` | +| `PICOCLAW_SWARM_MAX_CONCURRENT` | Max parallel tasks | `5` | +| `PICOCLAW_SWARM_HID` | Human identity | auto-generated | +| `PICOCLAW_SWARM_SID` | Session identity | auto-generated | + +### Config File + +```yaml +# ~/.picoclaw/config.yaml +swarm: + enabled: true + role: coordinator + capabilities: ["coordination", "scheduling"] + max_concurrent: 10 + + nats: + urls: + - nats://localhost:4222 + embedded: false + heartbeat_interval: 10s + node_timeout: 60s + + temporal: + address: localhost:7233 + task_queue: picoclaw-swarm +``` + +--- + +## Deployment Scenarios + +### 1. Local Development + +**Embedded NATS** (no external dependencies): + +```bash +# Terminal 1: Coordinator +picoclaw --swarm.enabled \ + --swarm.role coordinator \ + --swarm.nats.embedded + +# Terminal 2-3: Workers +picoclaw --swarm.enabled \ + --swarm.role worker \ + --swarm.capabilities code \ + --swarm.nats.embedded +``` + +### 2. Single Machine Swarm + +**Shared NATS** on the same machine: + +```bash +# Start NATS server +nats-server -js -p 4222 + +# Start nodes +picoclaw --swarm.enabled --swarm.role coordinator \ + --swarm.nats.urls nats://localhost:4222 + +picoclaw --swarm.enabled --swarm.role worker \ + --swarm.capabilities code,research \ + --swarm.nats.urls nats://localhost:4222 +``` + +### 3. Multi-Machine Swarm + +**Distributed NATS cluster**: + +```bash +# Machine 1 (NATS + Coordinator) +nats-server -cluster -js -p 4222 \ + -routes nats://machine2:6222,nats://machine3:6222 + +picoclaw --swarm.enabled --swarm.role coordinator \ + --swarm.nats.urls nats://localhost:4222 + +# Machine 2-3 (Workers) +picoclaw --swarm.enabled --swarm.role worker \ + --swarm.nats.urls nats://machine1:4222 +``` + +### 4. Kubernetes Deployment + +```yaml +# k8s/coordinator.yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: picoclaw-coordinator +spec: + replicas: 1 + selector: + matchLabels: + app: picoclaw-coordinator + template: + metadata: + labels: + app: picoclaw-coordinator + spec: + containers: + - name: picoclaw + image: picoclaw:latest + env: + - name: PICOCLAW_SWARM_ENABLED + value: "true" + - name: PICOCLAW_SWARM_ROLE + value: "coordinator" + - name: PICOCLAW_SWARM_NATS_URLS + value: "nats://nats:4222" +--- +# k8s/worker.yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: picoclaw-worker +spec: + replicas: 5 + selector: + matchLabels: + app: picoclaw-worker + template: + metadata: + labels: + app: picoclaw-worker + spec: + containers: + - name: picoclaw + image: picoclaw:latest + env: + - name: PICOCLAW_SWARM_ENABLED + value: "true" + - name: PICOCLAW_SWARM_ROLE + value: "worker" + - name: PICOCLAW_SWARM_NATS_URLS + value: "nats://nats:4222" + - name: PICOCLAW_SWARM_CAPABILITIES + value: "code,research,writing" +``` + +--- + +## Node Roles + +### Coordinator + +**Responsibilities:** +- Receive user requests +- Decompose tasks into sub-tasks +- Route tasks to appropriate workers +- Synthesize results from multiple workers +- Manage workflow state + +**Configuration:** +```bash +picoclaw --swarm.role coordinator \ + --swarm.capabilities coordination,scheduling +``` + +### Worker + +**Responsibilities:** +- Execute assigned tasks +- Report progress and results +- Advertise capabilities +- Handle task queues + +**Configuration:** +```bash +picoclaw --swarm.role worker \ + --swarm.capabilities code,research \ + --swarm.max_concurrent 5 +``` + +### Specialist + +**Responsibilities:** +- Handle domain-specific tasks +- Deep expertise in specific areas +- Targeted capability routing + +**Configuration:** +```bash +picoclaw --swarm.role specialist \ + --swarm.capabilities rust,embedded +``` + +--- + +## Monitoring + +### Dashboard View + +```bash +picoclaw --swarm.dashboard +``` + +Output: +``` +╔════════════════════════════════════════════════════════════╗ +║ PicoClaw Swarm Status Dashboard ║ +╚════════════════════════════════════════════════════════════╝ + +【This Node】 + ID: claw-a1b2c3d4 + Role: ⚙️ worker + Status: ● online + Load: [████░░░░░] 40% (2/5) + Uptime: 2h15m + +【Connections】 + NATS: ✓ Yes nats://localhost:4222 + Temporal: ✓ Yes + +【Swarm Statistics】 + Nodes: 5 total, 5 online, 0 offline + Roles: 1 coordinator(s), 4 worker(s), 0 specialist(s) + Capacity: 8/25 tasks used + +【Discovered Nodes】 + ● claw-coord-01 C online [██░░░░░░░] 20% (1/5) + ● claw-a1b2c3d4 W online [████░░░░░] 40% (2/5) + ● claw-e5f6g7h8 W online [████░░░░░] 40% (2/5) + ● claw-i9j0k1l2 W busy [█████░░░░] 60% (3/5) + ● claw-m3n4o5p6 W online [██░░░░░░░] 20% (1/5) +``` + +### Logs + +```bash +# View swarm logs +picoclaw --swarm.enabled --log.level debug + +# Key log patterns: +# - "Node joined swarm" - New node discovered +# - "Task assigned" - Task routed to worker +# - "Task completed" - Worker finished task +# - "Node marked offline" - Node failure detected +``` + +--- + +## Troubleshooting + +### Nodes Not Discovering Each Other + +**Problem:** Workers don't appear in coordinator's node list. + +**Solutions:** +1. Check NATS connectivity: `picoclaw --swarm.check-nats` +2. Verify H-id matches (all nodes in same swarm need same H-id) +3. Check firewall rules (NATS port 4222 must be open) + +### Tasks Not Being Assigned + +**Problem:** Coordinator creates tasks but workers don't receive them. + +**Solutions:** +1. Verify worker capabilities match task requirements +2. Check worker load (`--swarm.max_concurrent` limit) +3. Ensure NATS JetStream is enabled (`-js` flag) + +### High Memory Usage + +**Problem:** Nodes consuming excessive memory. + +**Solutions:** +1. Reduce `--swarm.max_concurrent` +2. Decrease NATS subscription buffer size +3. Enable Temporal for long-running workflows (offloads state) + +### Temporal Connection Failed + +**Problem:** "Temporal connection failed (workflows disabled)" warning. + +**Impact:** Tasks still work, but without workflow persistence. + +**Solutions:** +1. Start Temporal server: `temporal server start-dev` +2. Or disable Temporal requirement: `--swarm.temporal.address=""` + +--- + +## Security Considerations + +### For Development + +- Embedded NATS is fine for local testing +- No authentication needed +- Use `--swarm.nats.embedded` + +### For Production + +1. **Enable NATS Authentication** + ```bash + nats-server -js --auth picoclaw_secret + picoclaw --swarm.nats.credentials /path/to/creds + ``` + +2. **Enable TLS** + ```bash + nats-server -js -tls + picoclaw --swarm.nats.urls tls://localhost:4222 + ``` + +3. **Isolate Swarms by H-id** + - Different production environments = different H-ids + - Prevents cross-environment communication + +4. **Network Segmentation** + - Keep NATS ports internal + - Use VPN for multi-cloud deployments + +--- + +## Performance Tuning + +### Small Swarm (2-5 nodes) + +```yaml +swarm: + max_concurrent: 5 + nats: + heartbeat_interval: 10s + node_timeout: 30s +``` + +### Medium Swarm (5-20 nodes) + +```yaml +swarm: + max_concurrent: 10 + nats: + heartbeat_interval: 5s + node_timeout: 20s +``` + +### Large Swarm (20+ nodes) + +```yaml +swarm: + max_concurrent: 20 + nats: + heartbeat_interval: 3s + node_timeout: 15s + temporal: + enabled: true # Required for workflow persistence +``` + +--- + +## Example: Distributed Code Review Swarm + +```bash +# Terminal 1: Coordinator +picoclaw --swarm.role coordinator \ + --swarm.capabilities coordination + +# Terminal 2: Code specialist +picoclaw --swarm.role specialist \ + --swarm.capabilities rust,go,python + +# Terminal 3: Security specialist +picoclaw --swarm.role specialist \ + --swarm.capabilities security,audit + +# Terminal 4: Documentation specialist +picoclaw --swarm.role specialist \ + --swarm.capabilities docs,writing + +# Terminal 5: Test specialist +picoclaw --swarm.role specialist \ + --swarm.capabilities testing,qa +``` + +When you send a code review request: +1. Coordinator decomposes into: code, security, docs, testing reviews +2. Each specialist handles their area in parallel +3. Coordinator synthesizes a comprehensive review + +--- + +## Next Steps + +- See [API.md](./API.md) for programmatic usage +- See [CONFIG.md](./CONFIG.md) for all configuration options +- See [EXAMPLES.md](./EXAMPLES.md) for more deployment examples diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md new file mode 100644 index 000000000..8aba1aa91 --- /dev/null +++ b/docs/tools_configuration.md @@ -0,0 +1,143 @@ +# Tools Configuration + +PicoClaw's tools configuration is located in the `tools` field of `config.json`. + +## Directory Structure + +```json +{ + "tools": { + "web": { ... }, + "exec": { ... }, + "cron": { ... }, + "skills": { ... } + } +} +``` + +## Web Tools + +Web tools are used for web search and fetching. + +### Brave + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `enabled` | bool | false | Enable Brave search | +| `api_key` | string | - | Brave Search API key | +| `max_results` | int | 5 | Maximum number of results | + +### DuckDuckGo + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `enabled` | bool | true | Enable DuckDuckGo search | +| `max_results` | int | 5 | Maximum number of results | + +### Perplexity + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `enabled` | bool | false | Enable Perplexity search | +| `api_key` | string | - | Perplexity API key | +| `max_results` | int | 5 | Maximum number of results | + +## Exec Tool + +The exec tool is used to execute shell commands. + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking | +| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) | + +### Functionality + +- **`enable_deny_patterns`**: Set to `false` to completely disable the default dangerous command blocking patterns +- **`custom_deny_patterns`**: Add custom deny regex patterns; commands matching these will be blocked + +### Default Blocked Command Patterns + +By default, PicoClaw blocks the following dangerous commands: + +- Delete commands: `rm -rf`, `del /f/q`, `rmdir /s` +- Disk operations: `format`, `mkfs`, `diskpart`, `dd if=`, writing to `/dev/sd*` +- System operations: `shutdown`, `reboot`, `poweroff` +- Command substitution: `$()`, `${}`, backticks +- Pipe to shell: `| sh`, `| bash` +- Privilege escalation: `sudo`, `chmod`, `chown` +- Process control: `pkill`, `killall`, `kill -9` +- Remote operations: `curl | sh`, `wget | sh`, `ssh` +- Package management: `apt`, `yum`, `dnf`, `npm install -g`, `pip install --user` +- Containers: `docker run`, `docker exec` +- Git: `git push`, `git force` +- Other: `eval`, `source *.sh` + +### Configuration Example + +```json +{ + "tools": { + "exec": { + "enable_deny_patterns": true, + "custom_deny_patterns": [ + "\\brm\\s+-r\\b", + "\\bkillall\\s+python" + ] + } + } +} +``` + +## Cron Tool + +The cron tool is used for scheduling periodic tasks. + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit | + +## Skills Tool + +The skills tool configures skill discovery and installation via registries like ClawHub. + +### Registries + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry | +| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL | +| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path | +| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path | +| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path | + +### Configuration Example + +```json +{ + "tools": { + "skills": { + "registries": { + "clawhub": { + "enabled": true, + "base_url": "https://clawhub.ai", + "search_path": "/api/v1/search", + "skills_path": "/api/v1/skills", + "download_path": "/api/v1/download" + } + } + } + } +} +``` + +## Environment Variables + +All configuration options can be overridden via environment variables with the format `PICOCLAW_TOOLS_
_`: + +For example: +- `PICOCLAW_TOOLS_WEB_BRAVE_ENABLED=true` +- `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false` +- `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10` + +Note: Array-type environment variables are not currently supported and must be set via the config file. diff --git a/docs/wecom-app-configuration.md b/docs/wecom-app-configuration.md new file mode 100644 index 000000000..3b17d37a7 --- /dev/null +++ b/docs/wecom-app-configuration.md @@ -0,0 +1,117 @@ +# 企业微信自建应用 (WeCom App) 配置指南 + +本文档介绍如何在 PicoClaw 中配置企业微信自建应用 (wecom-app) 通道。 + +## 功能特性 + +| 功能 | 支持状态 | +|------|---------| +| 被动接收消息 | ✅ | +| 主动发送消息 | ✅ | +| 私聊 | ✅ | +| 群聊 | ❌ | + +## 配置步骤 + +### 1. 企业微信后台配置 + +1. 登录 [企业微信管理后台](https://work.weixin.qq.com/wework_admin) +2. 进入"应用管理" → 选择自建应用 +3. 记录以下信息: + - **AgentId**: 应用详情页显示 + - **Secret**: 点击"查看"获取 +4. 进入"我的企业"页面,记录 **企业ID** (CorpID) + +### 2. 接收消息配置 + +1. 在应用详情页,点击"接收消息"的"设置API接收" +2. 填写以下信息: + - **URL**: `http://your-server:18792/webhook/wecom-app` + - **Token**: 随机生成或自定义(用于签名验证) + - **EncodingAESKey**: 点击"随机生成"生成43字符的密钥 +3. 点击"保存"时,企业微信会发送验证请求 + +### 3. PicoClaw 配置 + +在 `config.json` 中添加以下配置: + +```json +{ + "channels": { + "wecom_app": { + "enabled": true, + "corp_id": "wwxxxxxxxxxxxxxxxx", // 企业ID + "corp_secret": "xxxxxxxxxxxxxxxxxxxxxxxx", // 应用Secret + "agent_id": 1000002, // 应用AgentId + "token": "your_token", // 接收消息配置的Token + "encoding_aes_key": "your_encoding_aes_key", // 接收消息配置的EncodingAESKey + "webhook_host": "0.0.0.0", + "webhook_port": 18792, + "webhook_path": "/webhook/wecom-app", + "allow_from": [], + "reply_timeout": 5 + } + } +} +``` + +## 常见问题 + +### 1. 回调URL验证失败 + +**症状**: 企业微信保存API接收消息时提示验证失败 + +**检查项**: +- 确认服务器防火墙已开放 18792 端口 +- 确认 `corp_id`、`token`、`encoding_aes_key` 配置正确 +- 查看 PicoClaw 日志是否有请求到达 + +### 2. 中文消息解密失败 + +**症状**: 发送中文消息时出现 `invalid padding size` 错误 + +**原因**: 企业微信使用非标准的 PKCS7 填充(32字节块大小) + +**解决**: 确保使用最新版本的 PicoClaw,已修复此问题。 + +### 3. 端口冲突 + +**症状**: 启动时提示端口已被占用 + +**解决**: 修改 `webhook_port` 为其他端口,如 18794 + +## 技术细节 + +### 加密算法 + +- **算法**: AES-256-CBC +- **密钥**: EncodingAESKey Base64解码后的32字节 +- **IV**: AESKey的前16字节 +- **填充**: PKCS7(块大小为32字节,非标准16字节) +- **消息格式**: XML + +### 消息结构 + +解密后的消息格式: +``` +random(16B) + msg_len(4B) + msg + receiveid +``` + +其中 `receiveid` 对于自建应用是 `corp_id`。 + +## 调试 + +启用调试模式查看详细日志: + +```bash +picoclaw gateway --debug +``` + +关键日志标识: +- `wecom_app`: WeCom App 通道相关日志 +- `wecom_common`: 加密解密相关日志 + +## 参考文档 + +- [企业微信官方文档 - 接收消息](https://developer.work.weixin.qq.com/document/path/96211) +- [企业微信官方加解密库](https://github.com/sbzhu/weworkapi_golang) diff --git a/go.mod b/go.mod index f4c233ea8..000dfb2e3 100644 --- a/go.mod +++ b/go.mod @@ -12,24 +12,50 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mymmrac/telego v1.6.0 + github.com/nats-io/nats-server/v2 v2.12.4 + github.com/nats-io/nats.go v1.48.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 - github.com/openai/openai-go/v3 v3.21.0 + github.com/openai/openai-go/v3 v3.22.0 github.com/slack-go/slack v0.17.3 + github.com/stretchr/testify v1.11.1 github.com/tencent-connect/botgo v0.2.1 + go.temporal.io/sdk v1.40.0 golang.org/x/oauth2 v0.35.0 ) +require ( + github.com/antithesishq/antithesis-sdk-go v0.6.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/go-tpm v0.9.8 // indirect + github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.3 // indirect + github.com/oklog/ulid/v2 v2.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + require ( github.com/andybalholm/brotli v1.2.0 // indirect github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.15.0 // indirect github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect - github.com/go-resty/resty/v2 v2.17.1 // indirect + github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a // indirect + github.com/github/copilot-sdk/go v0.1.23 + github.com/go-resty/resty/v2 v2.17.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/mock v1.6.0 // indirect + github.com/google/jsonschema-go v0.4.2 // indirect github.com/grbit/go-json v0.11.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.8 // indirect github.com/klauspost/compress v1.18.4 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 // indirect + github.com/nats-io/jwt/v2 v2.8.0 // indirect + github.com/nats-io/nkeys v0.4.15 // indirect + github.com/nats-io/nuid v1.0.1 // indirect + github.com/nexus-rpc/sdk-go v0.5.1 // indirect + github.com/robfig/cron v1.2.0 // indirect + github.com/stretchr/objx v0.5.3 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect @@ -38,9 +64,16 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.69.0 // indirect github.com/valyala/fastjson v1.6.7 // indirect + go.temporal.io/api v1.62.2 // indirect golang.org/x/arch v0.24.0 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.50.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect + golang.org/x/time v0.14.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect + google.golang.org/grpc v1.79.1 // indirect + google.golang.org/protobuf v1.36.11 // indirect ) diff --git a/go.sum b/go.sum index 9174d2889..9d44a5be1 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0= github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= +github.com/antithesishq/antithesis-sdk-go v0.6.0 h1:v/YViLhFYkZOEEof4AXjD5AgGnGM84YHF4RqEwp6I2g= +github.com/antithesishq/antithesis-sdk-go v0.6.0/go.mod h1:IUpT2DPAKh6i/YhSbt6Gl3v2yvUZjmKncl7U91fup7E= github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= @@ -17,6 +19,8 @@ github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5m github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= @@ -30,17 +34,27 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a h1:yDWHCSQ40h88yih2JAcL6Ls/kVkSE8GFACTGVnMPruw= +github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a/go.mod h1:7Ga40egUymuWXxAe151lTNnCv97MddSOVsjpPPkityA= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/github/copilot-sdk/go v0.1.23 h1:uExtO/inZQndCZMiSAA1hvXINiz9tqo/MZgQzFzurxw= +github.com/github/copilot-sdk/go v0.1.23/go.mod h1:GdwwBfMbm9AABLEM3x5IZKw4ZfwCYxZ1BgyytmZenQ0= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w= github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w/BIH7cC3Q= -github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4= -github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= +github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk= +github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= @@ -50,12 +64,20 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-tpm v0.9.8 h1:slArAR9Ft+1ybZu0lBwpSmpwhRXaa85hWtMinMyRAWo= +github.com/google/go-tpm v0.9.8/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= +github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -65,6 +87,10 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc= github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.3 h1:B+8ClL/kCQkRiU82d9xajRPKYMrB7E0MbtzWVi1K4ns= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.3/go.mod h1:NbCUVmiS4foBGBHOYlCT25+YmGpJ32dZPi75pGEUpj4= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.8 h1:NpbJl/eVbvrGE0MJ6X16X9SAifesl6Fwxg/YmCvubRI= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.8/go.mod h1:mi7YA+gCzVem12exXy46ZespvGtX/lZmD/RLnQhVW7U= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -75,15 +101,34 @@ github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk= github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI= +github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76 h1:KGuD/pM2JpL9FAYvBrnBBeENKZNh6eNtjqytV6TYjnk= +github.com/minio/highwayhash v1.0.4-0.20251030100505-070ab1a87a76/go.mod h1:GGYsuwP/fPD6Y9hMiXuapVvlIUEhFhMTh0rxU3ik1LQ= github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0= github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y= +github.com/nats-io/jwt/v2 v2.8.0 h1:K7uzyz50+yGZDO5o772eRE7atlcSEENpL7P+b74JV1g= +github.com/nats-io/jwt/v2 v2.8.0/go.mod h1:me11pOkwObtcBNR8AiMrUbtVOUGkqYjMQZ6jnSdVUIA= +github.com/nats-io/nats-server/v2 v2.12.4 h1:ZnT10v2LU2Xcoiy8ek9X6Se4YG8EuMfIfvAEuFVx1Ts= +github.com/nats-io/nats-server/v2 v2.12.4/go.mod h1:5MCp/pqm5SEfsvVZ31ll1088ZTwEUdvRX1Hmh/mTTDg= +github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U= +github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g= +github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4= +github.com/nats-io/nkeys v0.4.15/go.mod h1:CpMchTXC9fxA5zrMo4KpySxNjiDVvr8ANOSZdiNfUrs= +github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/nexus-rpc/sdk-go v0.5.1 h1:UFYYfoHlQc+Pn9gQpmn9QE7xluewAn2AO1OSkAh7YFU= +github.com/nexus-rpc/sdk-go v0.5.1/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/oklog/ulid/v2 v2.1.1 h1:suPZ4ARWLOJLegGFiZZ1dFAkqzhMjL3J1TzI+5wHz8s= +github.com/oklog/ulid/v2 v2.1.1/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= @@ -92,19 +137,26 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= -github.com/openai/openai-go/v3 v3.21.0 h1:3GpIR/W4q/v1uUOVuK3zYtQiF3DnRrZag/sxbtvEdtc= -github.com/openai/openai-go/v3 v3.21.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= +github.com/openai/openai-go/v3 v3.22.0 h1:6MEoNoV8sbjOVmXdvhmuX3BjVbVdcExbVyGixiyJ8ys= +github.com/openai/openai-go/v3 v3.22.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= +github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ= +github.com/robfig/cron v1.2.0/go.mod h1:JGuDeoQd7Z6yL4zQhZ3OPEVHB7fL6Ka6skscFHfmt2k= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g= github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4= +github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -139,7 +191,24 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.temporal.io/api v1.62.2 h1:jFhIzlqNyJsJZTiCRQmTIMv6OTQ5BZ57z8gbgLGMaoo= +go.temporal.io/api v1.62.2/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= +go.temporal.io/sdk v1.40.0 h1:n9JN3ezVpWBxLzz5xViCo0sKxp7kVVhr1Su0bcMRNNs= +go.temporal.io/sdk v1.40.0/go.mod h1:tauxVfN174F0bdEs27+i0h8UPD7xBb6Py2SPHo7f1C0= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y= @@ -154,6 +223,7 @@ golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -178,6 +248,7 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= @@ -194,6 +265,7 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -201,6 +273,7 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -215,19 +288,30 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= -golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 h1:JLQynH/LBHfCTSbDWl+py8C+Rg/k1OVH3xfcaiANuF0= +google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:kSJwQxqmFXeo79zOmbrALdflXQeAYcUbgS7PbpMknCY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 h1:mWPCjDEyshlQYzBpMNHaEof6UX1PmHcaUODUywQ0uac= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY= +google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -236,8 +320,11 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= diff --git a/pkg/agent/context.go b/pkg/agent/context.go index e32e456f9..47a7a6184 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -19,6 +19,8 @@ type ContextBuilder struct { skillsLoader *skills.SkillsLoader memory *MemoryStore tools *tools.ToolRegistry // Direct reference to tool registry + hid string // H-id (tenant/cluster identity) + sid string // S-id (instance identity) } func getGlobalConfigDir() string { @@ -48,6 +50,12 @@ func (cb *ContextBuilder) SetToolsRegistry(registry *tools.ToolRegistry) { cb.tools = registry } +// SetIdentity sets the swarm identity (H-id and S-id) +func (cb *ContextBuilder) SetIdentity(hid, sid string) { + cb.hid = hid + cb.sid = sid +} + func (cb *ContextBuilder) getIdentity() string { now := time.Now().Format("2006-01-02 15:04 (Monday)") workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace)) @@ -56,6 +64,18 @@ func (cb *ContextBuilder) getIdentity() string { // Build tools section dynamically toolsSection := cb.buildToolsSection() + // Build identity section if set + identitySection := "" + if cb.hid != "" || cb.sid != "" { + identitySection = "\n## Swarm Identity\n" + if cb.hid != "" { + identitySection += fmt.Sprintf("- **H-id (Tenant)**: %s\n", cb.hid) + } + if cb.sid != "" { + identitySection += fmt.Sprintf("- **S-id (Instance)**: %s\n", cb.sid) + } + } + return fmt.Sprintf(`# picoclaw 🦞 You are picoclaw, a helpful AI assistant. @@ -72,6 +92,7 @@ Your workspace is at: %s - Daily Notes: %s/memory/YYYYMM/YYYYMMDD.md - Skills: %s/skills/{skill-name}/SKILL.md +%s %s ## Important Rules @@ -80,8 +101,8 @@ Your workspace is at: %s 2. **Be helpful and accurate** - When using tools, briefly explain what you're doing. -3. **Memory** - When remembering something, write to %s/memory/MEMORY.md`, - now, runtime, workspacePath, workspacePath, workspacePath, workspacePath, toolsSection, workspacePath) +3. **Memory** - When interacting with me if something seems memorable, update %s/memory/MEMORY.md`, + now, runtime, workspacePath, workspacePath, workspacePath, workspacePath, toolsSection, identitySection, workspacePath) } func (cb *ContextBuilder) buildToolsSection() string { @@ -96,7 +117,9 @@ func (cb *ContextBuilder) buildToolsSection() string { var sb strings.Builder sb.WriteString("## Available Tools\n\n") - sb.WriteString("**CRITICAL**: You MUST use tools to perform actions. Do NOT pretend to execute commands or schedule tasks.\n\n") + sb.WriteString( + "**CRITICAL**: You MUST use tools to perform actions. Do NOT pretend to execute commands or schedule tasks.\n\n", + ) sb.WriteString("You have access to the following tools:\n\n") for _, s := range summaries { sb.WriteString(s) @@ -146,18 +169,24 @@ func (cb *ContextBuilder) LoadBootstrapFiles() string { "IDENTITY.md", } - var result string + var sb strings.Builder for _, filename := range bootstrapFiles { filePath := filepath.Join(cb.workspace, filename) if data, err := os.ReadFile(filePath); err == nil { - result += fmt.Sprintf("## %s\n\n%s\n\n", filename, string(data)) + fmt.Fprintf(&sb, "## %s\n\n%s\n\n", filename, data) } } - return result + return sb.String() } -func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary string, currentMessage string, media []string, channel, chatID string) []providers.Message { +func (cb *ContextBuilder) BuildMessages( + history []providers.Message, + summary string, + currentMessage string, + media []string, + channel, chatID string, +) []providers.Message { messages := []providers.Message{} systemPrompt := cb.BuildSystemPrompt() @@ -169,9 +198,9 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str // Log system prompt summary for debugging (debug mode only) logger.DebugCF("agent", "System prompt built", - map[string]interface{}{ - "total_chars": len(systemPrompt), - "total_lines": strings.Count(systemPrompt, "\n") + 1, + map[string]any{ + "total_chars": len(systemPrompt), + "total_lines": strings.Count(systemPrompt, "\n") + 1, "section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1, }) @@ -181,7 +210,7 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str preview = preview[:500] + "... (truncated)" } logger.DebugCF("agent", "System prompt preview", - map[string]interface{}{ + map[string]any{ "preview": preview, }) @@ -189,16 +218,7 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str systemPrompt += "\n\n## Summary of Previous Conversation\n\n" + summary } - //This fix prevents the session memory from LLM failure due to elimination of toolu_IDs required from LLM - // --- INICIO DEL FIX --- - //Diegox-17 - for len(history) > 0 && (history[0].Role == "tool") { - logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error", - map[string]interface{}{"role": history[0].Role}) - history = history[1:] - } - //Diegox-17 - // --- FIN DEL FIX --- + history = sanitizeHistoryForProvider(history) messages = append(messages, providers.Message{ Role: "system", @@ -207,15 +227,66 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str messages = append(messages, history...) - messages = append(messages, providers.Message{ - Role: "user", - Content: currentMessage, - }) + if strings.TrimSpace(currentMessage) != "" { + messages = append(messages, providers.Message{ + Role: "user", + Content: currentMessage, + }) + } return messages } -func (cb *ContextBuilder) AddToolResult(messages []providers.Message, toolCallID, toolName, result string) []providers.Message { +func sanitizeHistoryForProvider(history []providers.Message) []providers.Message { + if len(history) == 0 { + return history + } + + sanitized := make([]providers.Message, 0, len(history)) + for _, msg := range history { + switch msg.Role { + case "tool": + if len(sanitized) == 0 { + logger.DebugCF("agent", "Dropping orphaned leading tool message", map[string]any{}) + continue + } + last := sanitized[len(sanitized)-1] + if last.Role != "assistant" || len(last.ToolCalls) == 0 { + logger.DebugCF("agent", "Dropping orphaned tool message", map[string]any{}) + continue + } + sanitized = append(sanitized, msg) + + case "assistant": + if len(msg.ToolCalls) > 0 { + if len(sanitized) == 0 { + logger.DebugCF("agent", "Dropping assistant tool-call turn at history start", map[string]any{}) + continue + } + prev := sanitized[len(sanitized)-1] + if prev.Role != "user" && prev.Role != "tool" { + logger.DebugCF( + "agent", + "Dropping assistant tool-call turn with invalid predecessor", + map[string]any{"prev_role": prev.Role}, + ) + continue + } + } + sanitized = append(sanitized, msg) + + default: + sanitized = append(sanitized, msg) + } + } + + return sanitized +} + +func (cb *ContextBuilder) AddToolResult( + messages []providers.Message, + toolCallID, toolName, result string, +) []providers.Message { messages = append(messages, providers.Message{ Role: "tool", Content: result, @@ -224,7 +295,11 @@ func (cb *ContextBuilder) AddToolResult(messages []providers.Message, toolCallID return messages } -func (cb *ContextBuilder) AddAssistantMessage(messages []providers.Message, content string, toolCalls []map[string]interface{}) []providers.Message { +func (cb *ContextBuilder) AddAssistantMessage( + messages []providers.Message, + content string, + toolCalls []map[string]any, +) []providers.Message { msg := providers.Message{ Role: "assistant", Content: content, @@ -254,13 +329,13 @@ func (cb *ContextBuilder) loadSkills() string { } // GetSkillsInfo returns information about loaded skills. -func (cb *ContextBuilder) GetSkillsInfo() map[string]interface{} { +func (cb *ContextBuilder) GetSkillsInfo() map[string]any { allSkills := cb.skillsLoader.ListSkills() skillNames := make([]string, 0, len(allSkills)) for _, s := range allSkills { skillNames = append(skillNames, s.Name) } - return map[string]interface{}{ + return map[string]any{ "total": len(allSkills), "available": len(allSkills), "names": skillNames, diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go new file mode 100644 index 000000000..dfbef9fbc --- /dev/null +++ b/pkg/agent/instance.go @@ -0,0 +1,159 @@ +package agent + +import ( + "os" + "path/filepath" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// AgentInstance represents a fully configured agent with its own workspace, +// session manager, context builder, and tool registry. +type AgentInstance struct { + ID string + Name string + Model string + Fallbacks []string + Workspace string + MaxIterations int + MaxTokens int + Temperature float64 + ContextWindow int + Provider providers.LLMProvider + Sessions *session.SessionManager + ContextBuilder *ContextBuilder + Tools *tools.ToolRegistry + Subagents *config.SubagentsConfig + SkillsFilter []string + Candidates []providers.FallbackCandidate +} + +// NewAgentInstance creates an agent instance from config. +func NewAgentInstance( + agentCfg *config.AgentConfig, + defaults *config.AgentDefaults, + cfg *config.Config, + provider providers.LLMProvider, +) *AgentInstance { + workspace := resolveAgentWorkspace(agentCfg, defaults) + os.MkdirAll(workspace, 0o755) + + model := resolveAgentModel(agentCfg, defaults) + fallbacks := resolveAgentFallbacks(agentCfg, defaults) + + restrict := defaults.RestrictToWorkspace + toolsRegistry := tools.NewToolRegistry() + toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewListDirTool(workspace, restrict)) + toolsRegistry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg)) + toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict)) + + sessionsDir := filepath.Join(workspace, "sessions") + sessionsManager := session.NewSessionManager(sessionsDir) + + contextBuilder := NewContextBuilder(workspace) + contextBuilder.SetToolsRegistry(toolsRegistry) + + agentID := routing.DefaultAgentID + agentName := "" + var subagents *config.SubagentsConfig + var skillsFilter []string + + if agentCfg != nil { + agentID = routing.NormalizeAgentID(agentCfg.ID) + agentName = agentCfg.Name + subagents = agentCfg.Subagents + skillsFilter = agentCfg.Skills + } + + maxIter := defaults.MaxToolIterations + if maxIter == 0 { + maxIter = 20 + } + + maxTokens := defaults.MaxTokens + if maxTokens == 0 { + maxTokens = 8192 + } + + temperature := 0.7 + if defaults.Temperature != nil { + temperature = *defaults.Temperature + } + + // Resolve fallback candidates + modelCfg := providers.ModelConfig{ + Primary: model, + Fallbacks: fallbacks, + } + candidates := providers.ResolveCandidates(modelCfg, defaults.Provider) + + return &AgentInstance{ + ID: agentID, + Name: agentName, + Model: model, + Fallbacks: fallbacks, + Workspace: workspace, + MaxIterations: maxIter, + MaxTokens: maxTokens, + Temperature: temperature, + ContextWindow: maxTokens, + Provider: provider, + Sessions: sessionsManager, + ContextBuilder: contextBuilder, + Tools: toolsRegistry, + Subagents: subagents, + SkillsFilter: skillsFilter, + Candidates: candidates, + } +} + +// resolveAgentWorkspace determines the workspace directory for an agent. +func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string { + if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" { + return expandHome(strings.TrimSpace(agentCfg.Workspace)) + } + if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" { + return expandHome(defaults.Workspace) + } + home, _ := os.UserHomeDir() + id := routing.NormalizeAgentID(agentCfg.ID) + return filepath.Join(home, ".picoclaw", "workspace-"+id) +} + +// resolveAgentModel resolves the primary model for an agent. +func resolveAgentModel(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string { + if agentCfg != nil && agentCfg.Model != nil && strings.TrimSpace(agentCfg.Model.Primary) != "" { + return strings.TrimSpace(agentCfg.Model.Primary) + } + return defaults.Model +} + +// resolveAgentFallbacks resolves the fallback models for an agent. +func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) []string { + if agentCfg != nil && agentCfg.Model != nil && agentCfg.Model.Fallbacks != nil { + return agentCfg.Model.Fallbacks + } + return defaults.ModelFallbacks +} + +func expandHome(path string) string { + if path == "" { + return path + } + if path[0] == '~' { + home, _ := os.UserHomeDir() + if len(path) > 1 && path[1] == '/' { + return home + path[1:] + } + return home + } + return path +} diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go new file mode 100644 index 000000000..fcc8e9bea --- /dev/null +++ b/pkg/agent/instance_test.go @@ -0,0 +1,95 @@ +package agent + +import ( + "os" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 1234, + MaxToolIterations: 5, + }, + }, + } + + configuredTemp := 1.0 + cfg.Agents.Defaults.Temperature = &configuredTemp + + provider := &mockProvider{} + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) + + if agent.MaxTokens != 1234 { + t.Fatalf("MaxTokens = %d, want %d", agent.MaxTokens, 1234) + } + if agent.Temperature != 1.0 { + t.Fatalf("Temperature = %f, want %f", agent.Temperature, 1.0) + } +} + +func TestNewAgentInstance_DefaultsTemperatureWhenZero(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 1234, + MaxToolIterations: 5, + }, + }, + } + + configuredTemp := 0.0 + cfg.Agents.Defaults.Temperature = &configuredTemp + + provider := &mockProvider{} + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) + + if agent.Temperature != 0.0 { + t.Fatalf("Temperature = %f, want %f", agent.Temperature, 0.0) + } +} + +func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 1234, + MaxToolIterations: 5, + }, + }, + } + + provider := &mockProvider{} + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) + + if agent.Temperature != 0.7 { + t.Fatalf("Temperature = %f, want %f", agent.Temperature, 0.7) + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 90e665960..0de1a71d1 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -10,19 +10,20 @@ import ( "context" "encoding/json" "fmt" - "os" - "path/filepath" "strings" "sync" "sync/atomic" "time" + "unicode/utf8" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" - "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/skills" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" @@ -30,17 +31,22 @@ import ( type AgentLoop struct { bus *bus.MessageBus - provider providers.LLMProvider - workspace string - model string - contextWindow int // Maximum context window size in tokens - maxIterations int - sessions *session.SessionManager + cfg *config.Config + registry *AgentRegistry state *state.Manager - contextBuilder *ContextBuilder - tools *tools.ToolRegistry running atomic.Bool - summarizing sync.Map // Tracks which sessions are currently being summarized + summarizing sync.Map // Tracks which sessions are currently being summarized + + // Swarm identity fields + hid string // H-id (tenant/cluster identity) + sid string // S-id (instance identity) + + // AutoConsume controls whether Run() automatically consumes messages from bus + // When false, Run() waits indefinitely and the coordinator handles routing + AutoConsume bool + + fallback *providers.FallbackChain + channelManager *channels.Manager } // processOptions configures how a message is processed @@ -55,92 +61,115 @@ type processOptions struct { NoHistory bool // If true, don't load session history (for heartbeat) } -// createToolRegistry creates a tool registry with common tools. -// This is shared between main agent and subagents. -func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry { - registry := tools.NewToolRegistry() - - // File system tools - registry.Register(tools.NewReadFileTool(workspace, restrict)) - registry.Register(tools.NewWriteFileTool(workspace, restrict)) - registry.Register(tools.NewListDirTool(workspace, restrict)) - registry.Register(tools.NewEditFileTool(workspace, restrict)) - registry.Register(tools.NewAppendFileTool(workspace, restrict)) - - // Shell execution - registry.Register(tools.NewExecTool(workspace, restrict)) - - // Web tools - braveAPIKey := cfg.Tools.Web.Search.APIKey - registry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) - registry.Register(tools.NewWebFetchTool(50000)) - - // Message tool - available to both agent and subagent - // Subagent uses it to communicate directly with user - messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content string) error { - msgBus.PublishOutbound(bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, - Content: content, - }) - return nil - }) - registry.Register(messageTool) - - return registry -} - func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { - workspace := cfg.WorkspacePath() - os.MkdirAll(workspace, 0755) - - restrict := cfg.Agents.Defaults.RestrictToWorkspace + registry := NewAgentRegistry(cfg, provider) - // Create tool registry for main agent - toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus) + // Register shared tools to all agents + registerSharedTools(cfg, msgBus, registry, provider) - // Create subagent manager with its own tool registry - subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus) - subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus) - // Subagent doesn't need spawn/subagent tools to avoid recursion - subagentManager.SetTools(subagentTools) + // Set up shared fallback chain + cooldown := providers.NewCooldownTracker() + fallbackChain := providers.NewFallbackChain(cooldown) - // Register spawn tool (for main agent) - spawnTool := tools.NewSpawnTool(subagentManager) - toolsRegistry.Register(spawnTool) + // Create state manager using default agent's workspace for channel recording + defaultAgent := registry.GetDefaultAgent() + var stateManager *state.Manager + if defaultAgent != nil { + stateManager = state.NewManager(defaultAgent.Workspace) + } - // Register subagent tool (synchronous execution) - subagentTool := tools.NewSubagentTool(subagentManager) - toolsRegistry.Register(subagentTool) + return &AgentLoop{ + bus: msgBus, + cfg: cfg, + registry: registry, + state: stateManager, + summarizing: sync.Map{}, + fallback: fallbackChain, + } +} - sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) +// registerSharedTools registers tools that are shared across all agents (web, message, spawn). +func registerSharedTools( + cfg *config.Config, + msgBus *bus.MessageBus, + registry *AgentRegistry, + provider providers.LLMProvider, +) { + for _, agentID := range registry.ListAgentIDs() { + agent, ok := registry.GetAgent(agentID) + if !ok { + continue + } - // Create state manager for atomic state persistence - stateManager := state.NewManager(workspace) + // Web tools + if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{ + BraveAPIKey: cfg.Tools.Web.Brave.APIKey, + BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, + BraveEnabled: cfg.Tools.Web.Brave.Enabled, + DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, + DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, + PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, + PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, + PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, + }); searchTool != nil { + agent.Tools.Register(searchTool) + } + agent.Tools.Register(tools.NewWebFetchTool(50000)) + + // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms + agent.Tools.Register(tools.NewI2CTool()) + agent.Tools.Register(tools.NewSPITool()) + + // Message tool + messageTool := tools.NewMessageTool() + messageTool.SetSendCallback(func(channel, chatID, content string) error { + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) + return nil + }) + agent.Tools.Register(messageTool) - // Create context builder and set tools registry - contextBuilder := NewContextBuilder(workspace) - contextBuilder.SetToolsRegistry(toolsRegistry) + // Skill discovery and installation tools + registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ + MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, + ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), + }) + searchCache := skills.NewSearchCache( + cfg.Tools.Skills.SearchCache.MaxSize, + time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second, + ) + agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache)) + agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace)) + + // Spawn tool with allowlist checker + subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) + subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) + spawnTool := tools.NewSpawnTool(subagentManager) + currentAgentID := agentID + spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { + return registry.CanSpawnSubagent(currentAgentID, targetAgentID) + }) + agent.Tools.Register(spawnTool) - return &AgentLoop{ - bus: msgBus, - provider: provider, - workspace: workspace, - model: cfg.Agents.Defaults.Model, - contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization - maxIterations: cfg.Agents.Defaults.MaxToolIterations, - sessions: sessionsManager, - state: stateManager, - contextBuilder: contextBuilder, - tools: toolsRegistry, - summarizing: sync.Map{}, + // Update context builder with the complete tools registry + agent.ContextBuilder.SetToolsRegistry(agent.Tools) } } func (al *AgentLoop) Run(ctx context.Context) error { al.running.Store(true) + // If AutoConsume is disabled, just wait for context cancellation + // This allows the coordinator to handle message routing + if !al.AutoConsume { + <-ctx.Done() + al.running.Store(false) + return nil + } + for al.running.Load() { select { case <-ctx.Done(): @@ -157,11 +186,26 @@ func (al *AgentLoop) Run(ctx context.Context) error { } if response != "" { - al.bus.PublishOutbound(bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) + // Check if the message tool already sent a response during this round. + // If so, skip publishing to avoid duplicate messages to the user. + // Use default agent's tools to check (message tool is shared). + alreadySent := false + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() + } + } + } + + if !alreadySent { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: response, + }) + } } } } @@ -174,18 +218,32 @@ func (al *AgentLoop) Stop() { } func (al *AgentLoop) RegisterTool(tool tools.Tool) { - al.tools.Register(tool) + for _, agentID := range al.registry.ListAgentIDs() { + if agent, ok := al.registry.GetAgent(agentID); ok { + agent.Tools.Register(tool) + } + } +} + +func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { + al.channelManager = cm } // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { + if al.state == nil { + return nil + } return al.state.SetLastChannel(channel) } // RecordLastChatID records the last active chat ID for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChatID(chatID string) error { + if al.state == nil { + return nil + } return al.state.SetLastChatID(chatID) } @@ -193,7 +251,10 @@ func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey stri return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct") } -func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error) { +func (al *AgentLoop) ProcessDirectWithChannel( + ctx context.Context, + content, sessionKey, channel, chatID string, +) (string, error) { msg := bus.InboundMessage{ Channel: channel, SenderID: "cron", @@ -205,10 +266,17 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess return al.processMessage(ctx, msg) } +// ProcessInboundMessage processes an inbound message from the message bus. +// This is used by the coordinator to process non-workflow messages locally. +func (al *AgentLoop) ProcessInboundMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { + return al.processMessage(ctx, msg) +} + // ProcessHeartbeat processes a heartbeat request without session history. // Each heartbeat is independent and doesn't accumulate context. func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) { - return al.runAgentLoop(ctx, processOptions{ + agent := al.registry.GetDefaultAgent() + return al.runAgentLoop(ctx, agent, processOptions{ SessionKey: "heartbeat", Channel: channel, ChatID: chatID, @@ -229,7 +297,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) logContent = utils.Truncate(msg.Content, 80) } logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent), - map[string]interface{}{ + map[string]any{ "channel": msg.Channel, "chat_id": msg.ChatID, "sender_id": msg.SenderID, @@ -241,9 +309,41 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return al.processSystemMessage(ctx, msg) } - // Process as user message - return al.runAgentLoop(ctx, processOptions{ - SessionKey: msg.SessionKey, + // Check for commands + if response, handled := al.handleCommand(ctx, msg); handled { + return response, nil + } + + // Route to determine agent and session key + route := al.registry.ResolveRoute(routing.RouteInput{ + Channel: msg.Channel, + AccountID: msg.Metadata["account_id"], + Peer: extractPeer(msg), + ParentPeer: extractParentPeer(msg), + GuildID: msg.Metadata["guild_id"], + TeamID: msg.Metadata["team_id"], + }) + + agent, ok := al.registry.GetAgent(route.AgentID) + if !ok { + agent = al.registry.GetDefaultAgent() + } + + // Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron) + sessionKey := route.SessionKey + if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") { + sessionKey = msg.SessionKey + } + + logger.InfoCF("agent", "Routed message", + map[string]any{ + "agent_id": agent.ID, + "session_key": sessionKey, + "matched_by": route.MatchedBy, + }) + + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: sessionKey, Channel: msg.Channel, ChatID: msg.ChatID, UserMessage: msg.Content, @@ -254,24 +354,24 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { - // Verify this is a system message if msg.Channel != "system" { return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel) } logger.InfoCF("agent", "Processing system message", - map[string]interface{}{ + map[string]any{ "sender_id": msg.SenderID, "chat_id": msg.ChatID, }) // Parse origin channel from chat_id (format: "channel:chat_id") - var originChannel string + var originChannel, originChatID string if idx := strings.Index(msg.ChatID, ":"); idx > 0 { originChannel = msg.ChatID[:idx] + originChatID = msg.ChatID[idx+1:] } else { - // Fallback originChannel = "cli" + originChatID = msg.ChatID } // Extract subagent result from message content @@ -284,52 +384,55 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe // Skip internal channels - only log, don't send to user if constants.IsInternalChannel(originChannel) { logger.InfoCF("agent", "Subagent completed (internal channel)", - map[string]interface{}{ - "sender_id": msg.SenderID, - "content_len": len(content), - "channel": originChannel, + map[string]any{ + "sender_id": msg.SenderID, + "content_len": len(content), + "channel": originChannel, }) return "", nil } - // Agent acts as dispatcher only - subagent handles user interaction via message tool - // Don't forward result here, subagent should use message tool to communicate with user - logger.InfoCF("agent", "Subagent completed", - map[string]interface{}{ - "sender_id": msg.SenderID, - "channel": originChannel, - "content_len": len(content), - }) + // Use default agent for system messages + agent := al.registry.GetDefaultAgent() + + // Use the origin session for context + sessionKey := routing.BuildAgentMainSessionKey(agent.ID) - // Agent only logs, does not respond to user - return "", nil + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: sessionKey, + Channel: originChannel, + ChatID: originChatID, + UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), + DefaultResponse: "Background task completed.", + EnableSummary: false, + SendResponse: true, + }) } // runAgentLoop is the core message processing logic. -// It handles context building, LLM calls, tool execution, and response handling. -func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) { +func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) { // 0. Record last channel for heartbeat notifications (skip internal channels) if opts.Channel != "" && opts.ChatID != "" { // Don't record internal channels (cli, system, subagent) if !constants.IsInternalChannel(opts.Channel) { channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) if err := al.RecordLastChannel(channelKey); err != nil { - logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()}) + logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()}) } } } // 1. Update tool contexts - al.updateToolContexts(opts.Channel, opts.ChatID) + al.updateToolContexts(agent, opts.Channel, opts.ChatID) // 2. Build messages (skip history for heartbeat) var history []providers.Message var summary string if !opts.NoHistory { - history = al.sessions.GetHistory(opts.SessionKey) - summary = al.sessions.GetSummary(opts.SessionKey) + history = agent.Sessions.GetHistory(opts.SessionKey) + summary = agent.Sessions.GetSummary(opts.SessionKey) } - messages := al.contextBuilder.BuildMessages( + messages := agent.ContextBuilder.BuildMessages( history, summary, opts.UserMessage, @@ -339,10 +442,10 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str ) // 3. Save user message to session - al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) + agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) // 4. Run LLM iteration loop - finalContent, iteration, err := al.runLLMIteration(ctx, messages, opts) + finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts) if err != nil { return "", err } @@ -356,12 +459,12 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str } // 6. Save final assistant message to session - al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent) - al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey)) + agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent) + agent.Sessions.Save(opts.SessionKey) // 7. Optional: summarization if opts.EnableSummary { - al.maybeSummarize(opts.SessionKey) + al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID) } // 8. Optional: send response via bus @@ -376,7 +479,8 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str // 9. Log response responsePreview := utils.Truncate(finalContent, 120) logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), - map[string]interface{}{ + map[string]any{ + "agent_id": agent.ID, "session_key": opts.SessionKey, "iterations": iteration, "final_length": len(finalContent), @@ -386,78 +490,156 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str } // runLLMIteration executes the LLM call loop with tool handling. -// Returns the final content, iteration count, and any error. -func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.Message, opts processOptions) (string, int, error) { +func (al *AgentLoop) runLLMIteration( + ctx context.Context, + agent *AgentInstance, + messages []providers.Message, + opts processOptions, +) (string, int, error) { iteration := 0 var finalContent string - for iteration < al.maxIterations { + for iteration < agent.MaxIterations { iteration++ logger.DebugCF("agent", "LLM iteration", - map[string]interface{}{ + map[string]any{ + "agent_id": agent.ID, "iteration": iteration, - "max": al.maxIterations, + "max": agent.MaxIterations, }) // Build tool definitions - providerToolDefs := al.tools.ToProviderDefs() + providerToolDefs := agent.Tools.ToProviderDefs() // Log LLM request details logger.DebugCF("agent", "LLM request", - map[string]interface{}{ + map[string]any{ + "agent_id": agent.ID, "iteration": iteration, - "model": al.model, + "model": agent.Model, "messages_count": len(messages), "tools_count": len(providerToolDefs), - "max_tokens": 8192, - "temperature": 0.7, + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, "system_prompt_len": len(messages[0].Content), }) // Log full messages (detailed) logger.DebugCF("agent", "Full LLM request", - map[string]interface{}{ + map[string]any{ "iteration": iteration, "messages_json": formatMessagesForLog(messages), "tools_json": formatToolsForLog(providerToolDefs), }) - // Call LLM - response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ - "max_tokens": 8192, - "temperature": 0.7, - }) + // Call LLM with fallback chain if candidates are configured. + var response *providers.LLMResponse + var err error + + callLLM := func() (*providers.LLMResponse, error) { + if len(agent.Candidates) > 1 && al.fallback != nil { + fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates, + func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { + return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{ + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, + }) + }, + ) + if fbErr != nil { + return nil, fbErr + } + if fbResult.Provider != "" && len(fbResult.Attempts) > 0 { + logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts", + fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1), + map[string]any{"agent_id": agent.ID, "iteration": iteration}) + } + return fbResult.Response, nil + } + return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]any{ + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, + }) + } + + // Retry loop for context/token errors + maxRetries := 2 + for retry := 0; retry <= maxRetries; retry++ { + response, err = callLLM() + if err == nil { + break + } + + errMsg := strings.ToLower(err.Error()) + isContextError := strings.Contains(errMsg, "token") || + strings.Contains(errMsg, "context") || + strings.Contains(errMsg, "invalidparameter") || + strings.Contains(errMsg, "length") + + if isContextError && retry < maxRetries { + logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{ + "error": err.Error(), + "retry": retry, + }) + + if retry == 0 && !constants.IsInternalChannel(opts.Channel) { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: "Context window exceeded. Compressing history and retrying...", + }) + } + + al.forceCompression(agent, opts.SessionKey) + newHistory := agent.Sessions.GetHistory(opts.SessionKey) + newSummary := agent.Sessions.GetSummary(opts.SessionKey) + messages = agent.ContextBuilder.BuildMessages( + newHistory, newSummary, "", + nil, opts.Channel, opts.ChatID, + ) + continue + } + break + } if err != nil { logger.ErrorCF("agent", "LLM call failed", - map[string]interface{}{ + map[string]any{ + "agent_id": agent.ID, "iteration": iteration, "error": err.Error(), }) - return "", iteration, fmt.Errorf("LLM call failed: %w", err) + return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) } // Check if no tool calls - we're done if len(response.ToolCalls) == 0 { finalContent = response.Content logger.InfoCF("agent", "LLM response without tool calls (direct answer)", - map[string]interface{}{ + map[string]any{ + "agent_id": agent.ID, "iteration": iteration, "content_chars": len(finalContent), }) break } - // Log tool calls - toolNames := make([]string, 0, len(response.ToolCalls)) + normalizedToolCalls := make([]providers.ToolCall, 0, len(response.ToolCalls)) for _, tc := range response.ToolCalls { + normalizedToolCalls = append(normalizedToolCalls, providers.NormalizeToolCall(tc)) + } + + // Log tool calls + toolNames := make([]string, 0, len(normalizedToolCalls)) + for _, tc := range normalizedToolCalls { toolNames = append(toolNames, tc.Name) } logger.InfoCF("agent", "LLM requested tool calls", - map[string]interface{}{ + map[string]any{ + "agent_id": agent.ID, "tools": toolNames, - "count": len(response.ToolCalls), + "count": len(normalizedToolCalls), "iteration": iteration, }) @@ -466,29 +648,40 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M Role: "assistant", Content: response.Content, } - for _, tc := range response.ToolCalls { + for _, tc := range normalizedToolCalls { argumentsJSON, _ := json.Marshal(tc.Arguments) + // Copy ExtraContent to ensure thought_signature is persisted for Gemini 3 + extraContent := tc.ExtraContent + thoughtSignature := "" + if tc.Function != nil { + thoughtSignature = tc.Function.ThoughtSignature + } + assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ ID: tc.ID, Type: "function", + Name: tc.Name, Function: &providers.FunctionCall{ - Name: tc.Name, - Arguments: string(argumentsJSON), + Name: tc.Name, + Arguments: string(argumentsJSON), + ThoughtSignature: thoughtSignature, }, + ExtraContent: extraContent, + ThoughtSignature: thoughtSignature, }) } messages = append(messages, assistantMsg) // Save assistant message with tool calls to session - al.sessions.AddFullMessage(opts.SessionKey, assistantMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) // Execute tool calls - for _, tc := range response.ToolCalls { - // Log tool call with arguments preview + for _, tc := range normalizedToolCalls { argsJSON, _ := json.Marshal(tc.Arguments) argsPreview := utils.Truncate(string(argsJSON), 200) logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), - map[string]interface{}{ + map[string]any{ + "agent_id": agent.ID, "tool": tc.Name, "iteration": iteration, }) @@ -502,14 +695,21 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M // The agent will handle user notification via processSystemMessage if !result.Silent && result.ForUser != "" { logger.InfoCF("agent", "Async tool completed, agent will handle notification", - map[string]interface{}{ + map[string]any{ "tool": tc.Name, "content_len": len(result.ForUser), }) } } - toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback) + toolResult := agent.Tools.ExecuteWithContext( + ctx, + tc.Name, + tc.Arguments, + opts.Channel, + opts.ChatID, + asyncCallback, + ) // Send ForUser content to user immediately if not Silent if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { @@ -519,7 +719,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M Content: toolResult.ForUser, }) logger.DebugCF("agent", "Sent tool result to user", - map[string]interface{}{ + map[string]any{ "tool": tc.Name, "content_len": len(toolResult.ForUser), }) @@ -539,7 +739,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M messages = append(messages, toolResultMsg) // Save tool result message to session - al.sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) } } @@ -547,19 +747,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M } // updateToolContexts updates the context for tools that need channel/chatID info. -func (al *AgentLoop) updateToolContexts(channel, chatID string) { +func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) { // Use ContextualTool interface instead of type assertions - if tool, ok := al.tools.Get("message"); ok { + if tool, ok := agent.Tools.Get("message"); ok { if mt, ok := tool.(tools.ContextualTool); ok { mt.SetContext(channel, chatID) } } - if tool, ok := al.tools.Get("spawn"); ok { + if tool, ok := agent.Tools.Get("spawn"); ok { if st, ok := tool.(tools.ContextualTool); ok { st.SetContext(channel, chatID) } } - if tool, ok := al.tools.Get("subagent"); ok { + if tool, ok := agent.Tools.Get("subagent"); ok { if st, ok := tool.(tools.ContextualTool); ok { st.SetContext(channel, chatID) } @@ -567,34 +767,111 @@ func (al *AgentLoop) updateToolContexts(channel, chatID string) { } // maybeSummarize triggers summarization if the session history exceeds thresholds. -func (al *AgentLoop) maybeSummarize(sessionKey string) { - newHistory := al.sessions.GetHistory(sessionKey) +func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { + newHistory := agent.Sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) - threshold := al.contextWindow * 75 / 100 + threshold := agent.ContextWindow * 75 / 100 if len(newHistory) > 20 || tokenEstimate > threshold { - if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading { + summarizeKey := agent.ID + ":" + sessionKey + if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading { go func() { - defer al.summarizing.Delete(sessionKey) - al.summarizeSession(sessionKey) + defer al.summarizing.Delete(summarizeKey) + if !constants.IsInternalChannel(channel) { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: "Memory threshold reached. Optimizing conversation history...", + }) + } + al.summarizeSession(agent, sessionKey) }() } } } +// forceCompression aggressively reduces context when the limit is hit. +// It drops the oldest 50% of messages (keeping system prompt and last user message). +func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { + history := agent.Sessions.GetHistory(sessionKey) + if len(history) <= 4 { + return + } + + // Keep system prompt (usually [0]) and the very last message (user's trigger) + // We want to drop the oldest half of the *conversation* + // Assuming [0] is system, [1:] is conversation + conversation := history[1 : len(history)-1] + if len(conversation) == 0 { + return + } + + // Helper to find the mid-point of the conversation + mid := len(conversation) / 2 + + // New history structure: + // 1. System Prompt (with compression note appended) + // 2. Second half of conversation + // 3. Last message + + droppedCount := mid + keptConversation := conversation[mid:] + + newHistory := make([]providers.Message, 0) + + // Append compression note to the original system prompt instead of adding a new system message + // This avoids having two consecutive system messages which some APIs (like Zhipu) reject + compressionNote := fmt.Sprintf( + "\n\n[System Note: Emergency compression dropped %d oldest messages due to context limit]", + droppedCount, + ) + enhancedSystemPrompt := history[0] + enhancedSystemPrompt.Content = enhancedSystemPrompt.Content + compressionNote + newHistory = append(newHistory, enhancedSystemPrompt) + + newHistory = append(newHistory, keptConversation...) + newHistory = append(newHistory, history[len(history)-1]) // Last message + + // Update session + agent.Sessions.SetHistory(sessionKey, newHistory) + agent.Sessions.Save(sessionKey) + + logger.WarnCF("agent", "Forced compression executed", map[string]any{ + "session_key": sessionKey, + "dropped_msgs": droppedCount, + "new_count": len(newHistory), + }) +} + // GetStartupInfo returns information about loaded tools and skills for logging. -func (al *AgentLoop) GetStartupInfo() map[string]interface{} { - info := make(map[string]interface{}) +func (al *AgentLoop) SetIdentity(hid, sid string) { + al.hid = hid + al.sid = sid +} + +func (al *AgentLoop) GetStartupInfo() map[string]any { + info := make(map[string]any) + + agent := al.registry.GetDefaultAgent() + if agent == nil { + return info + } // Tools info - tools := al.tools.List() - info["tools"] = map[string]interface{}{ - "count": len(tools), - "names": tools, + toolsList := agent.Tools.List() + info["tools"] = map[string]any{ + "count": len(toolsList), + "names": toolsList, } // Skills info - info["skills"] = al.contextBuilder.GetSkillsInfo() + info["skills"] = agent.ContextBuilder.GetSkillsInfo() + + // Agents info + info["agents"] = map[string]any{ + "count": len(al.registry.ListAgentIDs()), + "ids": al.registry.ListAgentIDs(), + } return info } @@ -605,58 +882,58 @@ func formatMessagesForLog(messages []providers.Message) string { return "[]" } - var result string - result += "[\n" + var sb strings.Builder + sb.WriteString("[\n") for i, msg := range messages { - result += fmt.Sprintf(" [%d] Role: %s\n", i, msg.Role) - if msg.ToolCalls != nil && len(msg.ToolCalls) > 0 { - result += " ToolCalls:\n" + fmt.Fprintf(&sb, " [%d] Role: %s\n", i, msg.Role) + if len(msg.ToolCalls) > 0 { + sb.WriteString(" ToolCalls:\n") for _, tc := range msg.ToolCalls { - result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name) + fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name) if tc.Function != nil { - result += fmt.Sprintf(" Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200)) + fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200)) } } } if msg.Content != "" { content := utils.Truncate(msg.Content, 200) - result += fmt.Sprintf(" Content: %s\n", content) + fmt.Fprintf(&sb, " Content: %s\n", content) } if msg.ToolCallID != "" { - result += fmt.Sprintf(" ToolCallID: %s\n", msg.ToolCallID) + fmt.Fprintf(&sb, " ToolCallID: %s\n", msg.ToolCallID) } - result += "\n" + sb.WriteString("\n") } - result += "]" - return result + sb.WriteString("]") + return sb.String() } // formatToolsForLog formats tool definitions for logging -func formatToolsForLog(tools []providers.ToolDefinition) string { - if len(tools) == 0 { +func formatToolsForLog(toolDefs []providers.ToolDefinition) string { + if len(toolDefs) == 0 { return "[]" } - var result string - result += "[\n" - for i, tool := range tools { - result += fmt.Sprintf(" [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name) - result += fmt.Sprintf(" Description: %s\n", tool.Function.Description) + var sb strings.Builder + sb.WriteString("[\n") + for i, tool := range toolDefs { + fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name) + fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description) if len(tool.Function.Parameters) > 0 { - result += fmt.Sprintf(" Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200)) + fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200)) } } - result += "]" - return result + sb.WriteString("]") + return sb.String() } // summarizeSession summarizes the conversation history for a session. -func (al *AgentLoop) summarizeSession(sessionKey string) { +func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() - history := al.sessions.GetHistory(sessionKey) - summary := al.sessions.GetSummary(sessionKey) + history := agent.Sessions.GetHistory(sessionKey) + summary := agent.Sessions.GetSummary(sessionKey) // Keep last 4 messages for continuity if len(history) <= 4 { @@ -666,8 +943,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { toSummarize := history[:len(history)-4] // Oversized Message Guard - // Skip messages larger than 50% of context window to prevent summarizer overflow - maxMessageTokens := al.contextWindow / 2 + maxMessageTokens := agent.ContextWindow / 2 validMessages := make([]providers.Message, 0) omitted := false @@ -675,8 +951,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { if m.Role != "user" && m.Role != "assistant" { continue } - // Estimate tokens for this message - msgTokens := len(m.Content) / 4 + msgTokens := len(m.Content) / 2 if msgTokens > maxMessageTokens { omitted = true continue @@ -689,29 +964,37 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { } // Multi-Part Summarization - // Split into two parts if history is significant var finalSummary string if len(validMessages) > 10 { mid := len(validMessages) / 2 part1 := validMessages[:mid] part2 := validMessages[mid:] - s1, _ := al.summarizeBatch(ctx, part1, "") - s2, _ := al.summarizeBatch(ctx, part2, "") - - // Merge them - mergePrompt := fmt.Sprintf("Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s", s1, s2) - resp, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, al.model, map[string]interface{}{ - "max_tokens": 1024, - "temperature": 0.3, - }) + s1, _ := al.summarizeBatch(ctx, agent, part1, "") + s2, _ := al.summarizeBatch(ctx, agent, part2, "") + + mergePrompt := fmt.Sprintf( + "Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s", + s1, + s2, + ) + resp, err := agent.Provider.Chat( + ctx, + []providers.Message{{Role: "user", Content: mergePrompt}}, + nil, + agent.Model, + map[string]any{ + "max_tokens": 1024, + "temperature": 0.3, + }, + ) if err == nil { finalSummary = resp.Content } else { finalSummary = s1 + " " + s2 } } else { - finalSummary, _ = al.summarizeBatch(ctx, validMessages, summary) + finalSummary, _ = al.summarizeBatch(ctx, agent, validMessages, summary) } if omitted && finalSummary != "" { @@ -719,27 +1002,42 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { } if finalSummary != "" { - al.sessions.SetSummary(sessionKey, finalSummary) - al.sessions.TruncateHistory(sessionKey, 4) - al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + agent.Sessions.SetSummary(sessionKey, finalSummary) + agent.Sessions.TruncateHistory(sessionKey, 4) + agent.Sessions.Save(sessionKey) } } // summarizeBatch summarizes a batch of messages. -func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Message, existingSummary string) (string, error) { - prompt := "Provide a concise summary of this conversation segment, preserving core context and key points.\n" +func (al *AgentLoop) summarizeBatch( + ctx context.Context, + agent *AgentInstance, + batch []providers.Message, + existingSummary string, +) (string, error) { + var sb strings.Builder + sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n") if existingSummary != "" { - prompt += "Existing context: " + existingSummary + "\n" + sb.WriteString("Existing context: ") + sb.WriteString(existingSummary) + sb.WriteString("\n") } - prompt += "\nCONVERSATION:\n" + sb.WriteString("\nCONVERSATION:\n") for _, m := range batch { - prompt += fmt.Sprintf("%s: %s\n", m.Role, m.Content) + fmt.Fprintf(&sb, "%s: %s\n", m.Role, m.Content) } + prompt := sb.String() - response, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, al.model, map[string]interface{}{ - "max_tokens": 1024, - "temperature": 0.3, - }) + response, err := agent.Provider.Chat( + ctx, + []providers.Message{{Role: "user", Content: prompt}}, + nil, + agent.Model, + map[string]any{ + "max_tokens": 1024, + "temperature": 0.3, + }, + ) if err != nil { return "", err } @@ -747,10 +1045,130 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa } // estimateTokens estimates the number of tokens in a message list. +// Uses a safe heuristic of 2.5 characters per token to account for CJK and other +// overheads better than the previous 3 chars/token. func (al *AgentLoop) estimateTokens(messages []providers.Message) int { - total := 0 + totalChars := 0 for _, m := range messages { - total += len(m.Content) / 4 // Simple heuristic: 4 chars per token + totalChars += utf8.RuneCountInString(m.Content) + } + // 2.5 chars per token = totalChars * 2 / 5 + return totalChars * 2 / 5 +} + +func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) { + content := strings.TrimSpace(msg.Content) + if !strings.HasPrefix(content, "/") { + return "", false + } + + parts := strings.Fields(content) + if len(parts) == 0 { + return "", false + } + + cmd := parts[0] + args := parts[1:] + + switch cmd { + case "/show": + if len(args) < 1 { + return "Usage: /show [model|channel|agents]", true + } + switch args[0] { + case "model": + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + return "No default agent configured", true + } + return fmt.Sprintf("Current model: %s", defaultAgent.Model), true + case "channel": + return fmt.Sprintf("Current channel: %s", msg.Channel), true + case "agents": + agentIDs := al.registry.ListAgentIDs() + return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true + default: + return fmt.Sprintf("Unknown show target: %s", args[0]), true + } + + case "/list": + if len(args) < 1 { + return "Usage: /list [models|channels|agents]", true + } + switch args[0] { + case "models": + return "Available models: configured in config.json per agent", true + case "channels": + if al.channelManager == nil { + return "Channel manager not initialized", true + } + channels := al.channelManager.GetEnabledChannels() + if len(channels) == 0 { + return "No channels enabled", true + } + return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true + case "agents": + agentIDs := al.registry.ListAgentIDs() + return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true + default: + return fmt.Sprintf("Unknown list target: %s", args[0]), true + } + + case "/switch": + if len(args) < 3 || args[1] != "to" { + return "Usage: /switch [model|channel] to ", true + } + target := args[0] + value := args[2] + + switch target { + case "model": + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + return "No default agent configured", true + } + oldModel := defaultAgent.Model + defaultAgent.Model = value + return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true + case "channel": + if al.channelManager == nil { + return "Channel manager not initialized", true + } + if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" { + return fmt.Sprintf("Channel '%s' not found or not enabled", value), true + } + return fmt.Sprintf("Switched target channel to %s", value), true + default: + return fmt.Sprintf("Unknown switch target: %s", target), true + } + } + + return "", false +} + +// extractPeer extracts the routing peer from inbound message metadata. +func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { + peerKind := msg.Metadata["peer_kind"] + if peerKind == "" { + return nil + } + peerID := msg.Metadata["peer_id"] + if peerID == "" { + if peerKind == "direct" { + peerID = msg.SenderID + } else { + peerID = msg.ChatID + } + } + return &routing.RoutePeer{Kind: peerKind, ID: peerID} +} + +// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. +func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { + parentKind := msg.Metadata["parent_peer_kind"] + parentID := msg.Metadata["parent_peer_id"] + if parentKind == "" || parentID == "" { + return nil } - return total + return &routing.RoutePeer{Kind: parentKind, ID: parentID} } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 6c0ad044a..4414398b1 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -2,6 +2,7 @@ package agent import ( "context" + "fmt" "os" "path/filepath" "testing" @@ -13,20 +14,6 @@ import ( "github.com/sipeed/picoclaw/pkg/tools" ) -// mockProvider is a simple mock LLM provider for testing -type mockProvider struct{} - -func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { - return &providers.LLMResponse{ - Content: "Mock response", - ToolCalls: []providers.ToolCall{}, - }, nil -} - -func (m *mockProvider) GetDefaultModel() string { - return "mock-model" -} - func TestRecordLastChannel(t *testing.T) { // Create temp workspace tmpDir, err := os.MkdirTemp("", "agent-test-*") @@ -184,7 +171,7 @@ func TestToolRegistry_ToolRegistration(t *testing.T) { // Verify tool is registered by checking it doesn't panic on GetStartupInfo // (actual tool retrieval is tested in tools package tests) info := al.GetStartupInfo() - toolsInfo := info["tools"].(map[string]interface{}) + toolsInfo := info["tools"].(map[string]any) toolsList := toolsInfo["names"].([]string) // Check that our custom tool name is in the list @@ -259,7 +246,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) { al.RegisterTool(testTool) info := al.GetStartupInfo() - toolsInfo := info["tools"].(map[string]interface{}) + toolsInfo := info["tools"].(map[string]any) toolsList := toolsInfo["names"].([]string) // Check that our custom tool name is in the list @@ -306,7 +293,7 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) { t.Fatal("Expected 'tools' key in startup info") } - toolsMap, ok := toolsInfo.(map[string]interface{}) + toolsMap, ok := toolsInfo.(map[string]any) if !ok { t.Fatal("Expected 'tools' to be a map") } @@ -362,9 +349,15 @@ type simpleMockProvider struct { response string } -func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { +func (m *simpleMockProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { return &providers.LLMResponse{ - Content: m.response, + Content: m.response, ToolCalls: []providers.ToolCall{}, }, nil } @@ -384,14 +377,14 @@ func (m *mockCustomTool) Description() string { return "Mock custom tool for testing" } -func (m *mockCustomTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (m *mockCustomTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{}, + "properties": map[string]any{}, } } -func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult { +func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { return tools.SilentResult("Custom tool executed") } @@ -409,14 +402,14 @@ func (m *mockContextualTool) Description() string { return "Mock contextual tool" } -func (m *mockContextualTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (m *mockContextualTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{}, + "properties": map[string]any{}, } } -func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult { +func (m *mockContextualTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { return tools.SilentResult("Contextual tool executed") } @@ -475,7 +468,7 @@ func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "read test.txt", - SessionKey: "test-session", + SessionKey: "test-session", } response := helper.executeAndGetResponse(t, ctx, msg) @@ -517,7 +510,7 @@ func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "run hello", - SessionKey: "test-session", + SessionKey: "test-session", } response := helper.executeAndGetResponse(t, ctx, msg) @@ -527,3 +520,114 @@ func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) { t.Errorf("Expected 'Command output: hello world', got: %s", response) } } + +// failFirstMockProvider fails on the first N calls with a specific error +type failFirstMockProvider struct { + failures int + currentCall int + failError error + successResp string +} + +func (m *failFirstMockProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.currentCall++ + if m.currentCall <= m.failures { + return nil, m.failError + } + return &providers.LLMResponse{ + Content: m.successResp, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *failFirstMockProvider) GetDefaultModel() string { + return "mock-fail-model" +} + +// TestAgentLoop_ContextExhaustionRetry verify that the agent retries on context errors +func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + + // Create a provider that fails once with a context error + contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens") + provider := &failFirstMockProvider{ + failures: 1, + failError: contextErr, + successResp: "Recovered from context error", + } + + al := NewAgentLoop(cfg, msgBus, provider) + + // Inject some history to simulate a full context + sessionKey := "test-session-context" + // Create dummy history + history := []providers.Message{ + {Role: "system", Content: "System prompt"}, + {Role: "user", Content: "Old message 1"}, + {Role: "assistant", Content: "Old response 1"}, + {Role: "user", Content: "Old message 2"}, + {Role: "assistant", Content: "Old response 2"}, + {Role: "user", Content: "Trigger message"}, + } + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("No default agent found") + } + defaultAgent.Sessions.SetHistory(sessionKey, history) + + // Call ProcessDirectWithChannel + // Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration + response, err := al.ProcessDirectWithChannel( + context.Background(), + "Trigger message", + sessionKey, + "test", + "test-chat", + ) + if err != nil { + t.Fatalf("Expected success after retry, got error: %v", err) + } + + if response != "Recovered from context error" { + t.Errorf("Expected 'Recovered from context error', got '%s'", response) + } + + // We expect 2 calls: 1st failed, 2nd succeeded + if provider.currentCall != 2 { + t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall) + } + + // Check final history length + finalHistory := defaultAgent.Sessions.GetHistory(sessionKey) + // We verify that the history has been modified (compressed) + // Original length: 6 + // Expected behavior: compression drops ~50% of history (mid slice) + // We can assert that the length is NOT what it would be without compression. + // Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8 + if len(finalHistory) >= 8 { + t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory)) + } +} diff --git a/pkg/agent/memory.go b/pkg/agent/memory.go index f27882d1c..dd5f4441c 100644 --- a/pkg/agent/memory.go +++ b/pkg/agent/memory.go @@ -10,6 +10,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "time" ) @@ -29,7 +30,7 @@ func NewMemoryStore(workspace string) *MemoryStore { memoryFile := filepath.Join(memoryDir, "MEMORY.md") // Ensure memory directory exists - os.MkdirAll(memoryDir, 0755) + os.MkdirAll(memoryDir, 0o755) return &MemoryStore{ workspace: workspace, @@ -40,8 +41,8 @@ func NewMemoryStore(workspace string) *MemoryStore { // getTodayFile returns the path to today's daily note file (memory/YYYYMM/YYYYMMDD.md). func (ms *MemoryStore) getTodayFile() string { - today := time.Now().Format("20060102") // YYYYMMDD - monthDir := today[:6] // YYYYMM + today := time.Now().Format("20060102") // YYYYMMDD + monthDir := today[:6] // YYYYMM filePath := filepath.Join(ms.memoryDir, monthDir, today+".md") return filePath } @@ -57,7 +58,7 @@ func (ms *MemoryStore) ReadLongTerm() string { // WriteLongTerm writes content to the long-term memory file (MEMORY.md). func (ms *MemoryStore) WriteLongTerm(content string) error { - return os.WriteFile(ms.memoryFile, []byte(content), 0644) + return os.WriteFile(ms.memoryFile, []byte(content), 0o644) } // ReadToday reads today's daily note. @@ -77,7 +78,7 @@ func (ms *MemoryStore) AppendToday(content string) error { // Ensure month directory exists monthDir := filepath.Dir(todayFile) - os.MkdirAll(monthDir, 0755) + os.MkdirAll(monthDir, 0o755) var existingContent string if data, err := os.ReadFile(todayFile); err == nil { @@ -94,68 +95,57 @@ func (ms *MemoryStore) AppendToday(content string) error { newContent = existingContent + "\n" + content } - return os.WriteFile(todayFile, []byte(newContent), 0644) + return os.WriteFile(todayFile, []byte(newContent), 0o644) } // GetRecentDailyNotes returns daily notes from the last N days. // Contents are joined with "---" separator. func (ms *MemoryStore) GetRecentDailyNotes(days int) string { - var notes []string + var sb strings.Builder + first := true for i := 0; i < days; i++ { date := time.Now().AddDate(0, 0, -i) - dateStr := date.Format("20060102") // YYYYMMDD - monthDir := dateStr[:6] // YYYYMM + dateStr := date.Format("20060102") // YYYYMMDD + monthDir := dateStr[:6] // YYYYMM filePath := filepath.Join(ms.memoryDir, monthDir, dateStr+".md") if data, err := os.ReadFile(filePath); err == nil { - notes = append(notes, string(data)) + if !first { + sb.WriteString("\n\n---\n\n") + } + sb.Write(data) + first = false } } - if len(notes) == 0 { - return "" - } - - // Join with separator - var result string - for i, note := range notes { - if i > 0 { - result += "\n\n---\n\n" - } - result += note - } - return result + return sb.String() } // GetMemoryContext returns formatted memory context for the agent prompt. // Includes long-term memory and recent daily notes. func (ms *MemoryStore) GetMemoryContext() string { - var parts []string - - // Long-term memory longTerm := ms.ReadLongTerm() - if longTerm != "" { - parts = append(parts, "## Long-term Memory\n\n"+longTerm) - } - - // Recent daily notes (last 3 days) recentNotes := ms.GetRecentDailyNotes(3) - if recentNotes != "" { - parts = append(parts, "## Recent Daily Notes\n\n"+recentNotes) - } - if len(parts) == 0 { + if longTerm == "" && recentNotes == "" { return "" } - // Join parts with separator - var result string - for i, part := range parts { - if i > 0 { - result += "\n\n---\n\n" + var sb strings.Builder + + if longTerm != "" { + sb.WriteString("## Long-term Memory\n\n") + sb.WriteString(longTerm) + } + + if recentNotes != "" { + if longTerm != "" { + sb.WriteString("\n\n---\n\n") } - result += part + sb.WriteString("## Recent Daily Notes\n\n") + sb.WriteString(recentNotes) } - return fmt.Sprintf("# Memory\n\n%s", result) + + return sb.String() } diff --git a/pkg/agent/mock_provider_test.go b/pkg/agent/mock_provider_test.go new file mode 100644 index 000000000..4962810dc --- /dev/null +++ b/pkg/agent/mock_provider_test.go @@ -0,0 +1,26 @@ +package agent + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +type mockProvider struct{} + +func (m *mockProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + return &providers.LLMResponse{ + Content: "Mock response", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *mockProvider) GetDefaultModel() string { + return "mock-model" +} diff --git a/pkg/agent/registry.go b/pkg/agent/registry.go new file mode 100644 index 000000000..77b846832 --- /dev/null +++ b/pkg/agent/registry.go @@ -0,0 +1,114 @@ +package agent + +import ( + "sync" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" +) + +// AgentRegistry manages multiple agent instances and routes messages to them. +type AgentRegistry struct { + agents map[string]*AgentInstance + resolver *routing.RouteResolver + mu sync.RWMutex +} + +// NewAgentRegistry creates a registry from config, instantiating all agents. +func NewAgentRegistry( + cfg *config.Config, + provider providers.LLMProvider, +) *AgentRegistry { + registry := &AgentRegistry{ + agents: make(map[string]*AgentInstance), + resolver: routing.NewRouteResolver(cfg), + } + + agentConfigs := cfg.Agents.List + if len(agentConfigs) == 0 { + implicitAgent := &config.AgentConfig{ + ID: "main", + Default: true, + } + instance := NewAgentInstance(implicitAgent, &cfg.Agents.Defaults, cfg, provider) + registry.agents["main"] = instance + logger.InfoCF("agent", "Created implicit main agent (no agents.list configured)", nil) + } else { + for i := range agentConfigs { + ac := &agentConfigs[i] + id := routing.NormalizeAgentID(ac.ID) + instance := NewAgentInstance(ac, &cfg.Agents.Defaults, cfg, provider) + registry.agents[id] = instance + logger.InfoCF("agent", "Registered agent", + map[string]any{ + "agent_id": id, + "name": ac.Name, + "workspace": instance.Workspace, + "model": instance.Model, + }) + } + } + + return registry +} + +// GetAgent returns the agent instance for a given ID. +func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + id := routing.NormalizeAgentID(agentID) + agent, ok := r.agents[id] + return agent, ok +} + +// ResolveRoute determines which agent handles the message. +func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute { + return r.resolver.ResolveRoute(input) +} + +// ListAgentIDs returns all registered agent IDs. +func (r *AgentRegistry) ListAgentIDs() []string { + r.mu.RLock() + defer r.mu.RUnlock() + ids := make([]string, 0, len(r.agents)) + for id := range r.agents { + ids = append(ids, id) + } + return ids +} + +// CanSpawnSubagent checks if parentAgentID is allowed to spawn targetAgentID. +func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bool { + parent, ok := r.GetAgent(parentAgentID) + if !ok { + return false + } + if parent.Subagents == nil || parent.Subagents.AllowAgents == nil { + return false + } + targetNorm := routing.NormalizeAgentID(targetAgentID) + for _, allowed := range parent.Subagents.AllowAgents { + if allowed == "*" { + return true + } + if routing.NormalizeAgentID(allowed) == targetNorm { + return true + } + } + return false +} + +// GetDefaultAgent returns the default agent instance. +func (r *AgentRegistry) GetDefaultAgent() *AgentInstance { + r.mu.RLock() + defer r.mu.RUnlock() + if agent, ok := r.agents["main"]; ok { + return agent + } + for _, agent := range r.agents { + return agent + } + return nil +} diff --git a/pkg/agent/registry_test.go b/pkg/agent/registry_test.go new file mode 100644 index 000000000..518bb441f --- /dev/null +++ b/pkg/agent/registry_test.go @@ -0,0 +1,205 @@ +package agent + +import ( + "context" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +type mockRegistryProvider struct{} + +func (m *mockRegistryProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + options map[string]any, +) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "mock", FinishReason: "stop"}, nil +} + +func (m *mockRegistryProvider) GetDefaultModel() string { + return "mock-model" +} + +func testCfg(agents []config.AgentConfig) *config.Config { + return &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: "/tmp/picoclaw-test-registry", + Model: "gpt-4", + MaxTokens: 8192, + MaxToolIterations: 10, + }, + List: agents, + }, + } +} + +func TestNewAgentRegistry_ImplicitMain(t *testing.T) { + cfg := testCfg(nil) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + ids := registry.ListAgentIDs() + if len(ids) != 1 || ids[0] != "main" { + t.Errorf("expected implicit main agent, got %v", ids) + } + + agent, ok := registry.GetAgent("main") + if !ok || agent == nil { + t.Fatal("expected to find 'main' agent") + } + if agent.ID != "main" { + t.Errorf("agent.ID = %q, want 'main'", agent.ID) + } +} + +func TestNewAgentRegistry_ExplicitAgents(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "sales", Default: true, Name: "Sales Bot"}, + {ID: "support", Name: "Support Bot"}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + ids := registry.ListAgentIDs() + if len(ids) != 2 { + t.Fatalf("expected 2 agents, got %d: %v", len(ids), ids) + } + + sales, ok := registry.GetAgent("sales") + if !ok || sales == nil { + t.Fatal("expected to find 'sales' agent") + } + if sales.Name != "Sales Bot" { + t.Errorf("sales.Name = %q, want 'Sales Bot'", sales.Name) + } + + support, ok := registry.GetAgent("support") + if !ok || support == nil { + t.Fatal("expected to find 'support' agent") + } +} + +func TestAgentRegistry_GetAgent_Normalize(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "my-agent", Default: true}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, ok := registry.GetAgent("My-Agent") + if !ok || agent == nil { + t.Fatal("expected to find agent with normalized ID") + } + if agent.ID != "my-agent" { + t.Errorf("agent.ID = %q, want 'my-agent'", agent.ID) + } +} + +func TestAgentRegistry_GetDefaultAgent(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta", Default: true}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + // GetDefaultAgent first checks for "main", then returns any + agent := registry.GetDefaultAgent() + if agent == nil { + t.Fatal("expected a default agent") + } +} + +func TestAgentRegistry_CanSpawnSubagent(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + { + ID: "parent", + Default: true, + Subagents: &config.SubagentsConfig{ + AllowAgents: []string{"child1", "child2"}, + }, + }, + {ID: "child1"}, + {ID: "child2"}, + {ID: "restricted"}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + if !registry.CanSpawnSubagent("parent", "child1") { + t.Error("expected parent to be allowed to spawn child1") + } + if !registry.CanSpawnSubagent("parent", "child2") { + t.Error("expected parent to be allowed to spawn child2") + } + if registry.CanSpawnSubagent("parent", "restricted") { + t.Error("expected parent to NOT be allowed to spawn restricted") + } + if registry.CanSpawnSubagent("child1", "child2") { + t.Error("expected child1 to NOT be allowed to spawn (no subagents config)") + } +} + +func TestAgentRegistry_CanSpawnSubagent_Wildcard(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + { + ID: "admin", + Default: true, + Subagents: &config.SubagentsConfig{ + AllowAgents: []string{"*"}, + }, + }, + {ID: "any-agent"}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + if !registry.CanSpawnSubagent("admin", "any-agent") { + t.Error("expected wildcard to allow spawning any agent") + } + if !registry.CanSpawnSubagent("admin", "nonexistent") { + t.Error("expected wildcard to allow spawning even nonexistent agents") + } +} + +func TestAgentInstance_Model(t *testing.T) { + model := &config.AgentModelConfig{Primary: "claude-opus"} + cfg := testCfg([]config.AgentConfig{ + {ID: "custom", Default: true, Model: model}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, _ := registry.GetAgent("custom") + if agent.Model != "claude-opus" { + t.Errorf("agent.Model = %q, want 'claude-opus'", agent.Model) + } +} + +func TestAgentInstance_FallbackInheritance(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "inherit", Default: true}, + }) + cfg.Agents.Defaults.ModelFallbacks = []string{"openai/gpt-4o-mini", "anthropic/haiku"} + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, _ := registry.GetAgent("inherit") + if len(agent.Fallbacks) != 2 { + t.Errorf("expected 2 fallbacks inherited from defaults, got %d", len(agent.Fallbacks)) + } +} + +func TestAgentInstance_FallbackExplicitEmpty(t *testing.T) { + model := &config.AgentModelConfig{ + Primary: "gpt-4", + Fallbacks: []string{}, // explicitly empty = disable + } + cfg := testCfg([]config.AgentConfig{ + {ID: "no-fallback", Default: true, Model: model}, + }) + cfg.Agents.Defaults.ModelFallbacks = []string{"should-not-inherit"} + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, _ := registry.GetAgent("no-fallback") + if len(agent.Fallbacks) != 0 { + t.Errorf("expected 0 fallbacks (explicit empty), got %d: %v", len(agent.Fallbacks), agent.Fallbacks) + } +} diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go index ecd9ba265..cf8c1c9c4 100644 --- a/pkg/auth/oauth.go +++ b/pkg/auth/oauth.go @@ -1,6 +1,7 @@ package auth import ( + "bufio" "context" "crypto/rand" "encoding/base64" @@ -11,6 +12,7 @@ import ( "net" "net/http" "net/url" + "os" "os/exec" "runtime" "strconv" @@ -19,21 +21,51 @@ import ( ) type OAuthProviderConfig struct { - Issuer string - ClientID string - Scopes string - Port int + Issuer string + ClientID string + ClientSecret string // Required for Google OAuth (confidential client) + TokenURL string // Override token endpoint (Google uses a different URL than issuer) + Scopes string + Originator string + Port int } func OpenAIOAuthConfig() OAuthProviderConfig { return OAuthProviderConfig{ - Issuer: "https://auth.openai.com", - ClientID: "app_EMoamEEZ73f0CkXaXp7hrann", - Scopes: "openid profile email offline_access", - Port: 1455, + Issuer: "https://auth.openai.com", + ClientID: "app_EMoamEEZ73f0CkXaXp7hrann", + Scopes: "openid profile email offline_access", + Originator: "codex_cli_rs", + Port: 1455, } } +// GoogleAntigravityOAuthConfig returns the OAuth configuration for Google Cloud Code Assist (Antigravity). +// Client credentials are the same ones used by OpenCode/pi-ai for Cloud Code Assist access. +func GoogleAntigravityOAuthConfig() OAuthProviderConfig { + // These are the same client credentials used by the OpenCode antigravity plugin. + clientID := decodeBase64( + "MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==", + ) + clientSecret := decodeBase64("R09DU1BYLUs1OEZXUjQ4NkxkTEoxbUxCOHNYQzR6NnFEQWY=") + return OAuthProviderConfig{ + Issuer: "https://accounts.google.com/o/oauth2/v2", + TokenURL: "https://oauth2.googleapis.com/token", + ClientID: clientID, + ClientSecret: clientSecret, + Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/cclog https://www.googleapis.com/auth/experimentsandconfigs", + Port: 51121, + } +} + +func decodeBase64(s string) string { + data, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return s + } + return string(data) +} + func generateState() (string, error) { buf := make([]byte, 32) if _, err := rand.Read(buf); err != nil { @@ -99,8 +131,22 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL) } - fmt.Println("If you're running in a headless environment, use: picoclaw auth login --provider openai --device-code") - fmt.Println("Waiting for authentication in browser...") + fmt.Printf( + "Wait! If you are in a headless environment (like Coolify/VPS) and cannot reach localhost:%d,\n", + cfg.Port, + ) + fmt.Println( + "please complete the login in your local browser and then PASTE the final redirect URL (or just the code) here.", + ) + fmt.Println("Waiting for authentication (browser or manual paste)...") + + // Start manual input in a goroutine + manualCh := make(chan string) + go func() { + reader := bufio.NewReader(os.Stdin) + input, _ := reader.ReadString('\n') + manualCh <- strings.TrimSpace(input) + }() select { case result := <-resultCh: @@ -108,6 +154,22 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { return nil, result.err } return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI) + case manualInput := <-manualCh: + if manualInput == "" { + return nil, fmt.Errorf("manual input cancelled") + } + // Extract code from URL if it's a full URL + code := manualInput + if strings.Contains(manualInput, "?") { + u, err := url.Parse(manualInput) + if err == nil { + code = u.Query().Get("code") + } + } + if code == "" { + return nil, fmt.Errorf("could not find authorization code in input") + } + return exchangeCodeForTokens(cfg, code, pkce.CodeVerifier, redirectURI) case <-time.After(5 * time.Minute): return nil, fmt.Errorf("authentication timed out after 5 minutes") } @@ -198,8 +260,11 @@ func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { deviceResp.Interval = 5 } - fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n", - cfg.Issuer, deviceResp.UserCode) + fmt.Printf( + "\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n", + cfg.Issuer, + deviceResp.UserCode, + ) deadline := time.After(15 * time.Minute) ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second) @@ -267,8 +332,16 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre "refresh_token": {cred.RefreshToken}, "scope": {"openid profile email"}, } + if cfg.ClientSecret != "" { + data.Set("client_secret", cfg.ClientSecret) + } + + tokenURL := cfg.Issuer + "/oauth/token" + if cfg.TokenURL != "" { + tokenURL = cfg.TokenURL + } - resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + resp, err := http.PostForm(tokenURL, data) if err != nil { return nil, fmt.Errorf("refreshing token: %w", err) } @@ -279,7 +352,23 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre return nil, fmt.Errorf("token refresh failed: %s", string(body)) } - return parseTokenResponse(body, cred.Provider) + refreshed, err := parseTokenResponse(body, cred.Provider) + if err != nil { + return nil, err + } + if refreshed.RefreshToken == "" { + refreshed.RefreshToken = cred.RefreshToken + } + if refreshed.AccountID == "" { + refreshed.AccountID = cred.AccountID + } + if cred.Email != "" && refreshed.Email == "" { + refreshed.Email = cred.Email + } + if cred.ProjectID != "" && refreshed.ProjectID == "" { + refreshed.ProjectID = cred.ProjectID + } + return refreshed, nil } func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { @@ -296,7 +385,29 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU "code_challenge_method": {"S256"}, "state": {state}, } - return cfg.Issuer + "/authorize?" + params.Encode() + + isGoogle := strings.Contains(strings.ToLower(cfg.Issuer), "accounts.google.com") + if isGoogle { + // Google OAuth requires these for refresh token support + params.Set("access_type", "offline") + params.Set("prompt", "consent") + } else { + // OpenAI-specific parameters + params.Set("id_token_add_organizations", "true") + params.Set("codex_cli_simplified_flow", "true") + if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") { + params.Set("originator", "picoclaw") + } + if cfg.Originator != "" { + params.Set("originator", cfg.Originator) + } + } + + // Google uses /auth path, OpenAI uses /oauth/authorize + if isGoogle { + return cfg.Issuer + "/auth?" + params.Encode() + } + return cfg.Issuer + "/oauth/authorize?" + params.Encode() } func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) { @@ -307,8 +418,22 @@ func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirect "client_id": {cfg.ClientID}, "code_verifier": {codeVerifier}, } + if cfg.ClientSecret != "" { + data.Set("client_secret", cfg.ClientSecret) + } - resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + tokenURL := cfg.Issuer + "/oauth/token" + if cfg.TokenURL != "" { + tokenURL = cfg.TokenURL + } + + // Determine provider name from config + provider := "openai" + if cfg.TokenURL != "" && strings.Contains(cfg.TokenURL, "googleapis.com") { + provider = "google-antigravity" + } + + resp, err := http.PostForm(tokenURL, data) if err != nil { return nil, fmt.Errorf("exchanging code for tokens: %w", err) } @@ -319,7 +444,7 @@ func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirect return nil, fmt.Errorf("token exchange failed: %s", string(body)) } - return parseTokenResponse(body, "openai") + return parseTokenResponse(body, provider) } func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { @@ -350,19 +475,57 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { AuthMethod: "oauth", } - if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { + if accountID := extractAccountID(tokenResp.IDToken); accountID != "" { + cred.AccountID = accountID + } else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { + cred.AccountID = accountID + } else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" { + // Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims. cred.AccountID = accountID } return cred, nil } -func extractAccountID(accessToken string) string { - parts := strings.Split(accessToken, ".") - if len(parts) < 2 { +func extractAccountID(token string) string { + claims, err := parseJWTClaims(token) + if err != nil { return "" } + if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" { + return accountID + } + + if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" { + return accountID + } + + if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]any); ok { + if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" { + return accountID + } + } + + if orgs, ok := claims["organizations"].([]any); ok { + for _, org := range orgs { + if orgMap, ok := org.(map[string]any); ok { + if accountID, ok := orgMap["id"].(string); ok && accountID != "" { + return accountID + } + } + } + } + + return "" +} + +func parseJWTClaims(token string) (map[string]any, error) { + parts := strings.Split(token, ".") + if len(parts) < 2 { + return nil, fmt.Errorf("token is not a JWT") + } + payload := parts[1] switch len(payload) % 4 { case 2: @@ -373,21 +536,15 @@ func extractAccountID(accessToken string) string { decoded, err := base64URLDecode(payload) if err != nil { - return "" + return nil, err } - var claims map[string]interface{} + var claims map[string]any if err := json.Unmarshal(decoded, &claims); err != nil { - return "" - } - - if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok { - if accountID, ok := authClaim["chatgpt_account_id"].(string); ok { - return accountID - } + return nil, err } - return "" + return claims, nil } func base64URLDecode(s string) ([]byte, error) { diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go index 9f8013217..0cb589069 100644 --- a/pkg/auth/oauth_test.go +++ b/pkg/auth/oauth_test.go @@ -1,19 +1,34 @@ package auth import ( + "encoding/base64" "encoding/json" "net/http" "net/http/httptest" + "net/url" "strings" "testing" ) +func makeJWTForClaims(t *testing.T, claims map[string]any) string { + t.Helper() + + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payloadJSON, err := json.Marshal(claims) + if err != nil { + t.Fatalf("marshal claims: %v", err) + } + payload := base64.RawURLEncoding.EncodeToString(payloadJSON) + return header + "." + payload + ".sig" +} + func TestBuildAuthorizeURL(t *testing.T) { cfg := OAuthProviderConfig{ - Issuer: "https://auth.example.com", - ClientID: "test-client-id", - Scopes: "openid profile", - Port: 1455, + Issuer: "https://auth.example.com", + ClientID: "test-client-id", + Scopes: "openid profile", + Originator: "codex_cli_rs", + Port: 1455, } pkce := PKCECodes{ CodeVerifier: "test-verifier", @@ -22,7 +37,7 @@ func TestBuildAuthorizeURL(t *testing.T) { u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback") - if !strings.HasPrefix(u, "https://auth.example.com/authorize?") { + if !strings.HasPrefix(u, "https://auth.example.com/oauth/authorize?") { t.Errorf("URL does not start with expected prefix: %s", u) } if !strings.Contains(u, "client_id=test-client-id") { @@ -40,10 +55,41 @@ func TestBuildAuthorizeURL(t *testing.T) { if !strings.Contains(u, "response_type=code") { t.Error("URL missing response_type") } + if !strings.Contains(u, "id_token_add_organizations=true") { + t.Error("URL missing id_token_add_organizations") + } + if !strings.Contains(u, "codex_cli_simplified_flow=true") { + t.Error("URL missing codex_cli_simplified_flow") + } + if !strings.Contains(u, "originator=codex_cli_rs") { + t.Error("URL missing originator") + } +} + +func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) { + cfg := OpenAIOAuthConfig() + pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"} + + u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback") + parsed, err := url.Parse(u) + if err != nil { + t.Fatalf("url.Parse() error: %v", err) + } + q := parsed.Query() + + if q.Get("id_token_add_organizations") != "true" { + t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations")) + } + if q.Get("codex_cli_simplified_flow") != "true" { + t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow")) + } + if q.Get("originator") != "codex_cli_rs" { + t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator")) + } } func TestParseTokenResponse(t *testing.T) { - resp := map[string]interface{}{ + resp := map[string]any{ "access_token": "test-access-token", "refresh_token": "test-refresh-token", "expires_in": 3600, @@ -73,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) { } } +func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) { + idToken := makeJWTForClaims(t, map[string]any{"chatgpt_account_id": "acc-id-from-id-token"}) + resp := map[string]any{ + "access_token": "opaque-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "id_token": idToken, + } + body, _ := json.Marshal(resp) + + cred, err := parseTokenResponse(body, "openai") + if err != nil { + t.Fatalf("parseTokenResponse() error: %v", err) + } + if cred.AccountID != "acc-id-from-id-token" { + t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token") + } +} + +func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) { + token := makeJWTForClaims(t, map[string]any{ + "organizations": []any{ + map[string]any{"id": "org_from_orgs"}, + }, + }) + + if got := extractAccountID(token); got != "org_from_orgs" { + t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs") + } +} + func TestParseTokenResponseNoAccessToken(t *testing.T) { body := []byte(`{"refresh_token": "test"}`) _, err := parseTokenResponse(body, "openai") @@ -81,6 +158,34 @@ func TestParseTokenResponseNoAccessToken(t *testing.T) { } } +func TestParseTokenResponseAccountIDFromIDToken(t *testing.T) { + idToken := makeJWTWithAccountID("acc-from-id") + resp := map[string]any{ + "access_token": "not-a-jwt", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "id_token": idToken, + } + body, _ := json.Marshal(resp) + + cred, err := parseTokenResponse(body, "openai") + if err != nil { + t.Fatalf("parseTokenResponse() error: %v", err) + } + + if cred.AccountID != "acc-from-id" { + t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-from-id") + } +} + +func makeJWTWithAccountID(accountID string) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payload := base64.RawURLEncoding.EncodeToString( + []byte(`{"https://api.openai.com/auth":{"chatgpt_account_id":"` + accountID + `"}}`), + ) + return header + "." + payload + ".sig" +} + func TestExchangeCodeForTokens(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/oauth/token" { @@ -98,7 +203,7 @@ func TestExchangeCodeForTokens(t *testing.T) { return } - resp := map[string]interface{}{ + resp := map[string]any{ "access_token": "mock-access-token", "refresh_token": "mock-refresh-token", "expires_in": 3600, @@ -137,7 +242,7 @@ func TestRefreshAccessToken(t *testing.T) { return } - resp := map[string]interface{}{ + resp := map[string]any{ "access_token": "refreshed-access-token", "refresh_token": "refreshed-refresh-token", "expires_in": 3600, @@ -185,6 +290,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) { } } +func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "access_token": "new-access-token-only", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"} + cred := &AuthCredential{ + AccessToken: "old-access", + RefreshToken: "existing-refresh", + AccountID: "acc_existing", + Provider: "openai", + AuthMethod: "oauth", + } + + refreshed, err := RefreshAccessToken(cred, cfg) + if err != nil { + t.Fatalf("RefreshAccessToken() error: %v", err) + } + if refreshed.RefreshToken != "existing-refresh" { + t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh") + } + if refreshed.AccountID != "acc_existing" { + t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing") + } +} + func TestOpenAIOAuthConfig(t *testing.T) { cfg := OpenAIOAuthConfig() if cfg.Issuer != "https://auth.openai.com" { diff --git a/pkg/auth/store.go b/pkg/auth/store.go index 20724929a..64708421b 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -14,6 +14,8 @@ type AuthCredential struct { ExpiresAt time.Time `json:"expires_at,omitempty"` Provider string `json:"provider"` AuthMethod string `json:"auth_method"` + Email string `json:"email,omitempty"` + ProjectID string `json:"project_id,omitempty"` } type AuthStore struct { @@ -62,7 +64,7 @@ func LoadStore() (*AuthStore, error) { func SaveStore(store *AuthStore) error { path := authFilePath() dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0755); err != nil { + if err := os.MkdirAll(dir, 0o755); err != nil { return err } @@ -70,7 +72,7 @@ func SaveStore(store *AuthStore) error { if err != nil { return err } - return os.WriteFile(path, data, 0600) + return os.WriteFile(path, data, 0o600) } func GetCredential(provider string) (*AuthCredential, error) { diff --git a/pkg/auth/store_test.go b/pkg/auth/store_test.go index d96b460a1..f6793cfce 100644 --- a/pkg/auth/store_test.go +++ b/pkg/auth/store_test.go @@ -108,7 +108,7 @@ func TestStoreFilePermissions(t *testing.T) { t.Fatalf("Stat() error: %v", err) } perm := info.Mode().Perm() - if perm != 0600 { + if perm != 0o600 { t.Errorf("file permissions = %o, want 0600", perm) } } diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 6283251a4..58c0a25d5 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -9,6 +9,7 @@ type MessageBus struct { inbound chan InboundMessage outbound chan OutboundMessage handlers map[string]MessageHandler + closed bool mu sync.RWMutex } @@ -21,6 +22,11 @@ func NewMessageBus() *MessageBus { } func (mb *MessageBus) PublishInbound(msg InboundMessage) { + mb.mu.RLock() + defer mb.mu.RUnlock() + if mb.closed { + return + } mb.inbound <- msg } @@ -34,6 +40,11 @@ func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) } func (mb *MessageBus) PublishOutbound(msg OutboundMessage) { + mb.mu.RLock() + defer mb.mu.RUnlock() + if mb.closed { + return + } mb.outbound <- msg } @@ -60,6 +71,12 @@ func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) { } func (mb *MessageBus) Close() { + mb.mu.Lock() + defer mb.mu.Unlock() + if mb.closed { + return + } + mb.closed = true close(mb.inbound) close(mb.outbound) } diff --git a/pkg/channels/base.go b/pkg/channels/base.go index fabec1a86..cd6419ebb 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -2,7 +2,6 @@ package channels import ( "context" - "fmt" "strings" "github.com/sipeed/picoclaw/pkg/bus" @@ -18,14 +17,14 @@ type Channel interface { } type BaseChannel struct { - config interface{} + config any bus *bus.MessageBus running bool name string allowList []string } -func NewBaseChannel(name string, config interface{}, bus *bus.MessageBus, allowList []string) *BaseChannel { +func NewBaseChannel(name string, config any, bus *bus.MessageBus, allowList []string) *BaseChannel { return &BaseChannel{ config: config, bus: bus, @@ -59,7 +58,22 @@ func (c *BaseChannel) IsAllowed(senderID string) bool { for _, allowed := range c.allowList { // Strip leading "@" from allowed value for username matching trimmed := strings.TrimPrefix(allowed, "@") - if senderID == allowed || idPart == allowed || senderID == trimmed || idPart == trimmed || (userPart != "" && (userPart == allowed || userPart == trimmed)) { + allowedID := trimmed + allowedUser := "" + if idx := strings.Index(trimmed, "|"); idx > 0 { + allowedID = trimmed[:idx] + allowedUser = trimmed[idx+1:] + } + + // Support either side using "id|username" compound form. + // This keeps backward compatibility with legacy Telegram allowlist entries. + if senderID == allowed || + idPart == allowed || + senderID == trimmed || + idPart == trimmed || + idPart == allowedID || + (allowedUser != "" && senderID == allowedUser) || + (userPart != "" && (userPart == allowed || userPart == trimmed || userPart == allowedUser)) { return true } } @@ -72,17 +86,13 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st return } - // Build session key: channel:chatID - sessionKey := fmt.Sprintf("%s:%s", c.name, chatID) - msg := bus.InboundMessage{ - Channel: c.name, - SenderID: senderID, - ChatID: chatID, - Content: content, - Media: media, - SessionKey: sessionKey, - Metadata: metadata, + Channel: c.name, + SenderID: senderID, + ChatID: chatID, + Content: content, + Media: media, + Metadata: metadata, } c.bus.PublishInbound(msg) diff --git a/pkg/channels/base_test.go b/pkg/channels/base_test.go new file mode 100644 index 000000000..78c6d1d66 --- /dev/null +++ b/pkg/channels/base_test.go @@ -0,0 +1,52 @@ +package channels + +import "testing" + +func TestBaseChannelIsAllowed(t *testing.T) { + tests := []struct { + name string + allowList []string + senderID string + want bool + }{ + { + name: "empty allowlist allows all", + allowList: nil, + senderID: "anyone", + want: true, + }, + { + name: "compound sender matches numeric allowlist", + allowList: []string{"123456"}, + senderID: "123456|alice", + want: true, + }, + { + name: "compound sender matches username allowlist", + allowList: []string{"@alice"}, + senderID: "123456|alice", + want: true, + }, + { + name: "numeric sender matches legacy compound allowlist", + allowList: []string{"123456|alice"}, + senderID: "123456", + want: true, + }, + { + name: "non matching sender is denied", + allowList: []string{"123456"}, + senderID: "654321|bob", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := NewBaseChannel("test", nil, nil, tt.allowList) + if got := ch.IsAllowed(tt.senderID); got != tt.want { + t.Fatalf("IsAllowed(%q) = %v, want %v", tt.senderID, got, tt.want) + } + }) + } +} diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk.go index 5c6f29f0c..662fba3b7 100644 --- a/pkg/channels/dingtalk.go +++ b/pkg/channels/dingtalk.go @@ -10,6 +10,7 @@ import ( "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" @@ -20,12 +21,12 @@ import ( // It uses WebSocket for receiving messages via stream mode and API for sending type DingTalkChannel struct { *BaseChannel - config config.DingTalkConfig - clientID string - clientSecret string - streamClient *client.StreamClient - ctx context.Context - cancel context.CancelFunc + config config.DingTalkConfig + clientID string + clientSecret string + streamClient *client.StreamClient + ctx context.Context + cancel context.CancelFunc // Map to store session webhooks for each chat sessionWebhooks sync.Map // chatID -> sessionWebhook } @@ -108,9 +109,9 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) } - logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{ - "chat_id": msg.ChatID, - "preview": utils.Truncate(msg.Content, 100), + logger.DebugCF("dingtalk", "Sending message", map[string]any{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), }) // Use the session webhook to send the reply @@ -120,12 +121,15 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // onChatBotMessageReceived implements the IChatBotMessageHandler function signature // This is called by the Stream SDK when a new message arrives // IChatBotMessageHandler is: func(c context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) -func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) { +func (c *DingTalkChannel) onChatBotMessageReceived( + ctx context.Context, + data *chatbot.BotCallbackDataModel, +) ([]byte, error) { // Extract message content from Text field content := data.Text.Content if content == "" { // Try to extract from Content interface{} if Text is empty - if contentMap, ok := data.Content.(map[string]interface{}); ok { + if contentMap, ok := data.Content.(map[string]any); ok { if textContent, ok := contentMap["content"].(string); ok { content = textContent } @@ -155,7 +159,15 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch "session_webhook": data.SessionWebhook, } - logger.DebugCF("dingtalk", "Received message", map[string]interface{}{ + if data.ConversationType == "1" { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } else { + metadata["peer_kind"] = "group" + metadata["peer_id"] = data.ConversationId + } + + logger.DebugCF("dingtalk", "Received message", map[string]any{ "sender_nick": senderNick, "sender_id": senderID, "preview": utils.Truncate(content, 50), @@ -184,7 +196,6 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c titleBytes, contentBytes, ) - if err != nil { return fmt.Errorf("failed to send reply: %w", err) } diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index e65c99eec..20f3b267c 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -4,9 +4,12 @@ import ( "context" "fmt" "os" + "strings" + "sync" "time" "github.com/bwmarrin/discordgo" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" @@ -25,6 +28,9 @@ type DiscordChannel struct { config config.DiscordConfig transcriber *voice.GroqTranscriber ctx context.Context + typingMu sync.Mutex + typingStop map[string]chan struct{} // chatID → stop signal + botUserID string // stored for mention checking } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { @@ -41,6 +47,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC config: cfg, transcriber: nil, ctx: context.Background(), + typingStop: make(map[string]chan struct{}), }, nil } @@ -59,6 +66,14 @@ func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") c.ctx = ctx + + // Get bot user ID before opening session to avoid race condition + botUser, err := c.session.User("@me") + if err != nil { + return fmt.Errorf("failed to get bot user: %w", err) + } + c.botUserID = botUser.ID + c.session.AddHandler(c.handleMessage) if err := c.session.Open(); err != nil { @@ -67,10 +82,6 @@ func (c *DiscordChannel) Start(ctx context.Context) error { c.setRunning(true) - botUser, err := c.session.User("@me") - if err != nil { - return fmt.Errorf("failed to get bot user: %w", err) - } logger.InfoCF("discord", "Discord bot connected", map[string]any{ "username": botUser.Username, "user_id": botUser.ID, @@ -83,6 +94,14 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { logger.InfoC("discord", "Stopping Discord bot") c.setRunning(false) + // Stop all typing goroutines before closing session + c.typingMu.Lock() + for chatID, stop := range c.typingStop { + close(stop) + delete(c.typingStop, chatID) + } + c.typingMu.Unlock() + if err := c.session.Close(); err != nil { return fmt.Errorf("failed to close discord session: %w", err) } @@ -91,6 +110,8 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { } func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + c.stopTyping(msg.ChatID) + if !c.IsRunning() { return fmt.Errorf("discord bot not running") } @@ -100,15 +121,30 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro return fmt.Errorf("channel ID is empty") } - message := msg.Content + runes := []rune(msg.Content) + if len(runes) == 0 { + return nil + } + + chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars + + for _, chunk := range chunks { + if err := c.sendChunk(ctx, channelID, chunk); err != nil { + return err + } + } + + return nil +} - // 使用传入的 ctx 进行超时控制 +func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { + // Use the passed ctx for timeout control sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) defer cancel() done := make(chan error, 1) go func() { - _, err := c.session.ChannelMessageSend(channelID, message) + _, err := c.session.ChannelMessageSend(channelID, content) done <- err }() @@ -123,7 +159,7 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro } } -// appendContent 安全地追加内容到现有文本 +// appendContent safely appends content to existing text func appendContent(content, suffix string) string { if content == "" { return suffix @@ -140,7 +176,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } - // 检查白名单,避免为被拒绝的用户下载附件和转录 + // Check allowlist first to avoid downloading attachments and transcribing for rejected users if !c.IsAllowed(m.Author.ID) { logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ "user_id": m.Author.ID, @@ -148,6 +184,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } + // If configured to only respond to mentions, check if bot is mentioned + // Skip this check for DMs (GuildID is empty) - DMs should always be responded to + if c.config.MentionOnly && m.GuildID != "" { + isMentioned := false + for _, mention := range m.Mentions { + if mention.ID == c.botUserID { + isMentioned = true + break + } + } + if !isMentioned { + logger.DebugCF("discord", "Message ignored - bot not mentioned", map[string]any{ + "user_id": m.Author.ID, + }) + return + } + } + senderID := m.Author.ID senderName := m.Author.Username if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { @@ -155,10 +209,11 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag } content := m.Content + content = c.stripBotMention(content) mediaPaths := make([]string, 0, len(m.Attachments)) localFiles := make([]string, 0, len(m.Attachments)) - // 确保临时文件在函数返回时被清理 + // Ensure temp files are cleaned up when function returns defer func() { for _, file := range localFiles { if err := os.Remove(file); err != nil { @@ -182,7 +237,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag if c.transcriber != nil && c.transcriber.IsAvailable() { ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) result, err := c.transcriber.Transcribe(ctx, localPath) - cancel() // 立即释放context资源,避免在for循环中泄漏 + cancel() // Release context resources immediately to avoid leaks in for loop if err != nil { logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ @@ -222,12 +277,22 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content = "[media only]" } + // Start typing after all early returns — guaranteed to have a matching Send() + c.startTyping(m.ChannelID) + logger.DebugCF("discord", "Received message", map[string]any{ "sender_name": senderName, "sender_id": senderID, "preview": utils.Truncate(content, 50), }) + peerKind := "channel" + peerID := m.ChannelID + if m.GuildID == "" { + peerKind = "direct" + peerID = senderID + } + metadata := map[string]string{ "message_id": m.ID, "user_id": senderID, @@ -236,13 +301,73 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag "guild_id": m.GuildID, "channel_id": m.ChannelID, "is_dm": fmt.Sprintf("%t", m.GuildID == ""), + "peer_kind": peerKind, + "peer_id": peerID, } c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) } +// startTyping starts a continuous typing indicator loop for the given chatID. +// It stops any existing typing loop for that chatID before starting a new one. +func (c *DiscordChannel) startTyping(chatID string) { + c.typingMu.Lock() + // Stop existing loop for this chatID if any + if stop, ok := c.typingStop[chatID]; ok { + close(stop) + } + stop := make(chan struct{}) + c.typingStop[chatID] = stop + c.typingMu.Unlock() + + go func() { + if err := c.session.ChannelTyping(chatID); err != nil { + logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err}) + } + ticker := time.NewTicker(8 * time.Second) + defer ticker.Stop() + timeout := time.After(5 * time.Minute) + for { + select { + case <-stop: + return + case <-timeout: + return + case <-c.ctx.Done(): + return + case <-ticker.C: + if err := c.session.ChannelTyping(chatID); err != nil { + logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err}) + } + } + } + }() +} + +// stopTyping stops the typing indicator loop for the given chatID. +func (c *DiscordChannel) stopTyping(chatID string) { + c.typingMu.Lock() + defer c.typingMu.Unlock() + if stop, ok := c.typingStop[chatID]; ok { + close(stop) + delete(c.typingStop, chatID) + } +} + func (c *DiscordChannel) downloadAttachment(url, filename string) string { return utils.DownloadFile(url, filename, utils.DownloadOptions{ LoggerPrefix: "discord", }) } + +// stripBotMention removes the bot mention from the message content. +// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname). +func (c *DiscordChannel) stripBotMention(text string) string { + if c.botUserID == "" { + return text + } + // Remove both regular mention <@USER_ID> and nickname mention <@!USER_ID> + text = strings.ReplaceAll(text, fmt.Sprintf("<@%s>", c.botUserID), "") + text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "") + return strings.TrimSpace(text) +} diff --git a/pkg/channels/feishu_32.go b/pkg/channels/feishu_32.go new file mode 100644 index 000000000..5109b8195 --- /dev/null +++ b/pkg/channels/feishu_32.go @@ -0,0 +1,38 @@ +//go:build !amd64 && !arm64 && !riscv64 && !mips64 && !ppc64 + +package channels + +import ( + "context" + "errors" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +// FeishuChannel is a stub implementation for 32-bit architectures +type FeishuChannel struct { + *BaseChannel +} + +// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported +func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { + return nil, errors.New( + "feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config", + ) +} + +// Start is a stub method to satisfy the Channel interface +func (c *FeishuChannel) Start(ctx context.Context) error { + return nil +} + +// Stop is a stub method to satisfy the Channel interface +func (c *FeishuChannel) Stop(ctx context.Context) error { + return nil +} + +// Send is a stub method to satisfy the Channel interface +func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + return errors.New("feishu channel is not supported on 32-bit architectures") +} diff --git a/pkg/channels/feishu.go b/pkg/channels/feishu_64.go similarity index 92% rename from pkg/channels/feishu.go rename to pkg/channels/feishu_64.go index 11dbd6748..42e74980f 100644 --- a/pkg/channels/feishu.go +++ b/pkg/channels/feishu_64.go @@ -1,3 +1,5 @@ +//go:build amd64 || arm64 || riscv64 || mips64 || ppc64 + package channels import ( @@ -63,7 +65,7 @@ func (c *FeishuChannel) Start(ctx context.Context) error { go func() { if err := wsClient.Start(runCtx); err != nil { - logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]interface{}{ + logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]any{ "error": err.Error(), }) } @@ -119,7 +121,7 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg) } - logger.DebugCF("feishu", "Feishu message sent", map[string]interface{}{ + logger.DebugCF("feishu", "Feishu message sent", map[string]any{ "chat_id": msg.ChatID, }) @@ -163,7 +165,16 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 metadata["tenant_key"] = *sender.TenantKey } - logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{ + chatType := stringValue(message.ChatType) + if chatType == "p2p" { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } else { + metadata["peer_kind"] = "group" + metadata["peer_id"] = chatID + } + + logger.InfoCF("feishu", "Feishu message received", map[string]any{ "sender_id": senderID, "chat_id": chatID, "preview": utils.Truncate(content, 80), diff --git a/pkg/channels/line.go b/pkg/channels/line.go new file mode 100644 index 000000000..44134996f --- /dev/null +++ b/pkg/channels/line.go @@ -0,0 +1,606 @@ +package channels + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +const ( + lineAPIBase = "https://api.line.me/v2/bot" + lineDataAPIBase = "https://api-data.line.me/v2/bot" + lineReplyEndpoint = lineAPIBase + "/message/reply" + linePushEndpoint = lineAPIBase + "/message/push" + lineContentEndpoint = lineDataAPIBase + "/message/%s/content" + lineBotInfoEndpoint = lineAPIBase + "/info" + lineLoadingEndpoint = lineAPIBase + "/chat/loading/start" + lineReplyTokenMaxAge = 25 * time.Second +) + +type replyTokenEntry struct { + token string + timestamp time.Time +} + +// LINEChannel implements the Channel interface for LINE Official Account +// using the LINE Messaging API with HTTP webhook for receiving messages +// and REST API for sending messages. +type LINEChannel struct { + *BaseChannel + config config.LINEConfig + httpServer *http.Server + botUserID string // Bot's user ID + botBasicID string // Bot's basic ID (e.g. @216ru...) + botDisplayName string // Bot's display name for text-based mention detection + replyTokens sync.Map // chatID -> replyTokenEntry + quoteTokens sync.Map // chatID -> quoteToken (string) + ctx context.Context + cancel context.CancelFunc +} + +// NewLINEChannel creates a new LINE channel instance. +func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINEChannel, error) { + if cfg.ChannelSecret == "" || cfg.ChannelAccessToken == "" { + return nil, fmt.Errorf("line channel_secret and channel_access_token are required") + } + + base := NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom) + + return &LINEChannel{ + BaseChannel: base, + config: cfg, + }, nil +} + +// Start launches the HTTP webhook server. +func (c *LINEChannel) Start(ctx context.Context) error { + logger.InfoC("line", "Starting LINE channel (Webhook Mode)") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Fetch bot profile to get bot's userId for mention detection + if err := c.fetchBotInfo(); err != nil { + logger.WarnCF("line", "Failed to fetch bot info (mention detection disabled)", map[string]any{ + "error": err.Error(), + }) + } else { + logger.InfoCF("line", "Bot info fetched", map[string]any{ + "bot_user_id": c.botUserID, + "basic_id": c.botBasicID, + "display_name": c.botDisplayName, + }) + } + + mux := http.NewServeMux() + path := c.config.WebhookPath + if path == "" { + path = "/webhook/line" + } + mux.HandleFunc(path, c.webhookHandler) + + addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) + c.httpServer = &http.Server{ + Addr: addr, + Handler: mux, + } + + go func() { + logger.InfoCF("line", "LINE webhook server listening", map[string]any{ + "addr": addr, + "path": path, + }) + if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("line", "Webhook server error", map[string]any{ + "error": err.Error(), + }) + } + }() + + c.setRunning(true) + logger.InfoC("line", "LINE channel started (Webhook Mode)") + return nil +} + +// fetchBotInfo retrieves the bot's userId, basicId, and displayName from the LINE API. +func (c *LINEChannel) fetchBotInfo() error { + req, err := http.NewRequest(http.MethodGet, lineBotInfoEndpoint, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bot info API returned status %d", resp.StatusCode) + } + + var info struct { + UserID string `json:"userId"` + BasicID string `json:"basicId"` + DisplayName string `json:"displayName"` + } + if err := json.NewDecoder(resp.Body).Decode(&info); err != nil { + return err + } + + c.botUserID = info.UserID + c.botBasicID = info.BasicID + c.botDisplayName = info.DisplayName + return nil +} + +// Stop gracefully shuts down the HTTP server. +func (c *LINEChannel) Stop(ctx context.Context) error { + logger.InfoC("line", "Stopping LINE channel") + + if c.cancel != nil { + c.cancel() + } + + if c.httpServer != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := c.httpServer.Shutdown(shutdownCtx); err != nil { + logger.ErrorCF("line", "Webhook server shutdown error", map[string]any{ + "error": err.Error(), + }) + } + } + + c.setRunning(false) + logger.InfoC("line", "LINE channel stopped") + return nil +} + +// webhookHandler handles incoming LINE webhook requests. +func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + logger.ErrorCF("line", "Failed to read request body", map[string]any{ + "error": err.Error(), + }) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + signature := r.Header.Get("X-Line-Signature") + if !c.verifySignature(body, signature) { + logger.WarnC("line", "Invalid webhook signature") + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + var payload struct { + Events []lineEvent `json:"events"` + } + if err := json.Unmarshal(body, &payload); err != nil { + logger.ErrorCF("line", "Failed to parse webhook payload", map[string]any{ + "error": err.Error(), + }) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Return 200 immediately, process events asynchronously + w.WriteHeader(http.StatusOK) + + for _, event := range payload.Events { + go c.processEvent(event) + } +} + +// verifySignature validates the X-Line-Signature using HMAC-SHA256. +func (c *LINEChannel) verifySignature(body []byte, signature string) bool { + if signature == "" { + return false + } + + mac := hmac.New(sha256.New, []byte(c.config.ChannelSecret)) + mac.Write(body) + expected := base64.StdEncoding.EncodeToString(mac.Sum(nil)) + + return hmac.Equal([]byte(expected), []byte(signature)) +} + +// LINE webhook event types +type lineEvent struct { + Type string `json:"type"` + ReplyToken string `json:"replyToken"` + Source lineSource `json:"source"` + Message json.RawMessage `json:"message"` + Timestamp int64 `json:"timestamp"` +} + +type lineSource struct { + Type string `json:"type"` // "user", "group", "room" + UserID string `json:"userId"` + GroupID string `json:"groupId"` + RoomID string `json:"roomId"` +} + +type lineMessage struct { + ID string `json:"id"` + Type string `json:"type"` // "text", "image", "video", "audio", "file", "sticker" + Text string `json:"text"` + QuoteToken string `json:"quoteToken"` + Mention *struct { + Mentionees []lineMentionee `json:"mentionees"` + } `json:"mention"` + ContentProvider struct { + Type string `json:"type"` + } `json:"contentProvider"` +} + +type lineMentionee struct { + Index int `json:"index"` + Length int `json:"length"` + Type string `json:"type"` // "user", "all" + UserID string `json:"userId"` +} + +func (c *LINEChannel) processEvent(event lineEvent) { + if event.Type != "message" { + logger.DebugCF("line", "Ignoring non-message event", map[string]any{ + "type": event.Type, + }) + return + } + + senderID := event.Source.UserID + chatID := c.resolveChatID(event.Source) + isGroup := event.Source.Type == "group" || event.Source.Type == "room" + + var msg lineMessage + if err := json.Unmarshal(event.Message, &msg); err != nil { + logger.ErrorCF("line", "Failed to parse message", map[string]any{ + "error": err.Error(), + }) + return + } + + // In group chats, only respond when the bot is mentioned + if isGroup && !c.isBotMentioned(msg) { + logger.DebugCF("line", "Ignoring group message without mention", map[string]any{ + "chat_id": chatID, + }) + return + } + + // Store reply token for later use + if event.ReplyToken != "" { + c.replyTokens.Store(chatID, replyTokenEntry{ + token: event.ReplyToken, + timestamp: time.Now(), + }) + } + + // Store quote token for quoting the original message in reply + if msg.QuoteToken != "" { + c.quoteTokens.Store(chatID, msg.QuoteToken) + } + + var content string + var mediaPaths []string + localFiles := []string{} + + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("line", "Failed to cleanup temp file", map[string]any{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + switch msg.Type { + case "text": + content = msg.Text + // Strip bot mention from text in group chats + if isGroup { + content = c.stripBotMention(content, msg) + } + case "image": + localPath := c.downloadContent(msg.ID, "image.jpg") + if localPath != "" { + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + content = "[image]" + } + case "audio": + localPath := c.downloadContent(msg.ID, "audio.m4a") + if localPath != "" { + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + content = "[audio]" + } + case "video": + localPath := c.downloadContent(msg.ID, "video.mp4") + if localPath != "" { + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + content = "[video]" + } + case "file": + content = "[file]" + case "sticker": + content = "[sticker]" + default: + content = fmt.Sprintf("[%s]", msg.Type) + } + + if strings.TrimSpace(content) == "" { + return + } + + metadata := map[string]string{ + "platform": "line", + "source_type": event.Source.Type, + "message_id": msg.ID, + } + + if isGroup { + metadata["peer_kind"] = "group" + metadata["peer_id"] = chatID + } else { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } + + logger.DebugCF("line", "Received message", map[string]any{ + "sender_id": senderID, + "chat_id": chatID, + "message_type": msg.Type, + "is_group": isGroup, + "preview": utils.Truncate(content, 50), + }) + + // Show typing/loading indicator (requires user ID, not group ID) + c.sendLoading(senderID) + + c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) +} + +// isBotMentioned checks if the bot is mentioned in the message. +// It first checks the mention metadata (userId match), then falls back +// to text-based detection using the bot's display name, since LINE may +// not include userId in mentionees for Official Accounts. +func (c *LINEChannel) isBotMentioned(msg lineMessage) bool { + // Check mention metadata + if msg.Mention != nil { + for _, m := range msg.Mention.Mentionees { + if m.Type == "all" { + return true + } + if c.botUserID != "" && m.UserID == c.botUserID { + return true + } + } + // Mention metadata exists with mentionees but bot not matched by userId. + // The bot IS likely mentioned (LINE includes mention struct when bot is @-ed), + // so check if any mentionee overlaps with bot display name in text. + if c.botDisplayName != "" { + for _, m := range msg.Mention.Mentionees { + if m.Index >= 0 && m.Length > 0 { + runes := []rune(msg.Text) + end := m.Index + m.Length + if end <= len(runes) { + mentionText := string(runes[m.Index:end]) + if strings.Contains(mentionText, c.botDisplayName) { + return true + } + } + } + } + } + } + + // Fallback: text-based detection with display name + if c.botDisplayName != "" && strings.Contains(msg.Text, "@"+c.botDisplayName) { + return true + } + + return false +} + +// stripBotMention removes the @BotName mention text from the message. +func (c *LINEChannel) stripBotMention(text string, msg lineMessage) string { + stripped := false + + // Try to strip using mention metadata indices + if msg.Mention != nil { + runes := []rune(text) + for i := len(msg.Mention.Mentionees) - 1; i >= 0; i-- { + m := msg.Mention.Mentionees[i] + // Strip if userId matches OR if the mention text contains the bot display name + shouldStrip := false + if c.botUserID != "" && m.UserID == c.botUserID { + shouldStrip = true + } else if c.botDisplayName != "" && m.Index >= 0 && m.Length > 0 { + end := m.Index + m.Length + if end <= len(runes) { + mentionText := string(runes[m.Index:end]) + if strings.Contains(mentionText, c.botDisplayName) { + shouldStrip = true + } + } + } + if shouldStrip { + start := m.Index + end := m.Index + m.Length + if start >= 0 && end <= len(runes) { + runes = append(runes[:start], runes[end:]...) + stripped = true + } + } + } + if stripped { + return strings.TrimSpace(string(runes)) + } + } + + // Fallback: strip @DisplayName from text + if c.botDisplayName != "" { + text = strings.ReplaceAll(text, "@"+c.botDisplayName, "") + } + + return strings.TrimSpace(text) +} + +// resolveChatID determines the chat ID from the event source. +// For group/room messages, use the group/room ID; for 1:1, use the user ID. +func (c *LINEChannel) resolveChatID(source lineSource) string { + switch source.Type { + case "group": + return source.GroupID + case "room": + return source.RoomID + default: + return source.UserID + } +} + +// Send sends a message to LINE. It first tries the Reply API (free) +// using a cached reply token, then falls back to the Push API. +func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("line channel not running") + } + + // Load and consume quote token for this chat + var quoteToken string + if qt, ok := c.quoteTokens.LoadAndDelete(msg.ChatID); ok { + quoteToken = qt.(string) + } + + // Try reply token first (free, valid for ~25 seconds) + if entry, ok := c.replyTokens.LoadAndDelete(msg.ChatID); ok { + tokenEntry := entry.(replyTokenEntry) + if time.Since(tokenEntry.timestamp) < lineReplyTokenMaxAge { + if err := c.sendReply(ctx, tokenEntry.token, msg.Content, quoteToken); err == nil { + logger.DebugCF("line", "Message sent via Reply API", map[string]any{ + "chat_id": msg.ChatID, + "quoted": quoteToken != "", + }) + return nil + } + logger.DebugC("line", "Reply API failed, falling back to Push API") + } + } + + // Fall back to Push API + return c.sendPush(ctx, msg.ChatID, msg.Content, quoteToken) +} + +// buildTextMessage creates a text message object, optionally with quoteToken. +func buildTextMessage(content, quoteToken string) map[string]string { + msg := map[string]string{ + "type": "text", + "text": content, + } + if quoteToken != "" { + msg["quoteToken"] = quoteToken + } + return msg +} + +// sendReply sends a message using the LINE Reply API. +func (c *LINEChannel) sendReply(ctx context.Context, replyToken, content, quoteToken string) error { + payload := map[string]any{ + "replyToken": replyToken, + "messages": []map[string]string{buildTextMessage(content, quoteToken)}, + } + + return c.callAPI(ctx, lineReplyEndpoint, payload) +} + +// sendPush sends a message using the LINE Push API. +func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken string) error { + payload := map[string]any{ + "to": to, + "messages": []map[string]string{buildTextMessage(content, quoteToken)}, + } + + return c.callAPI(ctx, linePushEndpoint, payload) +} + +// sendLoading sends a loading animation indicator to the chat. +func (c *LINEChannel) sendLoading(chatID string) { + payload := map[string]any{ + "chatId": chatID, + "loadingSeconds": 60, + } + if err := c.callAPI(c.ctx, lineLoadingEndpoint, payload); err != nil { + logger.DebugCF("line", "Failed to send loading indicator", map[string]any{ + "error": err.Error(), + }) + } +} + +// callAPI makes an authenticated POST request to the LINE API. +func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) error { + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("API request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +// downloadContent downloads media content from the LINE API. +func (c *LINEChannel) downloadContent(messageID, filename string) string { + url := fmt.Sprintf(lineContentEndpoint, messageID) + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "line", + ExtraHeaders: map[string]string{ + "Authorization": "Bearer " + c.config.ChannelAccessToken, + }, + }) +} diff --git a/pkg/channels/maixcam.go b/pkg/channels/maixcam.go index 5fc19adbe..34ce62b20 100644 --- a/pkg/channels/maixcam.go +++ b/pkg/channels/maixcam.go @@ -18,14 +18,13 @@ type MaixCamChannel struct { listener net.Listener clients map[net.Conn]bool clientsMux sync.RWMutex - running bool } type MaixCamMessage struct { - Type string `json:"type"` - Tips string `json:"tips"` - Timestamp float64 `json:"timestamp"` - Data map[string]interface{} `json:"data"` + Type string `json:"type"` + Tips string `json:"tips"` + Timestamp float64 `json:"timestamp"` + Data map[string]any `json:"data"` } func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) { @@ -35,7 +34,6 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC BaseChannel: base, config: cfg, clients: make(map[net.Conn]bool), - running: false, }, nil } @@ -51,7 +49,7 @@ func (c *MaixCamChannel) Start(ctx context.Context) error { c.listener = listener c.setRunning(true) - logger.InfoCF("maixcam", "MaixCam server listening", map[string]interface{}{ + logger.InfoCF("maixcam", "MaixCam server listening", map[string]any{ "host": c.config.Host, "port": c.config.Port, }) @@ -73,14 +71,14 @@ func (c *MaixCamChannel) acceptConnections(ctx context.Context) { conn, err := c.listener.Accept() if err != nil { if c.running { - logger.ErrorCF("maixcam", "Failed to accept connection", map[string]interface{}{ + logger.ErrorCF("maixcam", "Failed to accept connection", map[string]any{ "error": err.Error(), }) } return } - logger.InfoCF("maixcam", "New connection from MaixCam device", map[string]interface{}{ + logger.InfoCF("maixcam", "New connection from MaixCam device", map[string]any{ "remote_addr": conn.RemoteAddr().String(), }) @@ -114,7 +112,7 @@ func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { var msg MaixCamMessage if err := decoder.Decode(&msg); err != nil { if err.Error() != "EOF" { - logger.ErrorCF("maixcam", "Failed to decode message", map[string]interface{}{ + logger.ErrorCF("maixcam", "Failed to decode message", map[string]any{ "error": err.Error(), }) } @@ -135,14 +133,14 @@ func (c *MaixCamChannel) processMessage(msg MaixCamMessage, conn net.Conn) { case "status": c.handleStatusUpdate(msg) default: - logger.WarnCF("maixcam", "Unknown message type", map[string]interface{}{ + logger.WarnCF("maixcam", "Unknown message type", map[string]any{ "type": msg.Type, }) } } func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { - logger.InfoCF("maixcam", "", map[string]interface{}{ + logger.InfoCF("maixcam", "", map[string]any{ "timestamp": msg.Timestamp, "data": msg.Data, }) @@ -172,13 +170,15 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { "y": fmt.Sprintf("%.0f", y), "w": fmt.Sprintf("%.0f", w), "h": fmt.Sprintf("%.0f", h), + "peer_kind": "channel", + "peer_id": "default", } c.HandleMessage(senderID, chatID, content, []string{}, metadata) } func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { - logger.InfoCF("maixcam", "Status update from MaixCam", map[string]interface{}{ + logger.InfoCF("maixcam", "Status update from MaixCam", map[string]any{ "status": msg.Data, }) } @@ -216,7 +216,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro return fmt.Errorf("no connected MaixCam devices") } - response := map[string]interface{}{ + response := map[string]any{ "type": "command", "timestamp": float64(0), "message": msg.Content, @@ -231,7 +231,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro var sendErr error for conn := range c.clients { if _, err := conn.Write(data); err != nil { - logger.ErrorCF("maixcam", "Failed to send to client", map[string]interface{}{ + logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{ "client": conn.RemoteAddr().String(), "error": err.Error(), }) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 772551a4e..75edaf49e 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -48,9 +48,9 @@ func (m *Manager) initChannels() error { if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" { logger.DebugC("channels", "Attempting to initialize Telegram channel") - telegram, err := NewTelegramChannel(m.config.Channels.Telegram, m.bus) + telegram, err := NewTelegramChannel(m.config, m.bus) if err != nil { - logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]interface{}{ + logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]any{ "error": err.Error(), }) } else { @@ -63,7 +63,7 @@ func (m *Manager) initChannels() error { logger.DebugC("channels", "Attempting to initialize WhatsApp channel") whatsapp, err := NewWhatsAppChannel(m.config.Channels.WhatsApp, m.bus) if err != nil { - logger.ErrorCF("channels", "Failed to initialize WhatsApp channel", map[string]interface{}{ + logger.ErrorCF("channels", "Failed to initialize WhatsApp channel", map[string]any{ "error": err.Error(), }) } else { @@ -76,7 +76,7 @@ func (m *Manager) initChannels() error { logger.DebugC("channels", "Attempting to initialize Feishu channel") feishu, err := NewFeishuChannel(m.config.Channels.Feishu, m.bus) if err != nil { - logger.ErrorCF("channels", "Failed to initialize Feishu channel", map[string]interface{}{ + logger.ErrorCF("channels", "Failed to initialize Feishu channel", map[string]any{ "error": err.Error(), }) } else { @@ -89,7 +89,7 @@ func (m *Manager) initChannels() error { logger.DebugC("channels", "Attempting to initialize Discord channel") discord, err := NewDiscordChannel(m.config.Channels.Discord, m.bus) if err != nil { - logger.ErrorCF("channels", "Failed to initialize Discord channel", map[string]interface{}{ + logger.ErrorCF("channels", "Failed to initialize Discord channel", map[string]any{ "error": err.Error(), }) } else { @@ -102,7 +102,7 @@ func (m *Manager) initChannels() error { logger.DebugC("channels", "Attempting to initialize MaixCam channel") maixcam, err := NewMaixCamChannel(m.config.Channels.MaixCam, m.bus) if err != nil { - logger.ErrorCF("channels", "Failed to initialize MaixCam channel", map[string]interface{}{ + logger.ErrorCF("channels", "Failed to initialize MaixCam channel", map[string]any{ "error": err.Error(), }) } else { @@ -115,7 +115,7 @@ func (m *Manager) initChannels() error { logger.DebugC("channels", "Attempting to initialize QQ channel") qq, err := NewQQChannel(m.config.Channels.QQ, m.bus) if err != nil { - logger.ErrorCF("channels", "Failed to initialize QQ channel", map[string]interface{}{ + logger.ErrorCF("channels", "Failed to initialize QQ channel", map[string]any{ "error": err.Error(), }) } else { @@ -128,7 +128,7 @@ func (m *Manager) initChannels() error { logger.DebugC("channels", "Attempting to initialize DingTalk channel") dingtalk, err := NewDingTalkChannel(m.config.Channels.DingTalk, m.bus) if err != nil { - logger.ErrorCF("channels", "Failed to initialize DingTalk channel", map[string]interface{}{ + logger.ErrorCF("channels", "Failed to initialize DingTalk channel", map[string]any{ "error": err.Error(), }) } else { @@ -141,7 +141,7 @@ func (m *Manager) initChannels() error { logger.DebugC("channels", "Attempting to initialize Slack channel") slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus) if err != nil { - logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]interface{}{ + logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]any{ "error": err.Error(), }) } else { @@ -150,7 +150,59 @@ func (m *Manager) initChannels() error { } } - logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{ + if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" { + logger.DebugC("channels", "Attempting to initialize LINE channel") + line, err := NewLINEChannel(m.config.Channels.LINE, m.bus) + if err != nil { + logger.ErrorCF("channels", "Failed to initialize LINE channel", map[string]any{ + "error": err.Error(), + }) + } else { + m.channels["line"] = line + logger.InfoC("channels", "LINE channel enabled successfully") + } + } + + if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" { + logger.DebugC("channels", "Attempting to initialize OneBot channel") + onebot, err := NewOneBotChannel(m.config.Channels.OneBot, m.bus) + if err != nil { + logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]any{ + "error": err.Error(), + }) + } else { + m.channels["onebot"] = onebot + logger.InfoC("channels", "OneBot channel enabled successfully") + } + } + + if m.config.Channels.WeCom.Enabled && m.config.Channels.WeCom.Token != "" { + logger.DebugC("channels", "Attempting to initialize WeCom channel") + wecom, err := NewWeComBotChannel(m.config.Channels.WeCom, m.bus) + if err != nil { + logger.ErrorCF("channels", "Failed to initialize WeCom channel", map[string]any{ + "error": err.Error(), + }) + } else { + m.channels["wecom"] = wecom + logger.InfoC("channels", "WeCom channel enabled successfully") + } + } + + if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" { + logger.DebugC("channels", "Attempting to initialize WeCom App channel") + wecomApp, err := NewWeComAppChannel(m.config.Channels.WeComApp, m.bus) + if err != nil { + logger.ErrorCF("channels", "Failed to initialize WeCom App channel", map[string]any{ + "error": err.Error(), + }) + } else { + m.channels["wecom_app"] = wecomApp + logger.InfoC("channels", "WeCom App channel enabled successfully") + } + } + + logger.InfoCF("channels", "Channel initialization completed", map[string]any{ "enabled_channels": len(m.channels), }) @@ -174,11 +226,11 @@ func (m *Manager) StartAll(ctx context.Context) error { go m.dispatchOutbound(dispatchCtx) for name, channel := range m.channels { - logger.InfoCF("channels", "Starting channel", map[string]interface{}{ + logger.InfoCF("channels", "Starting channel", map[string]any{ "channel": name, }) if err := channel.Start(ctx); err != nil { - logger.ErrorCF("channels", "Failed to start channel", map[string]interface{}{ + logger.ErrorCF("channels", "Failed to start channel", map[string]any{ "channel": name, "error": err.Error(), }) @@ -201,11 +253,11 @@ func (m *Manager) StopAll(ctx context.Context) error { } for name, channel := range m.channels { - logger.InfoCF("channels", "Stopping channel", map[string]interface{}{ + logger.InfoCF("channels", "Stopping channel", map[string]any{ "channel": name, }) if err := channel.Stop(ctx); err != nil { - logger.ErrorCF("channels", "Error stopping channel", map[string]interface{}{ + logger.ErrorCF("channels", "Error stopping channel", map[string]any{ "channel": name, "error": err.Error(), }) @@ -240,14 +292,14 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { m.mu.RUnlock() if !exists { - logger.WarnCF("channels", "Unknown channel for outbound message", map[string]interface{}{ + logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{ "channel": msg.Channel, }) continue } if err := channel.Send(ctx, msg); err != nil { - logger.ErrorCF("channels", "Error sending message to channel", map[string]interface{}{ + logger.ErrorCF("channels", "Error sending message to channel", map[string]any{ "channel": msg.Channel, "error": err.Error(), }) @@ -263,13 +315,13 @@ func (m *Manager) GetChannel(name string) (Channel, bool) { return channel, ok } -func (m *Manager) GetStatus() map[string]interface{} { +func (m *Manager) GetStatus() map[string]any { m.mu.RLock() defer m.mu.RUnlock() - status := make(map[string]interface{}) + status := make(map[string]any) for name, channel := range m.channels { - status[name] = map[string]interface{}{ + status[name] = map[string]any{ "enabled": true, "running": channel.IsRunning(), } diff --git a/pkg/channels/onebot.go b/pkg/channels/onebot.go new file mode 100644 index 000000000..cee8ad9d3 --- /dev/null +++ b/pkg/channels/onebot.go @@ -0,0 +1,982 @@ +package channels + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +type OneBotChannel struct { + *BaseChannel + config config.OneBotConfig + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + dedup map[string]struct{} + dedupRing []string + dedupIdx int + mu sync.Mutex + writeMu sync.Mutex + echoCounter int64 + selfID int64 + pending map[string]chan json.RawMessage + pendingMu sync.Mutex + transcriber *voice.GroqTranscriber + lastMessageID sync.Map + pendingEmojiMsg sync.Map +} + +type oneBotRawEvent struct { + PostType string `json:"post_type"` + MessageType string `json:"message_type"` + SubType string `json:"sub_type"` + MessageID json.RawMessage `json:"message_id"` + UserID json.RawMessage `json:"user_id"` + GroupID json.RawMessage `json:"group_id"` + RawMessage string `json:"raw_message"` + Message json.RawMessage `json:"message"` + Sender json.RawMessage `json:"sender"` + SelfID json.RawMessage `json:"self_id"` + Time json.RawMessage `json:"time"` + MetaEventType string `json:"meta_event_type"` + NoticeType string `json:"notice_type"` + Echo string `json:"echo"` + RetCode json.RawMessage `json:"retcode"` + Status json.RawMessage `json:"status"` + Data json.RawMessage `json:"data"` +} + +type BotStatus struct { + Online bool `json:"online"` + Good bool `json:"good"` +} + +func isAPIResponse(raw json.RawMessage) bool { + if len(raw) == 0 { + return false + } + var s string + if json.Unmarshal(raw, &s) == nil { + return s == "ok" || s == "failed" + } + var bs BotStatus + if json.Unmarshal(raw, &bs) == nil { + return bs.Online || bs.Good + } + return false +} + +type oneBotSender struct { + UserID json.RawMessage `json:"user_id"` + Nickname string `json:"nickname"` + Card string `json:"card"` +} + +type oneBotAPIRequest struct { + Action string `json:"action"` + Params any `json:"params"` + Echo string `json:"echo,omitempty"` +} + +type oneBotMessageSegment struct { + Type string `json:"type"` + Data map[string]any `json:"data"` +} + +func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) { + base := NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom) + + const dedupSize = 1024 + return &OneBotChannel{ + BaseChannel: base, + config: cfg, + dedup: make(map[string]struct{}, dedupSize), + dedupRing: make([]string, dedupSize), + dedupIdx: 0, + pending: make(map[string]chan json.RawMessage), + }, nil +} + +func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) { + go func() { + _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]any{ + "message_id": messageID, + "emoji_id": emojiID, + "set": set, + }, 5*time.Second) + if err != nil { + logger.DebugCF("onebot", "Failed to set emoji like", map[string]any{ + "message_id": messageID, + "error": err.Error(), + }) + } + }() +} + +func (c *OneBotChannel) Start(ctx context.Context) error { + if c.config.WSUrl == "" { + return fmt.Errorf("OneBot ws_url not configured") + } + + logger.InfoCF("onebot", "Starting OneBot channel", map[string]any{ + "ws_url": c.config.WSUrl, + }) + + c.ctx, c.cancel = context.WithCancel(ctx) + + if err := c.connect(); err != nil { + logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]any{ + "error": err.Error(), + }) + } else { + go c.listen() + c.fetchSelfID() + } + + if c.config.ReconnectInterval > 0 { + go c.reconnectLoop() + } else { + if c.conn == nil { + return fmt.Errorf("failed to connect to OneBot and reconnect is disabled") + } + } + + c.setRunning(true) + logger.InfoC("onebot", "OneBot channel started successfully") + + return nil +} + +func (c *OneBotChannel) connect() error { + dialer := websocket.DefaultDialer + dialer.HandshakeTimeout = 10 * time.Second + + header := make(map[string][]string) + if c.config.AccessToken != "" { + header["Authorization"] = []string{"Bearer " + c.config.AccessToken} + } + + conn, _, err := dialer.Dial(c.config.WSUrl, header) + if err != nil { + return err + } + + conn.SetPongHandler(func(appData string) error { + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + + c.mu.Lock() + c.conn = conn + c.mu.Unlock() + + go c.pinger(conn) + + logger.InfoC("onebot", "WebSocket connected") + return nil +} + +func (c *OneBotChannel) pinger(conn *websocket.Conn) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + c.writeMu.Lock() + err := conn.WriteMessage(websocket.PingMessage, nil) + c.writeMu.Unlock() + if err != nil { + logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]any{ + "error": err.Error(), + }) + return + } + } + } +} + +func (c *OneBotChannel) fetchSelfID() { + resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second) + if err != nil { + logger.WarnCF("onebot", "Failed to get_login_info", map[string]any{ + "error": err.Error(), + }) + return + } + + type loginInfo struct { + UserID json.RawMessage `json:"user_id"` + Nickname string `json:"nickname"` + } + for _, extract := range []func() (*loginInfo, error){ + func() (*loginInfo, error) { + var w struct { + Data loginInfo `json:"data"` + } + err := json.Unmarshal(resp, &w) + return &w.Data, err + }, + func() (*loginInfo, error) { + var f loginInfo + err := json.Unmarshal(resp, &f) + return &f, err + }, + } { + info, err := extract() + if err != nil || len(info.UserID) == 0 { + continue + } + if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 { + atomic.StoreInt64(&c.selfID, uid) + logger.InfoCF("onebot", "Bot self ID retrieved", map[string]any{ + "self_id": uid, + "nickname": info.Nickname, + }) + return + } + } + + logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]any{ + "response": string(resp), + }) +} + +func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.Duration) (json.RawMessage, error) { + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + return nil, fmt.Errorf("WebSocket not connected") + } + + echo := fmt.Sprintf("api_%d_%d", time.Now().UnixNano(), atomic.AddInt64(&c.echoCounter, 1)) + + ch := make(chan json.RawMessage, 1) + c.pendingMu.Lock() + c.pending[echo] = ch + c.pendingMu.Unlock() + + defer func() { + c.pendingMu.Lock() + delete(c.pending, echo) + c.pendingMu.Unlock() + }() + + req := oneBotAPIRequest{ + Action: action, + Params: params, + Echo: echo, + } + + data, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal API request: %w", err) + } + + c.writeMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, data) + c.writeMu.Unlock() + + if err != nil { + return nil, fmt.Errorf("failed to write API request: %w", err) + } + + select { + case resp := <-ch: + return resp, nil + case <-time.After(timeout): + return nil, fmt.Errorf("API request %s timed out after %v", action, timeout) + case <-c.ctx.Done(): + return nil, fmt.Errorf("context cancelled") + } +} + +func (c *OneBotChannel) reconnectLoop() { + interval := time.Duration(c.config.ReconnectInterval) * time.Second + if interval < 5*time.Second { + interval = 5 * time.Second + } + + for { + select { + case <-c.ctx.Done(): + return + case <-time.After(interval): + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + logger.InfoC("onebot", "Attempting to reconnect...") + if err := c.connect(); err != nil { + logger.ErrorCF("onebot", "Reconnect failed", map[string]any{ + "error": err.Error(), + }) + } else { + go c.listen() + c.fetchSelfID() + } + } + } + } +} + +func (c *OneBotChannel) Stop(ctx context.Context) error { + logger.InfoC("onebot", "Stopping OneBot channel") + c.setRunning(false) + + if c.cancel != nil { + c.cancel() + } + + c.pendingMu.Lock() + for echo, ch := range c.pending { + close(ch) + delete(c.pending, echo) + } + c.pendingMu.Unlock() + + c.mu.Lock() + if c.conn != nil { + c.conn.Close() + c.conn = nil + } + c.mu.Unlock() + + return nil +} + +func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("OneBot channel not running") + } + + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + return fmt.Errorf("OneBot WebSocket not connected") + } + + action, params, err := c.buildSendRequest(msg) + if err != nil { + return err + } + + echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1)) + + req := oneBotAPIRequest{ + Action: action, + Params: params, + Echo: echo, + } + + data, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to marshal OneBot request: %w", err) + } + + c.writeMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, data) + c.writeMu.Unlock() + + if err != nil { + logger.ErrorCF("onebot", "Failed to send message", map[string]any{ + "error": err.Error(), + }) + return err + } + + if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { + if mid, ok := msgID.(string); ok && mid != "" { + c.setMsgEmojiLike(mid, 289, false) + } + } + + return nil +} + +func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMessageSegment { + var segments []oneBotMessageSegment + + if lastMsgID, ok := c.lastMessageID.Load(chatID); ok { + if msgID, ok := lastMsgID.(string); ok && msgID != "" { + segments = append(segments, oneBotMessageSegment{ + Type: "reply", + Data: map[string]any{"id": msgID}, + }) + } + } + + segments = append(segments, oneBotMessageSegment{ + Type: "text", + Data: map[string]any{"text": content}, + }) + + return segments +} + +func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, any, error) { + chatID := msg.ChatID + segments := c.buildMessageSegments(chatID, msg.Content) + + var action, idKey string + var rawID string + if rest, ok := strings.CutPrefix(chatID, "group:"); ok { + action, idKey, rawID = "send_group_msg", "group_id", rest + } else if rest, ok := strings.CutPrefix(chatID, "private:"); ok { + action, idKey, rawID = "send_private_msg", "user_id", rest + } else { + action, idKey, rawID = "send_private_msg", "user_id", chatID + } + + id, err := strconv.ParseInt(rawID, 10, 64) + if err != nil { + return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID) + } + return action, map[string]any{idKey: id, "message": segments}, nil +} + +func (c *OneBotChannel) listen() { + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + logger.WarnC("onebot", "WebSocket connection is nil, listener exiting") + return + } + + for { + select { + case <-c.ctx.Done(): + return + default: + _, message, err := conn.ReadMessage() + if err != nil { + logger.ErrorCF("onebot", "WebSocket read error", map[string]any{ + "error": err.Error(), + }) + c.mu.Lock() + if c.conn == conn { + c.conn.Close() + c.conn = nil + } + c.mu.Unlock() + return + } + + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + + var raw oneBotRawEvent + if err := json.Unmarshal(message, &raw); err != nil { + logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]any{ + "error": err.Error(), + "payload": string(message), + }) + continue + } + + logger.DebugCF("onebot", "WebSocket event", map[string]any{ + "length": len(message), + "post_type": raw.PostType, + "sub_type": raw.SubType, + }) + + if raw.Echo != "" { + c.pendingMu.Lock() + ch, ok := c.pending[raw.Echo] + c.pendingMu.Unlock() + + if ok { + select { + case ch <- message: + default: + } + } else { + logger.DebugCF("onebot", "Received API response (no waiter)", map[string]any{ + "echo": raw.Echo, + "status": string(raw.Status), + }) + } + continue + } + + if isAPIResponse(raw.Status) { + logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]any{ + "status": string(raw.Status), + }) + continue + } + + c.handleRawEvent(&raw) + } + } +} + +func parseJSONInt64(raw json.RawMessage) (int64, error) { + if len(raw) == 0 { + return 0, nil + } + + var n int64 + if err := json.Unmarshal(raw, &n); err == nil { + return n, nil + } + + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return strconv.ParseInt(s, 10, 64) + } + return 0, fmt.Errorf("cannot parse as int64: %s", string(raw)) +} + +func parseJSONString(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + + return string(raw) +} + +type parseMessageResult struct { + Text string + IsBotMentioned bool + Media []string + LocalFiles []string + ReplyTo string +} + +func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult { + if len(raw) == 0 { + return parseMessageResult{} + } + + var s string + if err := json.Unmarshal(raw, &s); err == nil { + mentioned := false + if selfID > 0 { + cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID) + if strings.Contains(s, cqAt) { + mentioned = true + s = strings.ReplaceAll(s, cqAt, "") + s = strings.TrimSpace(s) + } + } + return parseMessageResult{Text: s, IsBotMentioned: mentioned} + } + + var segments []map[string]any + if err := json.Unmarshal(raw, &segments); err != nil { + return parseMessageResult{} + } + + var textParts []string + mentioned := false + selfIDStr := strconv.FormatInt(selfID, 10) + var media []string + var localFiles []string + var replyTo string + + for _, seg := range segments { + segType, _ := seg["type"].(string) + data, _ := seg["data"].(map[string]any) + + switch segType { + case "text": + if data != nil { + if t, ok := data["text"].(string); ok { + textParts = append(textParts, t) + } + } + + case "at": + if data != nil && selfID > 0 { + qqVal := fmt.Sprintf("%v", data["qq"]) + if qqVal == selfIDStr || qqVal == "all" { + mentioned = true + } + } + + case "image", "video", "file": + if data != nil { + url, _ := data["url"].(string) + if url != "" { + defaults := map[string]string{"image": "image.jpg", "video": "video.mp4", "file": "file"} + filename := defaults[segType] + if f, ok := data["file"].(string); ok && f != "" { + filename = f + } else if n, ok := data["name"].(string); ok && n != "" { + filename = n + } + localPath := utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "onebot", + }) + if localPath != "" { + media = append(media, localPath) + localFiles = append(localFiles, localPath) + textParts = append(textParts, fmt.Sprintf("[%s]", segType)) + } + } + } + + case "record": + if data != nil { + url, _ := data["url"].(string) + if url != "" { + localPath := utils.DownloadFile(url, "voice.amr", utils.DownloadOptions{ + LoggerPrefix: "onebot", + }) + if localPath != "" { + localFiles = append(localFiles, localPath) + if c.transcriber != nil && c.transcriber.IsAvailable() { + tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second) + result, err := c.transcriber.Transcribe(tctx, localPath) + tcancel() + if err != nil { + logger.WarnCF("onebot", "Voice transcription failed", map[string]any{ + "error": err.Error(), + }) + textParts = append(textParts, "[voice (transcription failed)]") + media = append(media, localPath) + } else { + textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text)) + } + } else { + textParts = append(textParts, "[voice]") + media = append(media, localPath) + } + } + } + } + + case "reply": + if data != nil { + if id, ok := data["id"]; ok { + replyTo = fmt.Sprintf("%v", id) + } + } + + case "face": + if data != nil { + faceID, _ := data["id"] + textParts = append(textParts, fmt.Sprintf("[face:%v]", faceID)) + } + + case "forward": + textParts = append(textParts, "[forward message]") + + default: + + } + } + + return parseMessageResult{ + Text: strings.TrimSpace(strings.Join(textParts, "")), + IsBotMentioned: mentioned, + Media: media, + LocalFiles: localFiles, + ReplyTo: replyTo, + } +} + +func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { + switch raw.PostType { + case "message": + if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 { + if !c.IsAllowed(strconv.FormatInt(userID, 10)) { + logger.DebugCF("onebot", "Message rejected by allowlist", map[string]any{ + "user_id": userID, + }) + return + } + } + c.handleMessage(raw) + + case "message_sent": + logger.DebugCF("onebot", "Bot sent message event", map[string]any{ + "message_type": raw.MessageType, + "message_id": parseJSONString(raw.MessageID), + }) + + case "meta_event": + c.handleMetaEvent(raw) + + case "notice": + c.handleNoticeEvent(raw) + + case "request": + logger.DebugCF("onebot", "Request event received", map[string]any{ + "sub_type": raw.SubType, + }) + + case "": + logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]any{ + "echo": raw.Echo, + "status": raw.Status, + }) + + default: + logger.DebugCF("onebot", "Unknown post_type", map[string]any{ + "post_type": raw.PostType, + }) + } +} + +func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) { + if raw.MetaEventType == "lifecycle" { + logger.InfoCF("onebot", "Lifecycle event", map[string]any{"sub_type": raw.SubType}) + } else if raw.MetaEventType != "heartbeat" { + logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil) + } +} + +func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) { + fields := map[string]any{ + "notice_type": raw.NoticeType, + "sub_type": raw.SubType, + "group_id": parseJSONString(raw.GroupID), + "user_id": parseJSONString(raw.UserID), + "message_id": parseJSONString(raw.MessageID), + } + switch raw.NoticeType { + case "group_recall", "group_increase", "group_decrease", + "friend_add", "group_admin", "group_ban": + logger.InfoCF("onebot", "Notice: "+raw.NoticeType, fields) + default: + logger.DebugCF("onebot", "Notice: "+raw.NoticeType, fields) + } +} + +func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { + // Parse fields from raw event + userID, err := parseJSONInt64(raw.UserID) + if err != nil { + logger.WarnCF("onebot", "Failed to parse user_id", map[string]any{ + "error": err.Error(), + "raw": string(raw.UserID), + }) + return + } + + groupID, _ := parseJSONInt64(raw.GroupID) + selfID, _ := parseJSONInt64(raw.SelfID) + messageID := parseJSONString(raw.MessageID) + + if selfID == 0 { + selfID = atomic.LoadInt64(&c.selfID) + } + + parsed := c.parseMessageSegments(raw.Message, selfID) + isBotMentioned := parsed.IsBotMentioned + + content := raw.RawMessage + if content == "" { + content = parsed.Text + } else if selfID > 0 { + cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID) + if strings.Contains(content, cqAt) { + isBotMentioned = true + content = strings.ReplaceAll(content, cqAt, "") + content = strings.TrimSpace(content) + } + } + + if parsed.Text != "" && content != parsed.Text && (len(parsed.Media) > 0 || parsed.ReplyTo != "") { + content = parsed.Text + } + + var sender oneBotSender + if len(raw.Sender) > 0 { + if err := json.Unmarshal(raw.Sender, &sender); err != nil { + logger.WarnCF("onebot", "Failed to parse sender", map[string]any{ + "error": err.Error(), + "sender": string(raw.Sender), + }) + } + } + + // Clean up temp files when done + if len(parsed.LocalFiles) > 0 { + defer func() { + for _, f := range parsed.LocalFiles { + if err := os.Remove(f); err != nil { + logger.DebugCF("onebot", "Failed to remove temp file", map[string]any{ + "path": f, + "error": err.Error(), + }) + } + } + }() + } + + if c.isDuplicate(messageID) { + logger.DebugCF("onebot", "Duplicate message, skipping", map[string]any{ + "message_id": messageID, + }) + return + } + + if content == "" { + logger.DebugCF("onebot", "Received empty message, ignoring", map[string]any{ + "message_id": messageID, + }) + return + } + + senderID := strconv.FormatInt(userID, 10) + var chatID string + + metadata := map[string]string{ + "message_id": messageID, + } + + if parsed.ReplyTo != "" { + metadata["reply_to_message_id"] = parsed.ReplyTo + } + + switch raw.MessageType { + case "private": + chatID = "private:" + senderID + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + + case "group": + groupIDStr := strconv.FormatInt(groupID, 10) + chatID = "group:" + groupIDStr + metadata["peer_kind"] = "group" + metadata["peer_id"] = groupIDStr + metadata["group_id"] = groupIDStr + + senderUserID, _ := parseJSONInt64(sender.UserID) + if senderUserID > 0 { + metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10) + } + + if sender.Card != "" { + metadata["sender_name"] = sender.Card + } else if sender.Nickname != "" { + metadata["sender_name"] = sender.Nickname + } + + triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned) + if !triggered { + logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]any{ + "sender": senderID, + "group": groupIDStr, + "is_mentioned": isBotMentioned, + "content": truncate(content, 100), + }) + return + } + content = strippedContent + + default: + logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]any{ + "type": raw.MessageType, + "message_id": messageID, + "user_id": userID, + }) + return + } + + logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]any{ + "sender": senderID, + "chat_id": chatID, + "message_id": messageID, + "length": len(content), + "content": truncate(content, 100), + "media_count": len(parsed.Media), + }) + + if sender.Nickname != "" { + metadata["nickname"] = sender.Nickname + } + + c.lastMessageID.Store(chatID, messageID) + + if raw.MessageType == "group" && messageID != "" && messageID != "0" { + c.setMsgEmojiLike(messageID, 289, true) + c.pendingEmojiMsg.Store(chatID, messageID) + } + + c.HandleMessage(senderID, chatID, content, parsed.Media, metadata) +} + +func (c *OneBotChannel) isDuplicate(messageID string) bool { + if messageID == "" || messageID == "0" { + return false + } + + c.mu.Lock() + defer c.mu.Unlock() + + if _, exists := c.dedup[messageID]; exists { + return true + } + + if old := c.dedupRing[c.dedupIdx]; old != "" { + delete(c.dedup, old) + } + c.dedupRing[c.dedupIdx] = messageID + c.dedup[messageID] = struct{}{} + c.dedupIdx = (c.dedupIdx + 1) % len(c.dedupRing) + + return false +} + +func truncate(s string, n int) string { + runes := []rune(s) + if len(runes) <= n { + return s + } + return string(runes[:n]) + "..." +} + +func (c *OneBotChannel) checkGroupTrigger( + content string, + isBotMentioned bool, +) (triggered bool, strippedContent string) { + if isBotMentioned { + return true, strings.TrimSpace(content) + } + + for _, prefix := range c.config.GroupTriggerPrefix { + if prefix == "" { + continue + } + if strings.HasPrefix(content, prefix) { + return true, strings.TrimSpace(strings.TrimPrefix(content, prefix)) + } + } + + return false, content +} diff --git a/pkg/channels/qq.go b/pkg/channels/qq.go index 18b4ca0e0..e66cac533 100644 --- a/pkg/channels/qq.go +++ b/pkg/channels/qq.go @@ -77,7 +77,7 @@ func (c *QQChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to get websocket info: %w", err) } - logger.InfoCF("qq", "Got WebSocket info", map[string]interface{}{ + logger.InfoCF("qq", "Got WebSocket info", map[string]any{ "shards": wsInfo.Shards, }) @@ -87,7 +87,7 @@ func (c *QQChannel) Start(ctx context.Context) error { // 在 goroutine 中启动 WebSocket 连接,避免阻塞 go func() { if err := c.sessionManager.Start(wsInfo, c.tokenSource, &intent); err != nil { - logger.ErrorCF("qq", "WebSocket session error", map[string]interface{}{ + logger.ErrorCF("qq", "WebSocket session error", map[string]any{ "error": err.Error(), }) c.setRunning(false) @@ -124,7 +124,7 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { // C2C 消息发送 _, err := c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate) if err != nil { - logger.ErrorCF("qq", "Failed to send C2C message", map[string]interface{}{ + logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{ "error": err.Error(), }) return err @@ -157,7 +157,7 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { return nil } - logger.InfoCF("qq", "Received C2C message", map[string]interface{}{ + logger.InfoCF("qq", "Received C2C message", map[string]any{ "sender": senderID, "length": len(content), }) @@ -165,6 +165,8 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { // 转发到消息总线 metadata := map[string]string{ "message_id": data.ID, + "peer_kind": "direct", + "peer_id": senderID, } c.HandleMessage(senderID, senderID, content, []string{}, metadata) @@ -197,7 +199,7 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { return nil } - logger.InfoCF("qq", "Received group AT message", map[string]interface{}{ + logger.InfoCF("qq", "Received group AT message", map[string]any{ "sender": senderID, "group": data.GroupID, "length": len(content), @@ -207,6 +209,8 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { metadata := map[string]string{ "message_id": data.ID, "group_id": data.GroupID, + "peer_kind": "group", + "peer_id": data.GroupID, } c.HandleMessage(senderID, data.GroupID, content, []string{}, metadata) diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go index b3ac12e01..f7359cd6d 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack.go @@ -25,6 +25,7 @@ type SlackChannel struct { api *slack.Client socketClient *socketmode.Client botUserID string + teamID string transcriber *voice.GroqTranscriber ctx context.Context cancel context.CancelFunc @@ -72,8 +73,9 @@ func (c *SlackChannel) Start(ctx context.Context) error { return fmt.Errorf("slack auth test failed: %w", err) } c.botUserID = authResp.UserID + c.teamID = authResp.TeamID - logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{ + logger.InfoCF("slack", "Slack bot connected", map[string]any{ "bot_user_id": c.botUserID, "team": authResp.Team, }) @@ -83,7 +85,7 @@ func (c *SlackChannel) Start(ctx context.Context) error { go func() { if err := c.socketClient.RunContext(c.ctx); err != nil { if c.ctx.Err() == nil { - logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{ + logger.ErrorCF("slack", "Socket Mode connection error", map[string]any{ "error": err.Error(), }) } @@ -138,7 +140,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error }) } - logger.DebugCF("slack", "Message sent", map[string]interface{}{ + logger.DebugCF("slack", "Message sent", map[string]any{ "channel_id": channelID, "thread_ts": threadTS, }) @@ -200,7 +202,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { // 检查白名单,避免为被拒绝的用户下载附件 if !c.IsAllowed(ev.User) { - logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{ + logger.DebugCF("slack", "Message rejected by allowlist", map[string]any{ "user_id": ev.User, }) return @@ -236,7 +238,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { defer func() { for _, file := range localFiles { if err := os.Remove(file); err != nil { - logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{ + logger.DebugCF("slack", "Failed to cleanup temp file", map[string]any{ "file": file, "error": err.Error(), }) @@ -259,7 +261,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { result, err := c.transcriber.Transcribe(ctx, localPath) if err != nil { - logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()}) + logger.ErrorCF("slack", "Voice transcription failed", map[string]any{"error": err.Error()}) content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name) } else { content += fmt.Sprintf("\n[voice transcription: %s]", result.Text) @@ -274,17 +276,27 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { return } + peerKind := "channel" + peerID := channelID + if strings.HasPrefix(channelID, "D") { + peerKind = "direct" + peerID = senderID + } + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", + "peer_kind": peerKind, + "peer_id": peerID, + "team_id": c.teamID, } - logger.DebugCF("slack", "Received message", map[string]interface{}{ - "sender_id": senderID, - "chat_id": chatID, - "preview": utils.Truncate(content, 50), + logger.DebugCF("slack", "Received message", map[string]any{ + "sender_id": senderID, + "chat_id": chatID, + "preview": utils.Truncate(content, 50), "has_thread": threadTS != "", }) @@ -296,6 +308,13 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { return } + if !c.IsAllowed(ev.User) { + logger.DebugCF("slack", "Mention rejected by allowlist", map[string]any{ + "user_id": ev.User, + }) + return + } + senderID := ev.User channelID := ev.Channel threadTS := ev.ThreadTimeStamp @@ -324,12 +343,22 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { return } + mentionPeerKind := "channel" + mentionPeerID := channelID + if strings.HasPrefix(channelID, "D") { + mentionPeerKind = "direct" + mentionPeerID = senderID + } + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", "is_mention": "true", + "peer_kind": mentionPeerKind, + "peer_id": mentionPeerID, + "team_id": c.teamID, } c.HandleMessage(senderID, chatID, content, nil, metadata) @@ -345,6 +374,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { c.socketClient.Ack(*event.Request) } + if !c.IsAllowed(cmd.UserID) { + logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]any{ + "user_id": cmd.UserID, + }) + return + } + senderID := cmd.UserID channelID := cmd.ChannelID chatID := channelID @@ -359,9 +395,12 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "platform": "slack", "is_command": "true", "trigger_id": cmd.TriggerID, + "peer_kind": "channel", + "peer_id": channelID, + "team_id": c.teamID, } - logger.DebugCF("slack", "Slash command received", map[string]interface{}{ + logger.DebugCF("slack", "Slash command received", map[string]any{ "sender_id": senderID, "command": cmd.Command, "text": utils.Truncate(content, 50), @@ -376,7 +415,7 @@ func (c *SlackChannel) downloadSlackFile(file slack.File) string { downloadURL = file.URLPrivate } if downloadURL == "" { - logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID}) + logger.ErrorCF("slack", "No download URL for file", map[string]any{"file_id": file.ID}) return "" } diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 3ad4818c3..a0a1c8d0a 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -12,6 +12,8 @@ import ( "time" "github.com/mymmrac/telego" + "github.com/mymmrac/telego/telegohandler" + th "github.com/mymmrac/telego/telegohandler" tu "github.com/mymmrac/telego/telegoutil" "github.com/sipeed/picoclaw/pkg/bus" @@ -24,7 +26,8 @@ import ( type TelegramChannel struct { *BaseChannel bot *telego.Bot - config config.TelegramConfig + commands TelegramCommander + config *config.Config chatIDs map[string]int64 transcriber *voice.GroqTranscriber placeholders sync.Map // chatID -> messageID @@ -41,30 +44,39 @@ func (c *thinkingCancel) Cancel() { } } -func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) { +func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { var opts []telego.BotOption + telegramCfg := cfg.Channels.Telegram - if cfg.Proxy != "" { - proxyURL, parseErr := url.Parse(cfg.Proxy) + if telegramCfg.Proxy != "" { + proxyURL, parseErr := url.Parse(telegramCfg.Proxy) if parseErr != nil { - return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr) + return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr) } opts = append(opts, telego.WithHTTPClient(&http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxyURL), }, })) + } else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" { + // Use environment proxy if configured + opts = append(opts, telego.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + })) } - bot, err := telego.NewBot(cfg.Token, opts...) + bot, err := telego.NewBot(telegramCfg.Token, opts...) if err != nil { return nil, fmt.Errorf("failed to create telegram bot: %w", err) } - base := NewBaseChannel("telegram", cfg, bus, cfg.AllowFrom) + base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom) return &TelegramChannel{ BaseChannel: base, + commands: NewTelegramCommands(bot, cfg), bot: bot, config: cfg, chatIDs: make(map[string]int64), @@ -88,26 +100,41 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to start long polling: %w", err) } + bh, err := telegohandler.NewBotHandler(c.bot, updates) + if err != nil { + return fmt.Errorf("failed to create bot handler: %w", err) + } + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + c.commands.Help(ctx, message) + return nil + }, th.CommandEqual("help")) + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.Start(ctx, message) + }, th.CommandEqual("start")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.Show(ctx, message) + }, th.CommandEqual("show")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.List(ctx, message) + }, th.CommandEqual("list")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.handleMessage(ctx, &message) + }, th.AnyMessage()) + c.setRunning(true) - logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{ + logger.InfoCF("telegram", "Telegram bot connected", map[string]any{ "username": c.bot.Username(), }) + go bh.Start() + go func() { - for { - select { - case <-ctx.Done(): - return - case update, ok := <-updates: - if !ok { - logger.InfoC("telegram", "Updates channel closed, reconnecting...") - return - } - if update.Message != nil { - c.handleMessage(ctx, update) - } - } - } + <-ctx.Done() + bh.Stop() }() return nil @@ -155,7 +182,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err tgMsg.ParseMode = telego.ModeHTML if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { - logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{ + logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{ "error": err.Error(), }) tgMsg.ParseMode = "" @@ -166,15 +193,14 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return nil } -func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) { - message := update.Message +func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error { if message == nil { - return + return fmt.Errorf("message is nil") } user := message.From if user == nil { - return + return fmt.Errorf("message sender (user) is nil") } senderID := fmt.Sprintf("%d", user.ID) @@ -184,10 +210,10 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat // 检查白名单,避免为被拒绝的用户下载附件 if !c.IsAllowed(senderID) { - logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{ + logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{ "user_id": senderID, }) - return + return nil } chatID := message.Chat.ID @@ -201,7 +227,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat defer func() { for _, file := range localFiles { if err := os.Remove(file); err != nil { - logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{ + logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]any{ "file": file, "error": err.Error(), }) @@ -220,7 +246,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat content += message.Caption } - if message.Photo != nil && len(message.Photo) > 0 { + if len(message.Photo) > 0 { photo := message.Photo[len(message.Photo)-1] photoPath := c.downloadPhoto(ctx, photo.FileID) if photoPath != "" { @@ -229,7 +255,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat if content != "" { content += "\n" } - content += fmt.Sprintf("[image: photo]") + content += "[image: photo]" } } @@ -241,24 +267,24 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + transcriberCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - result, err := c.transcriber.Transcribe(ctx, voicePath) + result, err := c.transcriber.Transcribe(transcriberCtx, voicePath) if err != nil { - logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{ + logger.ErrorCF("telegram", "Voice transcription failed", map[string]any{ "error": err.Error(), "path": voicePath, }) - transcribedText = fmt.Sprintf("[voice (transcription failed)]") + transcribedText = "[voice (transcription failed)]" } else { transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) - logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{ + logger.InfoCF("telegram", "Voice transcribed successfully", map[string]any{ "text": result.Text, }) } } else { - transcribedText = fmt.Sprintf("[voice]") + transcribedText = "[voice]" } if content != "" { @@ -276,7 +302,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat if content != "" { content += "\n" } - content += fmt.Sprintf("[audio]") + content += "[audio]" } } @@ -288,7 +314,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat if content != "" { content += "\n" } - content += fmt.Sprintf("[file]") + content += "[file]" } } @@ -296,7 +322,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat content = "[empty message]" } - logger.DebugCF("telegram", "Received message", map[string]interface{}{ + logger.DebugCF("telegram", "Received message", map[string]any{ "sender_id": senderID, "chat_id": fmt.Sprintf("%d", chatID), "preview": utils.Truncate(content, 50), @@ -305,7 +331,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat // Thinking indicator err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) if err != nil { - logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{ + logger.ErrorCF("telegram", "Failed to send chat action", map[string]any{ "error": err.Error(), }) } @@ -318,37 +344,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat } } - // Create new context for thinking animation with timeout - thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) + // Create cancel function for thinking state + _, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) if err == nil { pID := pMsg.MessageID c.placeholders.Store(chatIDStr, pID) + } - go func(cid int64, mid int) { - dots := []string{".", "..", "..."} - emotes := []string{"💭", "🤔", "☁️"} - i := 0 - ticker := time.NewTicker(2000 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-thinkCtx.Done(): - return - case <-ticker.C: - i++ - text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)]) - _, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text)) - if editErr != nil { - logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{ - "error": editErr.Error(), - }) - } - } - } - }(chatID, pID) + peerKind := "direct" + peerID := fmt.Sprintf("%d", user.ID) + if message.Chat.Type != "private" { + peerKind = "group" + peerID = fmt.Sprintf("%d", chatID) } metadata := map[string]string{ @@ -357,15 +367,18 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat "username": user.Username, "first_name": user.FirstName, "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), + "peer_kind": peerKind, + "peer_id": peerID, } c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + return nil } func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{ + logger.ErrorCF("telegram", "Failed to get photo file", map[string]any{ "error": err.Error(), }) return "" @@ -380,7 +393,7 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st } url := c.bot.FileDownloadURL(file.FilePath) - logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url}) + logger.DebugCF("telegram", "File URL", map[string]any{"url": url}) // Use FilePath as filename for better identification filename := file.FilePath + ext @@ -392,7 +405,7 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{ + logger.ErrorCF("telegram", "Failed to get file", map[string]any{ "error": err.Error(), }) return "" @@ -450,7 +463,11 @@ func markdownToTelegramHTML(text string) string { for i, code := range codeBlocks.codes { escaped := escapeHTML(code) - text = strings.ReplaceAll(text, fmt.Sprintf("\x00CB%d\x00", i), fmt.Sprintf("
%s
", escaped)) + text = strings.ReplaceAll( + text, + fmt.Sprintf("\x00CB%d\x00", i), + fmt.Sprintf("
%s
", escaped), + ) } return text @@ -470,8 +487,11 @@ func extractCodeBlocks(text string) codeBlockMatch { codes = append(codes, match[1]) } + i := 0 text = re.ReplaceAllStringFunc(text, func(m string) string { - return fmt.Sprintf("\x00CB%d\x00", len(codes)-1) + placeholder := fmt.Sprintf("\x00CB%d\x00", i) + i++ + return placeholder }) return codeBlockMatch{text: text, codes: codes} @@ -491,8 +511,11 @@ func extractInlineCodes(text string) inlineCodeMatch { codes = append(codes, match[1]) } + i := 0 text = re.ReplaceAllStringFunc(text, func(m string) string { - return fmt.Sprintf("\x00IC%d\x00", len(codes)-1) + placeholder := fmt.Sprintf("\x00IC%d\x00", i) + i++ + return placeholder }) return inlineCodeMatch{text: text, codes: codes} diff --git a/pkg/channels/telegram_commands.go b/pkg/channels/telegram_commands.go new file mode 100644 index 000000000..a084b641b --- /dev/null +++ b/pkg/channels/telegram_commands.go @@ -0,0 +1,156 @@ +package channels + +import ( + "context" + "fmt" + "strings" + + "github.com/mymmrac/telego" + + "github.com/sipeed/picoclaw/pkg/config" +) + +type TelegramCommander interface { + Help(ctx context.Context, message telego.Message) error + Start(ctx context.Context, message telego.Message) error + Show(ctx context.Context, message telego.Message) error + List(ctx context.Context, message telego.Message) error +} + +type cmd struct { + bot *telego.Bot + config *config.Config +} + +func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander { + return &cmd{ + bot: bot, + config: cfg, + } +} + +func commandArgs(text string) string { + parts := strings.SplitN(text, " ", 2) + if len(parts) < 2 { + return "" + } + return strings.TrimSpace(parts[1]) +} + +func (c *cmd) Help(ctx context.Context, message telego.Message) error { + msg := `/start - Start the bot +/help - Show this help message +/show [model|channel] - Show current configuration +/list [models|channels] - List available options + ` + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: msg, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} + +func (c *cmd) Start(ctx context.Context, message telego.Message) error { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Hello! I am PicoClaw 🦞", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} + +func (c *cmd) Show(ctx context.Context, message telego.Message) error { + args := commandArgs(message.Text) + if args == "" { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Usage: /show [model|channel]", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err + } + + var response string + switch args { + case "model": + response = fmt.Sprintf("Current Model: %s (Provider: %s)", + c.config.Agents.Defaults.Model, + c.config.Agents.Defaults.Provider) + case "channel": + response = "Current Channel: telegram" + default: + response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args) + } + + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: response, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} + +func (c *cmd) List(ctx context.Context, message telego.Message) error { + args := commandArgs(message.Text) + if args == "" { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Usage: /list [models|channels]", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err + } + + var response string + switch args { + case "models": + provider := c.config.Agents.Defaults.Provider + if provider == "" { + provider = "configured default" + } + response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml", + c.config.Agents.Defaults.Model, provider) + + case "channels": + var enabled []string + if c.config.Channels.Telegram.Enabled { + enabled = append(enabled, "telegram") + } + if c.config.Channels.WhatsApp.Enabled { + enabled = append(enabled, "whatsapp") + } + if c.config.Channels.Feishu.Enabled { + enabled = append(enabled, "feishu") + } + if c.config.Channels.Discord.Enabled { + enabled = append(enabled, "discord") + } + if c.config.Channels.Slack.Enabled { + enabled = append(enabled, "slack") + } + response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")) + + default: + response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args) + } + + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: response, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} diff --git a/pkg/channels/wecom.go b/pkg/channels/wecom.go new file mode 100644 index 000000000..f8daf89de --- /dev/null +++ b/pkg/channels/wecom.go @@ -0,0 +1,605 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// WeCom Bot (企业微信智能机器人) channel implementation +// Uses webhook callback mode for receiving messages and webhook API for sending replies + +package channels + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net/http" + "sort" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人) +// Uses webhook callback mode - simpler than WeCom App but only supports passive replies +type WeComBotChannel struct { + *BaseChannel + config config.WeComConfig + server *http.Server + ctx context.Context + cancel context.CancelFunc + processedMsgs map[string]bool // Message deduplication: msg_id -> processed + msgMu sync.RWMutex +} + +// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT) +type WeComBotMessage struct { + MsgID string `json:"msgid"` + AIBotID string `json:"aibotid"` + ChatID string `json:"chatid"` // Session ID, only present for group chats + ChatType string `json:"chattype"` // "single" for DM, "group" for group chat + From struct { + UserID string `json:"userid"` + } `json:"from"` + ResponseURL string `json:"response_url"` + MsgType string `json:"msgtype"` // text, image, voice, file, mixed + Text struct { + Content string `json:"content"` + } `json:"text"` + Image struct { + URL string `json:"url"` + } `json:"image"` + Voice struct { + Content string `json:"content"` // Voice to text content + } `json:"voice"` + File struct { + URL string `json:"url"` + } `json:"file"` + Mixed struct { + MsgItem []struct { + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text"` + Image struct { + URL string `json:"url"` + } `json:"image"` + } `json:"msg_item"` + } `json:"mixed"` + Quote struct { + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text"` + } `json:"quote"` +} + +// WeComBotReplyMessage represents the reply message structure +type WeComBotReplyMessage struct { + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text,omitempty"` +} + +// NewWeComBotChannel creates a new WeCom Bot channel instance +func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComBotChannel, error) { + if cfg.Token == "" || cfg.WebhookURL == "" { + return nil, fmt.Errorf("wecom token and webhook_url are required") + } + + base := NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom) + + return &WeComBotChannel{ + BaseChannel: base, + config: cfg, + processedMsgs: make(map[string]bool), + }, nil +} + +// Name returns the channel name +func (c *WeComBotChannel) Name() string { + return "wecom" +} + +// Start initializes the WeCom Bot channel with HTTP webhook server +func (c *WeComBotChannel) Start(ctx context.Context) error { + logger.InfoC("wecom", "Starting WeCom Bot channel...") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Setup HTTP server for webhook + mux := http.NewServeMux() + webhookPath := c.config.WebhookPath + if webhookPath == "" { + webhookPath = "/webhook/wecom" + } + mux.HandleFunc(webhookPath, c.handleWebhook) + + // Health check endpoint + mux.HandleFunc("/health/wecom", c.handleHealth) + + addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) + c.server = &http.Server{ + Addr: addr, + Handler: mux, + } + + c.setRunning(true) + logger.InfoCF("wecom", "WeCom Bot channel started", map[string]any{ + "address": addr, + "path": webhookPath, + }) + + // Start server in goroutine + go func() { + if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("wecom", "HTTP server error", map[string]any{ + "error": err.Error(), + }) + } + }() + + return nil +} + +// Stop gracefully stops the WeCom Bot channel +func (c *WeComBotChannel) Stop(ctx context.Context) error { + logger.InfoC("wecom", "Stopping WeCom Bot channel...") + + if c.cancel != nil { + c.cancel() + } + + if c.server != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + c.server.Shutdown(shutdownCtx) + } + + c.setRunning(false) + logger.InfoC("wecom", "WeCom Bot channel stopped") + return nil +} + +// Send sends a message to WeCom user via webhook API +// Note: WeCom Bot can only reply within the configured timeout (default 5 seconds) of receiving a message +// For delayed responses, we use the webhook URL +func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("wecom channel not running") + } + + logger.DebugCF("wecom", "Sending message via webhook", map[string]any{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) + + return c.sendWebhookReply(ctx, msg.ChatID, msg.Content) +} + +// handleWebhook handles incoming webhook requests from WeCom +func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + if r.Method == http.MethodGet { + // Handle verification request + c.handleVerification(ctx, w, r) + return + } + + if r.Method == http.MethodPost { + // Handle message callback + c.handleMessageCallback(ctx, w, r) + return + } + + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} + +// handleVerification handles the URL verification request from WeCom +func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + echostr := query.Get("echostr") + + if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Verify signature + if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + logger.WarnC("wecom", "Signature verification failed") + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + // Decrypt echostr + // For AIBOT (智能机器人), receiveid should be empty string "" + // Reference: https://developer.work.weixin.qq.com/document/path/101033 + decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") + if err != nil { + logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{ + "error": err.Error(), + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + // Remove BOM and whitespace as per WeCom documentation + // The response must be plain text without quotes, BOM, or newlines + decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) + decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM + w.Write([]byte(decryptedEchoStr)) +} + +// handleMessageCallback handles incoming messages from WeCom +func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + + if msgSignature == "" || timestamp == "" || nonce == "" { + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Read request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Parse XML to get encrypted message + var encryptedMsg struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + Encrypt string `xml:"Encrypt"` + AgentID string `xml:"AgentID"` + } + + if err = xml.Unmarshal(body, &encryptedMsg); err != nil { + logger.ErrorCF("wecom", "Failed to parse XML", map[string]any{ + "error": err.Error(), + }) + http.Error(w, "Invalid XML", http.StatusBadRequest) + return + } + + // Verify signature + if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + logger.WarnC("wecom", "Message signature verification failed") + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + // Decrypt message + // For AIBOT (智能机器人), receiveid should be empty string "" + // Reference: https://developer.work.weixin.qq.com/document/path/101033 + decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") + if err != nil { + logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{ + "error": err.Error(), + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + // Parse decrypted JSON message (AIBOT uses JSON format) + var msg WeComBotMessage + if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil { + logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]any{ + "error": err.Error(), + }) + http.Error(w, "Invalid message format", http.StatusBadRequest) + return + } + + // Process the message asynchronously with context + go c.processMessage(ctx, msg) + + // Return success response immediately + // WeCom Bot requires response within configured timeout (default 5 seconds) + w.Write([]byte("success")) +} + +// processMessage processes the received message +func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) { + // Skip unsupported message types + if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" && + msg.MsgType != "mixed" { + logger.DebugCF("wecom", "Skipping non-supported message type", map[string]any{ + "msg_type": msg.MsgType, + }) + return + } + + // Message deduplication: Use msg_id to prevent duplicate processing + msgID := msg.MsgID + c.msgMu.Lock() + if c.processedMsgs[msgID] { + c.msgMu.Unlock() + logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{ + "msg_id": msgID, + }) + return + } + c.processedMsgs[msgID] = true + c.msgMu.Unlock() + + // Clean up old messages periodically (keep last 1000) + if len(c.processedMsgs) > 1000 { + c.msgMu.Lock() + c.processedMsgs = make(map[string]bool) + c.msgMu.Unlock() + } + + senderID := msg.From.UserID + + // Determine if this is a group chat or direct message + // ChatType: "single" for DM, "group" for group chat + isGroupChat := msg.ChatType == "group" + + var chatID, peerKind, peerID string + if isGroupChat { + // Group chat: use ChatID as chatID and peer_id + chatID = msg.ChatID + peerKind = "group" + peerID = msg.ChatID + } else { + // Direct message: use senderID as chatID and peer_id + chatID = senderID + peerKind = "direct" + peerID = senderID + } + + // Extract content based on message type + var content string + switch msg.MsgType { + case "text": + content = msg.Text.Content + case "voice": + content = msg.Voice.Content // Voice to text content + case "mixed": + // For mixed messages, concatenate text items + for _, item := range msg.Mixed.MsgItem { + if item.MsgType == "text" { + content += item.Text.Content + } + } + case "image", "file": + // For image and file, we don't have text content + content = "" + } + + // Build metadata + metadata := map[string]string{ + "msg_type": msg.MsgType, + "msg_id": msg.MsgID, + "platform": "wecom", + "peer_kind": peerKind, + "peer_id": peerID, + "response_url": msg.ResponseURL, + } + if isGroupChat { + metadata["chat_id"] = msg.ChatID + metadata["sender_id"] = senderID + } + + logger.DebugCF("wecom", "Received message", map[string]any{ + "sender_id": senderID, + "msg_type": msg.MsgType, + "peer_kind": peerKind, + "is_group_chat": isGroupChat, + "preview": utils.Truncate(content, 50), + }) + + // Handle the message through the base channel + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +// sendWebhookReply sends a reply using the webhook URL +func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error { + reply := WeComBotReplyMessage{ + MsgType: "text", + } + reply.Text.Content = content + + jsonData, err := json.Marshal(reply) + if err != nil { + return fmt.Errorf("failed to marshal reply: %w", err) + } + + // Use configurable timeout (default 5 seconds) + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.WebhookURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to send webhook reply: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + // Check response + var result struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + } + if err := json.Unmarshal(body, &result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if result.ErrCode != 0 { + return fmt.Errorf("webhook API error: %s (code: %d)", result.ErrMsg, result.ErrCode) + } + + return nil +} + +// handleHealth handles health check requests +func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { + status := map[string]any{ + "status": "ok", + "running": c.IsRunning(), + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(status) +} + +// WeCom common utilities for both WeCom Bot and WeCom App +// The following functions were moved from wecom_common.go + +// WeComVerifySignature verifies the message signature for WeCom +// This is a common function used by both WeCom Bot and WeCom App +func WeComVerifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { + if token == "" { + return true // Skip verification if token is not set + } + + // Sort parameters + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + + // Concatenate + str := strings.Join(params, "") + + // SHA1 hash + hash := sha1.Sum([]byte(str)) + expectedSignature := fmt.Sprintf("%x", hash) + + return expectedSignature == msgSignature +} + +// WeComDecryptMessage decrypts the encrypted message using AES +// This is a common function used by both WeCom Bot and WeCom App +// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id +func WeComDecryptMessage(encryptedMsg, encodingAESKey string) (string, error) { + return WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, "") +} + +// WeComDecryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid +// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. +func WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { + if encodingAESKey == "" { + // No encryption, return as is (base64 decode) + decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", err + } + return string(decoded), nil + } + + // Decode AES key (base64) + aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + if err != nil { + return "", fmt.Errorf("failed to decode AES key: %w", err) + } + + // Decode encrypted message + cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", fmt.Errorf("failed to decode message: %w", err) + } + + // AES decrypt + block, err := aes.NewCipher(aesKey) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + if len(cipherText) < aes.BlockSize { + return "", fmt.Errorf("ciphertext too short") + } + + // IV is the first 16 bytes of AESKey + iv := aesKey[:aes.BlockSize] + mode := cipher.NewCBCDecrypter(block, iv) + plainText := make([]byte, len(cipherText)) + mode.CryptBlocks(plainText, cipherText) + + // Remove PKCS7 padding + plainText, err = pkcs7UnpadWeCom(plainText) + if err != nil { + return "", fmt.Errorf("failed to unpad: %w", err) + } + + // Parse message structure + // Format: random(16) + msg_len(4) + msg + receiveid + if len(plainText) < 20 { + return "", fmt.Errorf("decrypted message too short") + } + + msgLen := binary.BigEndian.Uint32(plainText[16:20]) + if int(msgLen) > len(plainText)-20 { + return "", fmt.Errorf("invalid message length") + } + + msg := plainText[20 : 20+msgLen] + + // Verify receiveid if provided + if receiveid != "" && len(plainText) > 20+int(msgLen) { + actualReceiveID := string(plainText[20+msgLen:]) + if actualReceiveID != receiveid { + return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) + } + } + + return string(msg), nil +} + +// pkcs7UnpadWeCom removes PKCS7 padding with validation +// WeCom uses block size of 32 (not standard AES block size of 16) +const wecomBlockSize = 32 + +func pkcs7UnpadWeCom(data []byte) ([]byte, error) { + if len(data) == 0 { + return data, nil + } + padding := int(data[len(data)-1]) + // WeCom uses 32-byte block size for PKCS7 padding + if padding == 0 || padding > wecomBlockSize { + return nil, fmt.Errorf("invalid padding size: %d", padding) + } + if padding > len(data) { + return nil, fmt.Errorf("padding size larger than data") + } + // Verify all padding bytes + for i := 0; i < padding; i++ { + if data[len(data)-1-i] != byte(padding) { + return nil, fmt.Errorf("invalid padding byte at position %d", i) + } + } + return data[:len(data)-padding], nil +} diff --git a/pkg/channels/wecom_app.go b/pkg/channels/wecom_app.go new file mode 100644 index 000000000..715c48707 --- /dev/null +++ b/pkg/channels/wecom_app.go @@ -0,0 +1,639 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// WeCom App (企业微信自建应用) channel implementation +// Supports receiving messages via webhook callback and sending messages proactively + +package channels + +import ( + "bytes" + "context" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +const ( + wecomAPIBase = "https://qyapi.weixin.qq.com" +) + +// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用) +type WeComAppChannel struct { + *BaseChannel + config config.WeComAppConfig + server *http.Server + accessToken string + tokenExpiry time.Time + tokenMu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + processedMsgs map[string]bool // Message deduplication: msg_id -> processed + msgMu sync.RWMutex +} + +// WeComXMLMessage represents the XML message structure from WeCom +type WeComXMLMessage struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + FromUserName string `xml:"FromUserName"` + CreateTime int64 `xml:"CreateTime"` + MsgType string `xml:"MsgType"` + Content string `xml:"Content"` + MsgId int64 `xml:"MsgId"` + AgentID int64 `xml:"AgentID"` + PicUrl string `xml:"PicUrl"` + MediaId string `xml:"MediaId"` + Format string `xml:"Format"` + ThumbMediaId string `xml:"ThumbMediaId"` + LocationX float64 `xml:"Location_X"` + LocationY float64 `xml:"Location_Y"` + Scale int `xml:"Scale"` + Label string `xml:"Label"` + Title string `xml:"Title"` + Description string `xml:"Description"` + Url string `xml:"Url"` + Event string `xml:"Event"` + EventKey string `xml:"EventKey"` +} + +// WeComTextMessage represents text message for sending +type WeComTextMessage struct { + ToUser string `json:"touser"` + MsgType string `json:"msgtype"` + AgentID int64 `json:"agentid"` + Text struct { + Content string `json:"content"` + } `json:"text"` + Safe int `json:"safe,omitempty"` +} + +// WeComMarkdownMessage represents markdown message for sending +type WeComMarkdownMessage struct { + ToUser string `json:"touser"` + MsgType string `json:"msgtype"` + AgentID int64 `json:"agentid"` + Markdown struct { + Content string `json:"content"` + } `json:"markdown"` +} + +// WeComImageMessage represents image message for sending +type WeComImageMessage struct { + ToUser string `json:"touser"` + MsgType string `json:"msgtype"` + AgentID int64 `json:"agentid"` + Image struct { + MediaID string `json:"media_id"` + } `json:"image"` +} + +// WeComAccessTokenResponse represents the access token API response +type WeComAccessTokenResponse struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` +} + +// WeComSendMessageResponse represents the send message API response +type WeComSendMessageResponse struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + InvalidUser string `json:"invaliduser"` + InvalidParty string `json:"invalidparty"` + InvalidTag string `json:"invalidtag"` +} + +// PKCS7Padding adds PKCS7 padding +type PKCS7Padding struct{} + +// NewWeComAppChannel creates a new WeCom App channel instance +func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) { + if cfg.CorpID == "" || cfg.CorpSecret == "" || cfg.AgentID == 0 { + return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required") + } + + base := NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom) + + return &WeComAppChannel{ + BaseChannel: base, + config: cfg, + processedMsgs: make(map[string]bool), + }, nil +} + +// Name returns the channel name +func (c *WeComAppChannel) Name() string { + return "wecom_app" +} + +// Start initializes the WeCom App channel with HTTP webhook server +func (c *WeComAppChannel) Start(ctx context.Context) error { + logger.InfoC("wecom_app", "Starting WeCom App channel...") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Get initial access token + if err := c.refreshAccessToken(); err != nil { + logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]any{ + "error": err.Error(), + }) + } + + // Start token refresh goroutine + go c.tokenRefreshLoop() + + // Setup HTTP server for webhook + mux := http.NewServeMux() + webhookPath := c.config.WebhookPath + if webhookPath == "" { + webhookPath = "/webhook/wecom-app" + } + mux.HandleFunc(webhookPath, c.handleWebhook) + + // Health check endpoint + mux.HandleFunc("/health/wecom-app", c.handleHealth) + + addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) + c.server = &http.Server{ + Addr: addr, + Handler: mux, + } + + c.setRunning(true) + logger.InfoCF("wecom_app", "WeCom App channel started", map[string]any{ + "address": addr, + "path": webhookPath, + }) + + // Start server in goroutine + go func() { + if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("wecom_app", "HTTP server error", map[string]any{ + "error": err.Error(), + }) + } + }() + + return nil +} + +// Stop gracefully stops the WeCom App channel +func (c *WeComAppChannel) Stop(ctx context.Context) error { + logger.InfoC("wecom_app", "Stopping WeCom App channel...") + + if c.cancel != nil { + c.cancel() + } + + if c.server != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + c.server.Shutdown(shutdownCtx) + } + + c.setRunning(false) + logger.InfoC("wecom_app", "WeCom App channel stopped") + return nil +} + +// Send sends a message to WeCom user proactively using access token +func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("wecom_app channel not running") + } + + accessToken := c.getAccessToken() + if accessToken == "" { + return fmt.Errorf("no valid access token available") + } + + logger.DebugCF("wecom_app", "Sending message", map[string]any{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) + + return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content) +} + +// handleWebhook handles incoming webhook requests from WeCom +func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Log all incoming requests for debugging + logger.DebugCF("wecom_app", "Received webhook request", map[string]any{ + "method": r.Method, + "url": r.URL.String(), + "path": r.URL.Path, + "query": r.URL.RawQuery, + }) + + if r.Method == http.MethodGet { + // Handle verification request + c.handleVerification(ctx, w, r) + return + } + + if r.Method == http.MethodPost { + // Handle message callback + c.handleMessageCallback(ctx, w, r) + return + } + + logger.WarnCF("wecom_app", "Method not allowed", map[string]any{ + "method": r.Method, + }) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} + +// handleVerification handles the URL verification request from WeCom +func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + echostr := query.Get("echostr") + + logger.DebugCF("wecom_app", "Handling verification request", map[string]any{ + "msg_signature": msgSignature, + "timestamp": timestamp, + "nonce": nonce, + "echostr": echostr, + "corp_id": c.config.CorpID, + }) + + if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { + logger.ErrorC("wecom_app", "Missing parameters in verification request") + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Verify signature + if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{ + "token": c.config.Token, + "msg_signature": msgSignature, + "timestamp": timestamp, + "nonce": nonce, + }) + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + logger.DebugC("wecom_app", "Signature verification passed") + + // Decrypt echostr with CorpID verification + // For WeCom App (自建应用), receiveid should be corp_id + logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]any{ + "encoding_aes_key": c.config.EncodingAESKey, + "corp_id": c.config.CorpID, + }) + decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{ + "error": err.Error(), + "encoding_aes_key": c.config.EncodingAESKey, + "corp_id": c.config.CorpID, + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]any{ + "decrypted": decryptedEchoStr, + }) + + // Remove BOM and whitespace as per WeCom documentation + // The response must be plain text without quotes, BOM, or newlines + decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) + decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM + w.Write([]byte(decryptedEchoStr)) +} + +// handleMessageCallback handles incoming messages from WeCom +func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + + if msgSignature == "" || timestamp == "" || nonce == "" { + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Read request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Parse XML to get encrypted message + var encryptedMsg struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + Encrypt string `xml:"Encrypt"` + AgentID string `xml:"AgentID"` + } + + if err = xml.Unmarshal(body, &encryptedMsg); err != nil { + logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]any{ + "error": err.Error(), + }) + http.Error(w, "Invalid XML", http.StatusBadRequest) + return + } + + // Verify signature + if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + logger.WarnC("wecom_app", "Message signature verification failed") + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + // Decrypt message with CorpID verification + // For WeCom App (自建应用), receiveid should be corp_id + decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{ + "error": err.Error(), + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + // Parse decrypted XML message + var msg WeComXMLMessage + if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil { + logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]any{ + "error": err.Error(), + }) + http.Error(w, "Invalid message format", http.StatusBadRequest) + return + } + + // Process the message with context + go c.processMessage(ctx, msg) + + // Return success response immediately + // WeCom App requires response within configured timeout (default 5 seconds) + w.Write([]byte("success")) +} + +// processMessage processes the received message +func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) { + // Skip non-text messages for now (can be extended) + if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" { + logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]any{ + "msg_type": msg.MsgType, + }) + return + } + + // Message deduplication: Use msg_id to prevent duplicate processing + // As per WeCom documentation, use msg_id for deduplication + msgID := fmt.Sprintf("%d", msg.MsgId) + c.msgMu.Lock() + if c.processedMsgs[msgID] { + c.msgMu.Unlock() + logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{ + "msg_id": msgID, + }) + return + } + c.processedMsgs[msgID] = true + c.msgMu.Unlock() + + // Clean up old messages periodically (keep last 1000) + if len(c.processedMsgs) > 1000 { + c.msgMu.Lock() + c.processedMsgs = make(map[string]bool) + c.msgMu.Unlock() + } + + senderID := msg.FromUserName + chatID := senderID // WeCom App uses user ID as chat ID for direct messages + + // Build metadata + // WeCom App only supports direct messages (private chat) + metadata := map[string]string{ + "msg_type": msg.MsgType, + "msg_id": fmt.Sprintf("%d", msg.MsgId), + "agent_id": fmt.Sprintf("%d", msg.AgentID), + "platform": "wecom_app", + "media_id": msg.MediaId, + "create_time": fmt.Sprintf("%d", msg.CreateTime), + "peer_kind": "direct", + "peer_id": senderID, + } + + content := msg.Content + + logger.DebugCF("wecom_app", "Received message", map[string]any{ + "sender_id": senderID, + "msg_type": msg.MsgType, + "preview": utils.Truncate(content, 50), + }) + + // Handle the message through the base channel + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +// tokenRefreshLoop periodically refreshes the access token +func (c *WeComAppChannel) tokenRefreshLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + if err := c.refreshAccessToken(); err != nil { + logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]any{ + "error": err.Error(), + }) + } + } + } +} + +// refreshAccessToken gets a new access token from WeCom API +func (c *WeComAppChannel) refreshAccessToken() error { + apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s", + wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret)) + + resp, err := http.Get(apiURL) + if err != nil { + return fmt.Errorf("failed to request access token: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var tokenResp WeComAccessTokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if tokenResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode) + } + + c.tokenMu.Lock() + c.accessToken = tokenResp.AccessToken + c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early + c.tokenMu.Unlock() + + logger.DebugC("wecom_app", "Access token refreshed successfully") + return nil +} + +// getAccessToken returns the current valid access token +func (c *WeComAppChannel) getAccessToken() string { + c.tokenMu.RLock() + defer c.tokenMu.RUnlock() + + if time.Now().After(c.tokenExpiry) { + return "" + } + + return c.accessToken +} + +// sendTextMessage sends a text message to a user +func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error { + apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) + + msg := WeComTextMessage{ + ToUser: userID, + MsgType: "text", + AgentID: c.config.AgentID, + } + msg.Text.Content = content + + jsonData, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Use configurable timeout (default 5 seconds) + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var sendResp WeComSendMessageResponse + if err := json.Unmarshal(body, &sendResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if sendResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) + } + + return nil +} + +// sendMarkdownMessage sends a markdown message to a user +func (c *WeComAppChannel) sendMarkdownMessage(ctx context.Context, accessToken, userID, content string) error { + apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) + + msg := WeComMarkdownMessage{ + ToUser: userID, + MsgType: "markdown", + AgentID: c.config.AgentID, + } + msg.Markdown.Content = content + + jsonData, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Use configurable timeout (default 5 seconds) + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var sendResp WeComSendMessageResponse + if err := json.Unmarshal(body, &sendResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if sendResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) + } + + return nil +} + +// handleHealth handles health check requests +func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) { + status := map[string]any{ + "status": "ok", + "running": c.IsRunning(), + "has_token": c.getAccessToken() != "", + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(status) +} diff --git a/pkg/channels/wecom_app_test.go b/pkg/channels/wecom_app_test.go new file mode 100644 index 000000000..abf15c52b --- /dev/null +++ b/pkg/channels/wecom_app_test.go @@ -0,0 +1,1104 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// WeCom App (企业微信自建应用) channel tests + +package channels + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "encoding/json" + "encoding/xml" + "fmt" + "net/http" + "net/http/httptest" + "sort" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +// generateTestAESKeyApp generates a valid test AES key for WeCom App +func generateTestAESKeyApp() string { + // AES key needs to be 32 bytes (256 bits) for AES-256 + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + // Return base64 encoded key without padding + return base64.StdEncoding.EncodeToString(key)[:43] +} + +// encryptTestMessageApp encrypts a message for testing WeCom App +func encryptTestMessageApp(message, aesKey string) (string, error) { + // Decode AES key + key, err := base64.StdEncoding.DecodeString(aesKey + "=") + if err != nil { + return "", err + } + + // Prepare message: random(16) + msg_len(4) + msg + corp_id + random := make([]byte, 0, 16) + for i := 0; i < 16; i++ { + random = append(random, byte(i+1)) + } + + msgBytes := []byte(message) + corpID := []byte("test_corp_id") + + msgLen := uint32(len(msgBytes)) + lenBytes := make([]byte, 4) + binary.BigEndian.PutUint32(lenBytes, msgLen) + + plainText := append(random, lenBytes...) + plainText = append(plainText, msgBytes...) + plainText = append(plainText, corpID...) + + // PKCS7 padding + blockSize := aes.BlockSize + padding := blockSize - len(plainText)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + plainText = append(plainText, padText...) + + // Encrypt + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) + cipherText := make([]byte, len(plainText)) + mode.CryptBlocks(cipherText, plainText) + + return base64.StdEncoding.EncodeToString(cipherText), nil +} + +// generateSignatureApp generates a signature for testing WeCom App +func generateSignatureApp(token, timestamp, nonce, msgEncrypt string) string { + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + str := strings.Join(params, "") + hash := sha1.Sum([]byte(str)) + return fmt.Sprintf("%x", hash) +} + +func TestNewWeComAppChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing corp_id", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "", + CorpSecret: "test_secret", + AgentID: 1000002, + } + _, err := NewWeComAppChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing corp_id, got nil") + } + }) + + t.Run("missing corp_secret", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "", + AgentID: 1000002, + } + _, err := NewWeComAppChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing corp_secret, got nil") + } + }) + + t.Run("missing agent_id", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 0, + } + _, err := NewWeComAppChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing agent_id, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + AllowFrom: []string{"user1", "user2"}, + } + ch, err := NewWeComAppChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "wecom_app" { + t.Errorf("Name() = %q, want %q", ch.Name(), "wecom_app") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestWeComAppChannelIsAllowed(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("empty allowlist allows all", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + AllowFrom: []string{}, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + if !ch.IsAllowed("any_user") { + t.Error("empty allowlist should allow all users") + } + }) + + t.Run("allowlist restricts users", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + AllowFrom: []string{"allowed_user"}, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + if !ch.IsAllowed("allowed_user") { + t.Error("allowed user should pass allowlist check") + } + if ch.IsAllowed("blocked_user") { + t.Error("non-allowed user should be blocked") + } + }) +} + +func TestWeComAppVerifySignature(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("valid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt) + + if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + t.Error("valid signature should pass verification") + } + }) + + t.Run("invalid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + + if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + t.Error("invalid signature should fail verification") + } + }) + + t.Run("empty token skips verification", func(t *testing.T) { + cfgEmpty := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "", + } + chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) + + if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + t.Error("empty token should skip verification and return true") + } + }) +} + +func TestWeComAppDecryptMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("decrypt without AES key", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: "", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + // Without AES key, message should be base64 decoded only + plainText := "hello world" + encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) + + result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != plainText { + t.Errorf("decryptMessage() = %q, want %q", result, plainText) + } + }) + + t.Run("decrypt with AES key", func(t *testing.T) { + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + originalMsg := "Hello" + encrypted, err := encryptTestMessageApp(originalMsg, aesKey) + if err != nil { + t.Fatalf("failed to encrypt test message: %v", err) + } + + result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != originalMsg { + t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) + } + }) + + t.Run("invalid base64", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: "", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid base64, got nil") + } + }) + + t.Run("invalid AES key", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: "invalid_key", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid AES key, got nil") + } + }) + + t.Run("ciphertext too short", func(t *testing.T) { + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + // Encrypt a very short message that results in ciphertext less than block size + shortData := make([]byte, 8) + _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for short ciphertext, got nil") + } + }) +} + +func TestWeComAppPKCS7Unpad(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "empty input", + input: []byte{}, + expected: []byte{}, + }, + { + name: "valid padding 3 bytes", + input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), + expected: []byte("hello"), + }, + { + name: "valid padding 16 bytes (full block)", + input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), + expected: []byte("123456789012345"), + }, + { + name: "invalid padding larger than data", + input: []byte{20}, + expected: nil, // should return error + }, + { + name: "invalid padding zero", + input: append([]byte("test"), byte(0)), + expected: nil, // should return error + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := pkcs7UnpadWeCom(tt.input) + if tt.expected == nil { + // This case should return an error + if err == nil { + t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) + } + return + } + if err != nil { + t.Errorf("pkcs7Unpad() unexpected error: %v", err) + return + } + if !bytes.Equal(result, tt.expected) { + t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestWeComAppHandleVerification(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("valid verification request", func(t *testing.T) { + echostr := "test_echostr_123" + encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encryptedEchostr) + + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, + nil, + ) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != echostr { + t.Errorf("response body = %q, want %q", w.Body.String(), echostr) + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature=sig×tamp=ts", nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + echostr := "test_echostr" + encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, + nil, + ) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComAppHandleMessageCallback(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("valid message callback", func(t *testing.T) { + // Create XML message + xmlMsg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "text", + Content: "Hello World", + MsgId: 123456, + AgentID: 1000002, + } + xmlData, _ := xml.Marshal(xmlMsg) + + // Encrypt message + encrypted, _ := encryptTestMessageApp(string(xmlData), aesKey) + + // Create encrypted XML wrapper + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: encrypted, + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encrypted) + + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != "success" { + t.Errorf("response body = %q, want %q", w.Body.String(), "success") + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature=sig", nil) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid XML", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, "") + + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + strings.NewReader("invalid xml"), + ) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: "encrypted_data", + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComAppProcessMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("process text message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "text", + Content: "Hello World", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process image message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "image", + PicUrl: "https://example.com/image.jpg", + MediaId: "media_123", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process voice message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "voice", + MediaId: "media_123", + Format: "amr", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("skip unsupported message type", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "video", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process event message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "event", + Event: "subscribe", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) +} + +func TestWeComAppHandleWebhook(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("GET request calls verification", func(t *testing.T) { + echostr := "test_echostr" + encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encoded) + + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, + nil, + ) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + }) + + t.Run("POST request calls message callback", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encryptedWrapper.Encrypt) + + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + // Should not be method not allowed + if w.Code == http.StatusMethodNotAllowed { + t.Error("POST request should not return Method Not Allowed") + } + }) + + t.Run("unsupported method", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/webhook/wecom-app", nil) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) + } + }) +} + +func TestWeComAppHandleHealth(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + req := httptest.NewRequest(http.MethodGet, "/health/wecom-app", nil) + w := httptest.NewRecorder() + + ch.handleHealth(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Content-Type = %q, want %q", contentType, "application/json") + } + + body := w.Body.String() + if !strings.Contains(body, "status") || !strings.Contains(body, "running") || !strings.Contains(body, "has_token") { + t.Errorf("response body should contain status, running, and has_token fields, got: %s", body) + } +} + +func TestWeComAppAccessToken(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("get empty access token initially", func(t *testing.T) { + token := ch.getAccessToken() + if token != "" { + t.Errorf("getAccessToken() = %q, want empty string", token) + } + }) + + t.Run("set and get access token", func(t *testing.T) { + ch.tokenMu.Lock() + ch.accessToken = "test_token_123" + ch.tokenExpiry = time.Now().Add(1 * time.Hour) + ch.tokenMu.Unlock() + + token := ch.getAccessToken() + if token != "test_token_123" { + t.Errorf("getAccessToken() = %q, want %q", token, "test_token_123") + } + }) + + t.Run("expired token returns empty", func(t *testing.T) { + ch.tokenMu.Lock() + ch.accessToken = "expired_token" + ch.tokenExpiry = time.Now().Add(-1 * time.Hour) + ch.tokenMu.Unlock() + + token := ch.getAccessToken() + if token != "" { + t.Errorf("getAccessToken() = %q, want empty string for expired token", token) + } + }) +} + +func TestWeComAppMessageStructures(t *testing.T) { + t.Run("WeComTextMessage structure", func(t *testing.T) { + msg := WeComTextMessage{ + ToUser: "user123", + MsgType: "text", + AgentID: 1000002, + } + msg.Text.Content = "Hello World" + + if msg.ToUser != "user123" { + t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") + } + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.AgentID != 1000002 { + t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) + } + if msg.Text.Content != "Hello World" { + t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") + } + + // Test JSON marshaling + jsonData, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal JSON: %v", err) + } + + var unmarshaled WeComTextMessage + err = json.Unmarshal(jsonData, &unmarshaled) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if unmarshaled.ToUser != msg.ToUser { + t.Errorf("JSON round-trip failed for ToUser") + } + }) + + t.Run("WeComMarkdownMessage structure", func(t *testing.T) { + msg := WeComMarkdownMessage{ + ToUser: "user123", + MsgType: "markdown", + AgentID: 1000002, + } + msg.Markdown.Content = "# Hello\nWorld" + + if msg.Markdown.Content != "# Hello\nWorld" { + t.Errorf("Markdown.Content = %q, want %q", msg.Markdown.Content, "# Hello\nWorld") + } + + // Test JSON marshaling + jsonData, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal JSON: %v", err) + } + + if !bytes.Contains(jsonData, []byte("markdown")) { + t.Error("JSON should contain 'markdown' field") + } + }) + + t.Run("WeComAccessTokenResponse structure", func(t *testing.T) { + jsonData := `{ + "errcode": 0, + "errmsg": "ok", + "access_token": "test_access_token", + "expires_in": 7200 + }` + + var resp WeComAccessTokenResponse + err := json.Unmarshal([]byte(jsonData), &resp) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if resp.ErrCode != 0 { + t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) + } + if resp.ErrMsg != "ok" { + t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") + } + if resp.AccessToken != "test_access_token" { + t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "test_access_token") + } + if resp.ExpiresIn != 7200 { + t.Errorf("ExpiresIn = %d, want %d", resp.ExpiresIn, 7200) + } + }) + + t.Run("WeComSendMessageResponse structure", func(t *testing.T) { + jsonData := `{ + "errcode": 0, + "errmsg": "ok", + "invaliduser": "", + "invalidparty": "", + "invalidtag": "" + }` + + var resp WeComSendMessageResponse + err := json.Unmarshal([]byte(jsonData), &resp) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if resp.ErrCode != 0 { + t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) + } + if resp.ErrMsg != "ok" { + t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") + } + }) +} + +func TestWeComAppXMLMessageStructure(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.ToUserName != "corp_id" { + t.Errorf("ToUserName = %q, want %q", msg.ToUserName, "corp_id") + } + if msg.FromUserName != "user123" { + t.Errorf("FromUserName = %q, want %q", msg.FromUserName, "user123") + } + if msg.CreateTime != 1234567890 { + t.Errorf("CreateTime = %d, want %d", msg.CreateTime, 1234567890) + } + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.Content != "Hello World" { + t.Errorf("Content = %q, want %q", msg.Content, "Hello World") + } + if msg.MsgId != 1234567890123456 { + t.Errorf("MsgId = %d, want %d", msg.MsgId, 1234567890123456) + } + if msg.AgentID != 1000002 { + t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) + } +} + +func TestWeComAppXMLMessageImage(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "image" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") + } + if msg.PicUrl != "https://example.com/image.jpg" { + t.Errorf("PicUrl = %q, want %q", msg.PicUrl, "https://example.com/image.jpg") + } + if msg.MediaId != "media_123" { + t.Errorf("MediaId = %q, want %q", msg.MediaId, "media_123") + } +} + +func TestWeComAppXMLMessageVoice(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "voice" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "voice") + } + if msg.Format != "amr" { + t.Errorf("Format = %q, want %q", msg.Format, "amr") + } +} + +func TestWeComAppXMLMessageLocation(t *testing.T) { + xmlData := ` + + + + 1234567890 + + 39.9042 + 116.4074 + 16 + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "location" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "location") + } + if msg.LocationX != 39.9042 { + t.Errorf("LocationX = %f, want %f", msg.LocationX, 39.9042) + } + if msg.LocationY != 116.4074 { + t.Errorf("LocationY = %f, want %f", msg.LocationY, 116.4074) + } + if msg.Scale != 16 { + t.Errorf("Scale = %d, want %d", msg.Scale, 16) + } + if msg.Label != "Beijing" { + t.Errorf("Label = %q, want %q", msg.Label, "Beijing") + } +} + +func TestWeComAppXMLMessageLink(t *testing.T) { + xmlData := ` + + + + 1234567890 + + <![CDATA[Link Title]]> + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "link" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "link") + } + if msg.Title != "Link Title" { + t.Errorf("Title = %q, want %q", msg.Title, "Link Title") + } + if msg.Description != "Link Description" { + t.Errorf("Description = %q, want %q", msg.Description, "Link Description") + } + if msg.Url != "https://example.com" { + t.Errorf("Url = %q, want %q", msg.Url, "https://example.com") + } +} + +func TestWeComAppXMLMessageEvent(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "event" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "event") + } + if msg.Event != "subscribe" { + t.Errorf("Event = %q, want %q", msg.Event, "subscribe") + } + if msg.EventKey != "event_key_123" { + t.Errorf("EventKey = %q, want %q", msg.EventKey, "event_key_123") + } +} diff --git a/pkg/channels/wecom_test.go b/pkg/channels/wecom_test.go new file mode 100644 index 000000000..8afa7e8c3 --- /dev/null +++ b/pkg/channels/wecom_test.go @@ -0,0 +1,785 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// WeCom Bot (企业微信智能机器人) channel tests + +package channels + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "encoding/json" + "encoding/xml" + "fmt" + "net/http" + "net/http/httptest" + "sort" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +// generateTestAESKey generates a valid test AES key +func generateTestAESKey() string { + // AES key needs to be 32 bytes (256 bits) for AES-256 + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + // Return base64 encoded key without padding + return base64.StdEncoding.EncodeToString(key)[:43] +} + +// encryptTestMessage encrypts a message for testing (AIBOT JSON format) +func encryptTestMessage(message, aesKey string) (string, error) { + // Decode AES key + key, err := base64.StdEncoding.DecodeString(aesKey + "=") + if err != nil { + return "", err + } + + // Prepare message: random(16) + msg_len(4) + msg + receiveid + random := make([]byte, 0, 16) + for i := 0; i < 16; i++ { + random = append(random, byte(i)) + } + + msgBytes := []byte(message) + receiveID := []byte("test_aibot_id") + + msgLen := uint32(len(msgBytes)) + lenBytes := make([]byte, 4) + binary.BigEndian.PutUint32(lenBytes, msgLen) + + plainText := append(random, lenBytes...) + plainText = append(plainText, msgBytes...) + plainText = append(plainText, receiveID...) + + // PKCS7 padding + blockSize := aes.BlockSize + padding := blockSize - len(plainText)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + plainText = append(plainText, padText...) + + // Encrypt + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) + cipherText := make([]byte, len(plainText)) + mode.CryptBlocks(cipherText, plainText) + + return base64.StdEncoding.EncodeToString(cipherText), nil +} + +// generateSignature generates a signature for testing +func generateSignature(token, timestamp, nonce, msgEncrypt string) string { + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + str := strings.Join(params, "") + hash := sha1.Sum([]byte(str)) + return fmt.Sprintf("%x", hash) +} + +func TestNewWeComBotChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing token", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + _, err := NewWeComBotChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing token, got nil") + } + }) + + t.Run("missing webhook_url", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "", + } + _, err := NewWeComBotChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing webhook_url, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + AllowFrom: []string{"user1", "user2"}, + } + ch, err := NewWeComBotChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "wecom" { + t.Errorf("Name() = %q, want %q", ch.Name(), "wecom") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestWeComBotChannelIsAllowed(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("empty allowlist allows all", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + AllowFrom: []string{}, + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + if !ch.IsAllowed("any_user") { + t.Error("empty allowlist should allow all users") + } + }) + + t.Run("allowlist restricts users", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + AllowFrom: []string{"allowed_user"}, + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + if !ch.IsAllowed("allowed_user") { + t.Error("allowed user should pass allowlist check") + } + if ch.IsAllowed("blocked_user") { + t.Error("non-allowed user should be blocked") + } + }) +} + +func TestWeComBotVerifySignature(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("valid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt) + + if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + t.Error("valid signature should pass verification") + } + }) + + t.Run("invalid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + + if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + t.Error("invalid signature should fail verification") + } + }) + + t.Run("empty token skips verification", func(t *testing.T) { + // Create a channel manually with empty token to test the behavior + cfgEmpty := config.WeComConfig{ + Token: "", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + chEmpty := &WeComBotChannel{ + config: cfgEmpty, + } + + if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + t.Error("empty token should skip verification and return true") + } + }) +} + +func TestWeComBotDecryptMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("decrypt without AES key", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: "", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + // Without AES key, message should be base64 decoded only + plainText := "hello world" + encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) + + result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != plainText { + t.Errorf("decryptMessage() = %q, want %q", result, plainText) + } + }) + + t.Run("decrypt with AES key", func(t *testing.T) { + aesKey := generateTestAESKey() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: aesKey, + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + originalMsg := "Hello" + encrypted, err := encryptTestMessage(originalMsg, aesKey) + if err != nil { + t.Fatalf("failed to encrypt test message: %v", err) + } + + result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != originalMsg { + t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) + } + }) + + t.Run("invalid base64", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: "", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid base64, got nil") + } + }) + + t.Run("invalid AES key", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: "invalid_key", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid AES key, got nil") + } + }) +} + +func TestWeComBotPKCS7Unpad(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "empty input", + input: []byte{}, + expected: []byte{}, + }, + { + name: "valid padding 3 bytes", + input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), + expected: []byte("hello"), + }, + { + name: "valid padding 16 bytes (full block)", + input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), + expected: []byte("123456789012345"), + }, + { + name: "invalid padding larger than data", + input: []byte{20}, + expected: nil, // should return error + }, + { + name: "invalid padding zero", + input: append([]byte("test"), byte(0)), + expected: nil, // should return error + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := pkcs7UnpadWeCom(tt.input) + if tt.expected == nil { + // This case should return an error + if err == nil { + t.Errorf("pkcs7UnpadWeCom() expected error for invalid padding, got result: %v", result) + } + return + } + if err != nil { + t.Errorf("pkcs7UnpadWeCom() unexpected error: %v", err) + return + } + if !bytes.Equal(result, tt.expected) { + t.Errorf("pkcs7UnpadWeCom() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestWeComBotHandleVerification(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKey() + cfg := config.WeComConfig{ + Token: "test_token", + EncodingAESKey: aesKey, + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("valid verification request", func(t *testing.T) { + echostr := "test_echostr_123" + encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr) + + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, + nil, + ) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != echostr { + t.Errorf("response body = %q, want %q", w.Body.String(), echostr) + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=sig×tamp=ts", nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + echostr := "test_echostr" + encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, + nil, + ) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComBotHandleMessageCallback(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKey() + cfg := config.WeComConfig{ + Token: "test_token", + EncodingAESKey: aesKey, + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("valid direct message callback", func(t *testing.T) { + // Create JSON message for direct chat (single) + jsonMsg := `{ + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chattype": "single", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }` + + // Encrypt message + encrypted, _ := encryptTestMessage(jsonMsg, aesKey) + + // Create encrypted XML wrapper + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: encrypted, + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encrypted) + + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != "success" { + t.Errorf("response body = %q, want %q", w.Body.String(), "success") + } + }) + + t.Run("valid group message callback", func(t *testing.T) { + // Create JSON message for group chat + jsonMsg := `{ + "msgid": "test_msg_id_456", + "aibotid": "test_aibot_id", + "chatid": "group_chat_id_123", + "chattype": "group", + "from": {"userid": "user456"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello Group"} + }` + + // Encrypt message + encrypted, _ := encryptTestMessage(jsonMsg, aesKey) + + // Create encrypted XML wrapper + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: encrypted, + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encrypted) + + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != "success" { + t.Errorf("response body = %q, want %q", w.Body.String(), "success") + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=sig", nil) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid XML", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, "") + + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + strings.NewReader("invalid xml"), + ) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: "encrypted_data", + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComBotProcessMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("process direct text message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_123", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "text", + } + msg.From.UserID = "user123" + msg.Text.Content = "Hello World" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process group text message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_456", + AIBotID: "test_aibot_id", + ChatID: "group_chat_id_123", + ChatType: "group", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "text", + } + msg.From.UserID = "user456" + msg.Text.Content = "Hello Group" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process voice message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_789", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "voice", + } + msg.From.UserID = "user123" + msg.Voice.Content = "Voice message text" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("skip unsupported message type", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_000", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "video", + } + msg.From.UserID = "user123" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) +} + +func TestWeComBotHandleWebhook(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("GET request calls verification", func(t *testing.T) { + echostr := "test_echostr" + encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encoded) + + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, + nil, + ) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + }) + + t.Run("POST request calls message callback", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt) + + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + // Should not be method not allowed + if w.Code == http.StatusMethodNotAllowed { + t.Error("POST request should not return Method Not Allowed") + } + }) + + t.Run("unsupported method", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/webhook/wecom", nil) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) + } + }) +} + +func TestWeComBotHandleHealth(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + req := httptest.NewRequest(http.MethodGet, "/health/wecom", nil) + w := httptest.NewRecorder() + + ch.handleHealth(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Content-Type = %q, want %q", contentType, "application/json") + } + + body := w.Body.String() + if !strings.Contains(body, "status") || !strings.Contains(body, "running") { + t.Errorf("response body should contain status and running fields, got: %s", body) + } +} + +func TestWeComBotReplyMessage(t *testing.T) { + msg := WeComBotReplyMessage{ + MsgType: "text", + } + msg.Text.Content = "Hello World" + + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.Text.Content != "Hello World" { + t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") + } +} + +func TestWeComBotMessageStructure(t *testing.T) { + jsonData := `{ + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chatid": "group_chat_id_123", + "chattype": "group", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }` + + var msg WeComBotMessage + err := json.Unmarshal([]byte(jsonData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if msg.MsgID != "test_msg_id_123" { + t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123") + } + if msg.AIBotID != "test_aibot_id" { + t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id") + } + if msg.ChatID != "group_chat_id_123" { + t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123") + } + if msg.ChatType != "group" { + t.Errorf("ChatType = %q, want %q", msg.ChatType, "group") + } + if msg.From.UserID != "user123" { + t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123") + } + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.Text.Content != "Hello World" { + t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") + } +} diff --git a/pkg/channels/whatsapp.go b/pkg/channels/whatsapp.go index c95e59578..958d850bb 100644 --- a/pkg/channels/whatsapp.go +++ b/pkg/channels/whatsapp.go @@ -86,7 +86,7 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("whatsapp connection not established") } - payload := map[string]interface{}{ + payload := map[string]any{ "type": "message", "to": msg.ChatID, "content": msg.Content, @@ -126,7 +126,7 @@ func (c *WhatsAppChannel) listen(ctx context.Context) { continue } - var msg map[string]interface{} + var msg map[string]any if err := json.Unmarshal(message, &msg); err != nil { log.Printf("Failed to unmarshal WhatsApp message: %v", err) continue @@ -144,7 +144,7 @@ func (c *WhatsAppChannel) listen(ctx context.Context) { } } -func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) { +func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { senderID, ok := msg["from"].(string) if !ok { return @@ -161,7 +161,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) { } var mediaPaths []string - if mediaData, ok := msg["media"].([]interface{}); ok { + if mediaData, ok := msg["media"].([]any); ok { mediaPaths = make([]string, 0, len(mediaData)) for _, m := range mediaData { if path, ok := m.(string); ok { @@ -178,6 +178,14 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) { metadata["user_name"] = userName } + if chatID == senderID { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } else { + metadata["peer_kind"] = "group" + metadata["peer_id"] = chatID + } + log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) diff --git a/pkg/config/config.go b/pkg/config/config.go index 197b95973..2e9e9e784 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -6,10 +6,15 @@ import ( "os" "path/filepath" "sync" + "sync/atomic" + "time" "github.com/caarlos0/env/v11" ) +// rrCounter is a global counter for round-robin load balancing across models. +var rrCounter atomic.Uint64 + // FlexibleStringSlice is a []string that also accepts JSON numbers, // so allow_from can contain both "123" and 123. type FlexibleStringSlice []string @@ -23,7 +28,7 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error { } // Try []interface{} to handle mixed types - var raw []interface{} + var raw []any if err := json.Unmarshal(data, &raw); err != nil { return err } @@ -45,26 +50,137 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error { type Config struct { Agents AgentsConfig `json:"agents"` + Bindings []AgentBinding `json:"bindings,omitempty"` + Session SessionConfig `json:"session,omitempty"` Channels ChannelsConfig `json:"channels"` - Providers ProvidersConfig `json:"providers"` + Providers ProvidersConfig `json:"providers,omitempty"` + ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration Gateway GatewayConfig `json:"gateway"` Tools ToolsConfig `json:"tools"` Heartbeat HeartbeatConfig `json:"heartbeat"` + Devices DevicesConfig `json:"devices"` + Swarm SwarmConfig `json:"swarm"` mu sync.RWMutex } +// MarshalJSON implements custom JSON marshaling for Config +// to omit providers section when empty and session when empty +func (c Config) MarshalJSON() ([]byte, error) { + type Alias Config + aux := &struct { + Providers *ProvidersConfig `json:"providers,omitempty"` + Session *SessionConfig `json:"session,omitempty"` + *Alias + }{ + Alias: (*Alias)(&c), + } + + // Only include providers if not empty + if !c.Providers.IsEmpty() { + aux.Providers = &c.Providers + } + + // Only include session if not empty + if c.Session.DMScope != "" || len(c.Session.IdentityLinks) > 0 { + aux.Session = &c.Session + } + + return json.Marshal(aux) +} + type AgentsConfig struct { Defaults AgentDefaults `json:"defaults"` + List []AgentConfig `json:"list,omitempty"` +} + +// AgentModelConfig supports both string and structured model config. +// String format: "gpt-4" (just primary, no fallbacks) +// Object format: {"primary": "gpt-4", "fallbacks": ["claude-haiku"]} +type AgentModelConfig struct { + Primary string `json:"primary,omitempty"` + Fallbacks []string `json:"fallbacks,omitempty"` +} + +func (m *AgentModelConfig) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err == nil { + m.Primary = s + m.Fallbacks = nil + return nil + } + type raw struct { + Primary string `json:"primary"` + Fallbacks []string `json:"fallbacks"` + } + var r raw + if err := json.Unmarshal(data, &r); err != nil { + return err + } + m.Primary = r.Primary + m.Fallbacks = r.Fallbacks + return nil +} + +func (m AgentModelConfig) MarshalJSON() ([]byte, error) { + if len(m.Fallbacks) == 0 && m.Primary != "" { + return json.Marshal(m.Primary) + } + type raw struct { + Primary string `json:"primary,omitempty"` + Fallbacks []string `json:"fallbacks,omitempty"` + } + return json.Marshal(raw{Primary: m.Primary, Fallbacks: m.Fallbacks}) +} + +type AgentConfig struct { + ID string `json:"id"` + Default bool `json:"default,omitempty"` + Name string `json:"name,omitempty"` + Workspace string `json:"workspace,omitempty"` + Model *AgentModelConfig `json:"model,omitempty"` + Skills []string `json:"skills,omitempty"` + Subagents *SubagentsConfig `json:"subagents,omitempty"` +} + +type SubagentsConfig struct { + AllowAgents []string `json:"allow_agents,omitempty"` + Model *AgentModelConfig `json:"model,omitempty"` +} + +type PeerMatch struct { + Kind string `json:"kind"` + ID string `json:"id"` +} + +type BindingMatch struct { + Channel string `json:"channel"` + AccountID string `json:"account_id,omitempty"` + Peer *PeerMatch `json:"peer,omitempty"` + GuildID string `json:"guild_id,omitempty"` + TeamID string `json:"team_id,omitempty"` +} + +type AgentBinding struct { + AgentID string `json:"agent_id"` + Match BindingMatch `json:"match"` +} + +type SessionConfig struct { + DMScope string `json:"dm_scope,omitempty"` + IdentityLinks map[string][]string `json:"identity_links,omitempty"` } type AgentDefaults struct { - Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` - RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` - Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` - MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` - Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` - MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` + Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` + ModelFallbacks []string `json:"model_fallbacks,omitempty"` + ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` + ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` + MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` } type ChannelsConfig struct { @@ -76,86 +192,223 @@ type ChannelsConfig struct { QQ QQConfig `json:"qq"` DingTalk DingTalkConfig `json:"dingtalk"` Slack SlackConfig `json:"slack"` + LINE LINEConfig `json:"line"` + OneBot OneBotConfig `json:"onebot"` + WeCom WeComConfig `json:"wecom"` + WeComApp WeComAppConfig `json:"wecom_app"` } type WhatsAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"` } type TelegramConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` - Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` + Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` } type FeishuConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` - EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` + EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` } type DiscordConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` + MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` } type MaixCamConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` - Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` - Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` + Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` + Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"` } type QQConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` } type DingTalkConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` - ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` + ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` } type SlackConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` - BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` - AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` + BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` + AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` +} + +type LINEConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"` + ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"` + ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"` +} + +type OneBotConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"` + WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"` + AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"` + ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"` + GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"` +} + +type WeComConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` + WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` +} + +type WeComAppConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` + CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` + CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` + AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` } type HeartbeatConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"` + Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"` Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5 } +type DevicesConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_DEVICES_ENABLED"` + MonitorUSB bool `json:"monitor_usb" env:"PICOCLAW_DEVICES_MONITOR_USB"` +} + type ProvidersConfig struct { - Anthropic ProviderConfig `json:"anthropic"` - OpenAI ProviderConfig `json:"openai"` - OpenRouter ProviderConfig `json:"openrouter"` - Groq ProviderConfig `json:"groq"` - Zhipu ProviderConfig `json:"zhipu"` - VLLM ProviderConfig `json:"vllm"` - Gemini ProviderConfig `json:"gemini"` - Nvidia ProviderConfig `json:"nvidia"` - Moonshot ProviderConfig `json:"moonshot"` + Anthropic ProviderConfig `json:"anthropic"` + OpenAI OpenAIProviderConfig `json:"openai"` + OpenRouter ProviderConfig `json:"openrouter"` + Groq ProviderConfig `json:"groq"` + Zhipu ProviderConfig `json:"zhipu"` + VLLM ProviderConfig `json:"vllm"` + Gemini ProviderConfig `json:"gemini"` + Nvidia ProviderConfig `json:"nvidia"` + Ollama ProviderConfig `json:"ollama"` + Moonshot ProviderConfig `json:"moonshot"` + ShengSuanYun ProviderConfig `json:"shengsuanyun"` + DeepSeek ProviderConfig `json:"deepseek"` + Cerebras ProviderConfig `json:"cerebras"` + VolcEngine ProviderConfig `json:"volcengine"` + GitHubCopilot ProviderConfig `json:"github_copilot"` + Antigravity ProviderConfig `json:"antigravity"` + Qwen ProviderConfig `json:"qwen"` +} + +// IsEmpty checks if all provider configs are empty (no API keys or API bases set) +// Note: WebSearch is an optimization option and doesn't count as "non-empty" +func (p ProvidersConfig) IsEmpty() bool { + return p.Anthropic.APIKey == "" && p.Anthropic.APIBase == "" && + p.OpenAI.APIKey == "" && p.OpenAI.APIBase == "" && + p.OpenRouter.APIKey == "" && p.OpenRouter.APIBase == "" && + p.Groq.APIKey == "" && p.Groq.APIBase == "" && + p.Zhipu.APIKey == "" && p.Zhipu.APIBase == "" && + p.VLLM.APIKey == "" && p.VLLM.APIBase == "" && + p.Gemini.APIKey == "" && p.Gemini.APIBase == "" && + p.Nvidia.APIKey == "" && p.Nvidia.APIBase == "" && + p.Ollama.APIKey == "" && p.Ollama.APIBase == "" && + p.Moonshot.APIKey == "" && p.Moonshot.APIBase == "" && + p.ShengSuanYun.APIKey == "" && p.ShengSuanYun.APIBase == "" && + p.DeepSeek.APIKey == "" && p.DeepSeek.APIBase == "" && + p.Cerebras.APIKey == "" && p.Cerebras.APIBase == "" && + p.VolcEngine.APIKey == "" && p.VolcEngine.APIBase == "" && + p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" && + p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" && + p.Qwen.APIKey == "" && p.Qwen.APIBase == "" +} + +// MarshalJSON implements custom JSON marshaling for ProvidersConfig +// to omit the entire section when empty +func (p ProvidersConfig) MarshalJSON() ([]byte, error) { + if p.IsEmpty() { + return []byte("null"), nil + } + type Alias ProvidersConfig + return json.Marshal((*Alias)(&p)) } type ProviderConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` - APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` - Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` - AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` + APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` + APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` + AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` + ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` // only for Github Copilot, `stdio` or `grpc` +} + +type OpenAIProviderConfig struct { + ProviderConfig + WebSearch bool `json:"web_search" env:"PICOCLAW_PROVIDERS_OPENAI_WEB_SEARCH"` +} + +// ModelConfig represents a model-centric provider configuration. +// It allows adding new providers (especially OpenAI-compatible ones) via configuration only. +// The model field uses protocol prefix format: [protocol/]model-identifier +// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot +// Default protocol is "openai" if no prefix is specified. +type ModelConfig struct { + // Required fields + ModelName string `json:"model_name"` // User-facing alias for the model + Model string `json:"model"` // Protocol/model-identifier (e.g., "openai/gpt-4o", "anthropic/claude-sonnet-4.6") + + // HTTP-based providers + APIBase string `json:"api_base,omitempty"` // API endpoint URL + APIKey string `json:"api_key"` // API authentication key + Proxy string `json:"proxy,omitempty"` // HTTP proxy URL + + // Special providers (CLI-based, OAuth, etc.) + AuthMethod string `json:"auth_method,omitempty"` // Authentication method: oauth, token + ConnectMode string `json:"connect_mode,omitempty"` // Connection mode: stdio, grpc + Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers + + // Optional optimizations + RPM int `json:"rpm,omitempty"` // Requests per minute limit + MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens") +} + +// Validate checks if the ModelConfig has all required fields. +func (c *ModelConfig) Validate() error { + if c.ModelName == "" { + return fmt.Errorf("model_name is required") + } + if c.Model == "" { + return fmt.Errorf("model is required") + } + return nil } type GatewayConfig struct { @@ -163,111 +416,177 @@ type GatewayConfig struct { Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` } -type WebSearchConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_SEARCH_API_KEY"` - MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARCH_MAX_RESULTS"` +type BraveConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"` + APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"` +} + +type DuckDuckGoConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_ENABLED"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"` +} + +type PerplexityConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"` + APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"` } type WebToolsConfig struct { - Search WebSearchConfig `json:"search"` + Brave BraveConfig `json:"brave"` + DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"` + Perplexity PerplexityConfig `json:"perplexity"` +} + +type CronToolsConfig struct { + ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout +} + +type ExecConfig struct { + EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"` + CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"` } type ToolsConfig struct { - Web WebToolsConfig `json:"web"` -} - -func DefaultConfig() *Config { - return &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ - Workspace: "~/.picoclaw/workspace", - RestrictToWorkspace: true, - Provider: "", - Model: "glm-4.7", - MaxTokens: 8192, - Temperature: 0.7, - MaxToolIterations: 20, - }, - }, - Channels: ChannelsConfig{ - WhatsApp: WhatsAppConfig{ - Enabled: false, - BridgeURL: "ws://localhost:3001", - AllowFrom: FlexibleStringSlice{}, - }, - Telegram: TelegramConfig{ - Enabled: false, - Token: "", - AllowFrom: FlexibleStringSlice{}, - }, - Feishu: FeishuConfig{ - Enabled: false, - AppID: "", - AppSecret: "", - EncryptKey: "", - VerificationToken: "", - AllowFrom: FlexibleStringSlice{}, - }, - Discord: DiscordConfig{ - Enabled: false, - Token: "", - AllowFrom: FlexibleStringSlice{}, - }, - MaixCam: MaixCamConfig{ - Enabled: false, - Host: "0.0.0.0", - Port: 18790, - AllowFrom: FlexibleStringSlice{}, - }, - QQ: QQConfig{ - Enabled: false, - AppID: "", - AppSecret: "", - AllowFrom: FlexibleStringSlice{}, - }, - DingTalk: DingTalkConfig{ - Enabled: false, - ClientID: "", - ClientSecret: "", - AllowFrom: FlexibleStringSlice{}, - }, - Slack: SlackConfig{ - Enabled: false, - BotToken: "", - AppToken: "", - AllowFrom: []string{}, - }, - }, - Providers: ProvidersConfig{ - Anthropic: ProviderConfig{}, - OpenAI: ProviderConfig{}, - OpenRouter: ProviderConfig{}, - Groq: ProviderConfig{}, - Zhipu: ProviderConfig{}, - VLLM: ProviderConfig{}, - Gemini: ProviderConfig{}, - Nvidia: ProviderConfig{}, - Moonshot: ProviderConfig{}, - }, - Gateway: GatewayConfig{ - Host: "0.0.0.0", - Port: 18790, - }, - Tools: ToolsConfig{ - Web: WebToolsConfig{ - Search: WebSearchConfig{ - APIKey: "", - MaxResults: 5, - }, - }, - }, - Heartbeat: HeartbeatConfig{ - Enabled: true, - Interval: 30, // default 30 minutes - }, + Web WebToolsConfig `json:"web"` + Cron CronToolsConfig `json:"cron"` + Exec ExecConfig `json:"exec"` + Skills SkillsToolsConfig `json:"skills"` +} + +type SkillsToolsConfig struct { + Registries SkillsRegistriesConfig `json:"registries"` + MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"` + SearchCache SearchCacheConfig `json:"search_cache"` +} + +type SearchCacheConfig struct { + MaxSize int `json:"max_size" env:"PICOCLAW_SKILLS_SEARCH_CACHE_MAX_SIZE"` + TTLSeconds int `json:"ttl_seconds" env:"PICOCLAW_SKILLS_SEARCH_CACHE_TTL_SECONDS"` +} + +type SkillsRegistriesConfig struct { + ClawHub ClawHubRegistryConfig `json:"clawhub"` +} + +type ClawHubRegistryConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_ENABLED"` + BaseURL string `json:"base_url" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_BASE_URL"` + AuthToken string `json:"auth_token" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_AUTH_TOKEN"` + SearchPath string `json:"search_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_SEARCH_PATH"` + SkillsPath string `json:"skills_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_SKILLS_PATH"` + DownloadPath string `json:"download_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_DOWNLOAD_PATH"` + Timeout int `json:"timeout" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_TIMEOUT"` + MaxZipSize int `json:"max_zip_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_ZIP_SIZE"` + MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"` +} + +// SwarmConfig configures the swarm mode for multi-instance collaboration +type SwarmConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_SWARM_ENABLED"` + NodeID string `json:"node_id" env:"PICOCLAW_SWARM_NODE_ID"` + Role string `json:"role" env:"PICOCLAW_SWARM_ROLE"` // coordinator/worker/specialist + Capabilities []string `json:"capabilities" env:"PICOCLAW_SWARM_CAPABILITIES"` + MaxConcurrent int `json:"max_concurrent_tasks" env:"PICOCLAW_SWARM_MAX_CONCURRENT"` + HID string `json:"hid" env:"PICOCLAW_SWARM_HID"` // Host/tenant ID (multi-tenancy) + SID string `json:"sid" env:"PICOCLAW_SWARM_SID"` // Service/instance ID + NATS NATSConfig `json:"nats"` + Temporal TemporalConfig `json:"temporal"` +} + +// NATSConfig configures NATS connection +type NATSConfig struct { + URLs []string `json:"urls" env:"PICOCLAW_SWARM_NATS_URLS"` + Credentials string `json:"credentials" env:"PICOCLAW_SWARM_NATS_CREDENTIALS"` + HeartbeatInterval string `json:"heartbeat_interval" env:"PICOCLAW_SWARM_NATS_HEARTBEAT_INTERVAL"` + NodeTimeout string `json:"node_timeout" env:"PICOCLAW_SWARM_NATS_NODE_TIMEOUT"` + Embedded bool `json:"embedded" env:"PICOCLAW_SWARM_NATS_EMBEDDED"` // Use embedded server + EmbeddedPort int `json:"embedded_port" env:"PICOCLAW_SWARM_NATS_EMBEDDED_PORT"` + EmbeddedHost string `json:"embedded_host" env:"PICOCLAW_SWARM_NATS_EMBEDDED_HOST"` // Embedded server listen address (0.0.0.0 for external access) +} + +// TemporalConfig configures Temporal connection +type TemporalConfig struct { + Host string `json:"host" env:"PICOCLAW_SWARM_TEMPORAL_HOST"` + Namespace string `json:"namespace" env:"PICOCLAW_SWARM_TEMPORAL_NAMESPACE"` + TaskQueue string `json:"task_queue" env:"PICOCLAW_SWARM_TEMPORAL_TASK_QUEUE"` + WorkflowTimeout string `json:"workflow_timeout" env:"PICOCLAW_SWARM_TEMPORAL_WORKFLOW_TIMEOUT"` + ActivityTimeout string `json:"activity_timeout" env:"PICOCLAW_SWARM_TEMPORAL_ACTIVITY_TIMEOUT"` + Model string `json:"model" env:"PICOCLAW_SWARM_TEMPORAL_MODEL"` // LLM model for swarm tasks (e.g., glm-4, gpt-4, claude-sonnet-4-20250515) +} + +// GetHeartbeatInterval returns the heartbeat interval as a duration +func (c *SwarmConfig) GetHeartbeatInterval() time.Duration { + d, err := time.ParseDuration(c.NATS.HeartbeatInterval) + if err != nil { + return 10 * time.Second + } + return d +} + +// GetNodeTimeout returns the node timeout as a duration +func (c *SwarmConfig) GetNodeTimeout() time.Duration { + d, err := time.ParseDuration(c.NATS.NodeTimeout) + if err != nil { + return 30 * time.Second } + return d } +// GetWorkflowTimeout returns the workflow timeout as a duration +func (c *SwarmConfig) GetWorkflowTimeout() time.Duration { + d, err := time.ParseDuration(c.Temporal.WorkflowTimeout) + if err != nil { + return 30 * time.Minute + } + return d +} + +// GetActivityTimeout returns the activity timeout as a duration +func (c *SwarmConfig) GetActivityTimeout() time.Duration { + d, err := time.ParseDuration(c.Temporal.ActivityTimeout) + if err != nil { + return 10 * time.Minute + } + return d +} + +// Validate validates the SwarmConfig +func (c *SwarmConfig) Validate() error { + if !c.Enabled { + return nil + } + + if c.MaxConcurrent <= 0 { + return fmt.Errorf("swarm: max_concurrent must be > 0, got %d", c.MaxConcurrent) + } + + if len(c.NATS.URLs) == 0 && !c.NATS.Embedded { + return fmt.Errorf("swarm: NATS.URLs required when embedded=false") + } + + if c.NATS.HeartbeatInterval != "" { + if _, err := time.ParseDuration(c.NATS.HeartbeatInterval); err != nil { + return fmt.Errorf("swarm: invalid heartbeat_interval: %w", err) + } + } + + if c.NATS.NodeTimeout != "" { + if _, err := time.ParseDuration(c.NATS.NodeTimeout); err != nil { + return fmt.Errorf("swarm: invalid node_timeout: %w", err) + } + } + + validRoles := map[string]bool{"coordinator": true, "worker": true, "specialist": true} + if !validRoles[c.Role] { + return fmt.Errorf("swarm: invalid role %q, must be coordinator/worker/specialist", c.Role) + } + + return nil +} + + func LoadConfig(path string) (*Config, error) { cfg := DefaultConfig() @@ -287,35 +606,38 @@ func LoadConfig(path string) (*Config, error) { return nil, err } + // Auto-migrate: if only legacy providers config exists, convert to model_list + if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() { + cfg.ModelList = ConvertProvidersToModelList(cfg) + } + + // Validate model_list for uniqueness and required fields + if err := cfg.ValidateModelList(); err != nil { + return nil, err + } + return cfg, nil } func SaveConfig(path string, cfg *Config) error { - cfg.mu.RLock() - defer cfg.mu.RUnlock() - data, err := json.MarshalIndent(cfg, "", " ") if err != nil { return err } dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0755); err != nil { + if err := os.MkdirAll(dir, 0o755); err != nil { return err } - return os.WriteFile(path, data, 0644) + return os.WriteFile(path, data, 0o600) } func (c *Config) WorkspacePath() string { - c.mu.RLock() - defer c.mu.RUnlock() return expandHome(c.Agents.Defaults.Workspace) } func (c *Config) GetAPIKey() string { - c.mu.RLock() - defer c.mu.RUnlock() if c.Providers.OpenRouter.APIKey != "" { return c.Providers.OpenRouter.APIKey } @@ -337,12 +659,16 @@ func (c *Config) GetAPIKey() string { if c.Providers.VLLM.APIKey != "" { return c.Providers.VLLM.APIKey } + if c.Providers.ShengSuanYun.APIKey != "" { + return c.Providers.ShengSuanYun.APIKey + } + if c.Providers.Cerebras.APIKey != "" { + return c.Providers.Cerebras.APIKey + } return "" } func (c *Config) GetAPIBase() string { - c.mu.RLock() - defer c.mu.RUnlock() if c.Providers.OpenRouter.APIKey != "" { if c.Providers.OpenRouter.APIBase != "" { return c.Providers.OpenRouter.APIBase @@ -371,3 +697,65 @@ func expandHome(path string) string { } return path } + +// GetModelConfig returns the ModelConfig for the given model name. +// If multiple configs exist with the same model_name, it uses round-robin +// selection for load balancing. Returns an error if the model is not found. +func (c *Config) GetModelConfig(modelName string) (*ModelConfig, error) { + matches := c.findMatches(modelName) + if len(matches) == 0 { + return nil, fmt.Errorf("model %q not found in model_list or providers", modelName) + } + if len(matches) == 1 { + return &matches[0], nil + } + + // Multiple configs - use round-robin for load balancing + idx := rrCounter.Add(1) % uint64(len(matches)) + return &matches[idx], nil +} + +// findMatches finds all ModelConfig entries with the given model_name. +func (c *Config) findMatches(modelName string) []ModelConfig { + var matches []ModelConfig + for i := range c.ModelList { + if c.ModelList[i].ModelName == modelName { + matches = append(matches, c.ModelList[i]) + } + } + return matches +} + +// HasProvidersConfig checks if any provider in the old providers config has configuration. +func (c *Config) HasProvidersConfig() bool { + v := c.Providers + return v.Anthropic.APIKey != "" || v.Anthropic.APIBase != "" || + v.OpenAI.APIKey != "" || v.OpenAI.APIBase != "" || + v.OpenRouter.APIKey != "" || v.OpenRouter.APIBase != "" || + v.Groq.APIKey != "" || v.Groq.APIBase != "" || + v.Zhipu.APIKey != "" || v.Zhipu.APIBase != "" || + v.VLLM.APIKey != "" || v.VLLM.APIBase != "" || + v.Gemini.APIKey != "" || v.Gemini.APIBase != "" || + v.Nvidia.APIKey != "" || v.Nvidia.APIBase != "" || + v.Ollama.APIKey != "" || v.Ollama.APIBase != "" || + v.Moonshot.APIKey != "" || v.Moonshot.APIBase != "" || + v.ShengSuanYun.APIKey != "" || v.ShengSuanYun.APIBase != "" || + v.DeepSeek.APIKey != "" || v.DeepSeek.APIBase != "" || + v.Cerebras.APIKey != "" || v.Cerebras.APIBase != "" || + v.VolcEngine.APIKey != "" || v.VolcEngine.APIBase != "" || + v.GitHubCopilot.APIKey != "" || v.GitHubCopilot.APIBase != "" || + v.Antigravity.APIKey != "" || v.Antigravity.APIBase != "" || + v.Qwen.APIKey != "" || v.Qwen.APIBase != "" +} + +// ValidateModelList validates all ModelConfig entries in the model_list. +// It checks that each model config is valid. +// Note: Multiple entries with the same model_name are allowed for load balancing. +func (c *Config) ValidateModelList() error { + for i := range c.ModelList { + if err := c.ModelList[i].Validate(); err != nil { + return fmt.Errorf("model_list[%d]: %w", i, err) + } + } + return nil +} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 0a5e7b56f..0898217d6 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -1,9 +1,193 @@ package config import ( + "encoding/json" + "os" + "path/filepath" + "runtime" "testing" ) +func TestAgentModelConfig_UnmarshalString(t *testing.T) { + var m AgentModelConfig + if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil { + t.Fatalf("unmarshal string: %v", err) + } + if m.Primary != "gpt-4" { + t.Errorf("Primary = %q, want 'gpt-4'", m.Primary) + } + if m.Fallbacks != nil { + t.Errorf("Fallbacks = %v, want nil", m.Fallbacks) + } +} + +func TestAgentModelConfig_UnmarshalObject(t *testing.T) { + var m AgentModelConfig + data := `{"primary": "claude-opus", "fallbacks": ["gpt-4o-mini", "haiku"]}` + if err := json.Unmarshal([]byte(data), &m); err != nil { + t.Fatalf("unmarshal object: %v", err) + } + if m.Primary != "claude-opus" { + t.Errorf("Primary = %q, want 'claude-opus'", m.Primary) + } + if len(m.Fallbacks) != 2 { + t.Fatalf("Fallbacks len = %d, want 2", len(m.Fallbacks)) + } + if m.Fallbacks[0] != "gpt-4o-mini" || m.Fallbacks[1] != "haiku" { + t.Errorf("Fallbacks = %v", m.Fallbacks) + } +} + +func TestAgentModelConfig_MarshalString(t *testing.T) { + m := AgentModelConfig{Primary: "gpt-4"} + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if string(data) != `"gpt-4"` { + t.Errorf("marshal = %s, want '\"gpt-4\"'", string(data)) + } +} + +func TestAgentModelConfig_MarshalObject(t *testing.T) { + m := AgentModelConfig{Primary: "claude-opus", Fallbacks: []string{"haiku"}} + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var result map[string]any + json.Unmarshal(data, &result) + if result["primary"] != "claude-opus" { + t.Errorf("primary = %v", result["primary"]) + } +} + +func TestAgentConfig_FullParse(t *testing.T) { + jsonData := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "max_tool_iterations": 20 + }, + "list": [ + { + "id": "sales", + "default": true, + "name": "Sales Bot", + "model": "gpt-4" + }, + { + "id": "support", + "name": "Support Bot", + "model": { + "primary": "claude-opus", + "fallbacks": ["haiku"] + }, + "subagents": { + "allow_agents": ["sales"] + } + } + ] + }, + "bindings": [ + { + "agent_id": "support", + "match": { + "channel": "telegram", + "account_id": "*", + "peer": {"kind": "direct", "id": "user123"} + } + } + ], + "session": { + "dm_scope": "per-peer", + "identity_links": { + "john": ["telegram:123", "discord:john#1234"] + } + } + }` + + cfg := DefaultConfig() + if err := json.Unmarshal([]byte(jsonData), cfg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(cfg.Agents.List) != 2 { + t.Fatalf("agents.list len = %d, want 2", len(cfg.Agents.List)) + } + + sales := cfg.Agents.List[0] + if sales.ID != "sales" || !sales.Default || sales.Name != "Sales Bot" { + t.Errorf("sales = %+v", sales) + } + if sales.Model == nil || sales.Model.Primary != "gpt-4" { + t.Errorf("sales.Model = %+v", sales.Model) + } + + support := cfg.Agents.List[1] + if support.ID != "support" || support.Name != "Support Bot" { + t.Errorf("support = %+v", support) + } + if support.Model == nil || support.Model.Primary != "claude-opus" { + t.Errorf("support.Model = %+v", support.Model) + } + if len(support.Model.Fallbacks) != 1 || support.Model.Fallbacks[0] != "haiku" { + t.Errorf("support.Model.Fallbacks = %v", support.Model.Fallbacks) + } + if support.Subagents == nil || len(support.Subagents.AllowAgents) != 1 { + t.Errorf("support.Subagents = %+v", support.Subagents) + } + + if len(cfg.Bindings) != 1 { + t.Fatalf("bindings len = %d, want 1", len(cfg.Bindings)) + } + binding := cfg.Bindings[0] + if binding.AgentID != "support" || binding.Match.Channel != "telegram" { + t.Errorf("binding = %+v", binding) + } + if binding.Match.Peer == nil || binding.Match.Peer.Kind != "direct" || binding.Match.Peer.ID != "user123" { + t.Errorf("binding.Match.Peer = %+v", binding.Match.Peer) + } + + if cfg.Session.DMScope != "per-peer" { + t.Errorf("Session.DMScope = %q", cfg.Session.DMScope) + } + if len(cfg.Session.IdentityLinks) != 1 { + t.Errorf("Session.IdentityLinks = %v", cfg.Session.IdentityLinks) + } + links := cfg.Session.IdentityLinks["john"] + if len(links) != 2 { + t.Errorf("john links = %v", links) + } +} + +func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) { + jsonData := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "max_tool_iterations": 20 + } + } + }` + + cfg := DefaultConfig() + if err := json.Unmarshal([]byte(jsonData), cfg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(cfg.Agents.List) != 0 { + t.Errorf("agents.list should be empty for backward compat, got %d", len(cfg.Agents.List)) + } + if len(cfg.Bindings) != 0 { + t.Errorf("bindings should be empty, got %d", len(cfg.Bindings)) + } +} + // TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default func TestDefaultConfig_HeartbeatEnabled(t *testing.T) { cfg := DefaultConfig() @@ -17,8 +201,6 @@ func TestDefaultConfig_HeartbeatEnabled(t *testing.T) { func TestDefaultConfig_WorkspacePath(t *testing.T) { cfg := DefaultConfig() - // Just verify the workspace is set, don't compare exact paths - // since expandHome behavior may differ based on environment if cfg.Agents.Defaults.Workspace == "" { t.Error("Workspace should not be empty") } @@ -55,8 +237,8 @@ func TestDefaultConfig_MaxToolIterations(t *testing.T) { func TestDefaultConfig_Temperature(t *testing.T) { cfg := DefaultConfig() - if cfg.Agents.Defaults.Temperature == 0 { - t.Error("Temperature should not be zero") + if cfg.Agents.Defaults.Temperature != nil { + t.Error("Temperature should be nil when not provided") } } @@ -76,7 +258,6 @@ func TestDefaultConfig_Gateway(t *testing.T) { func TestDefaultConfig_Providers(t *testing.T) { cfg := DefaultConfig() - // Verify all providers are empty by default if cfg.Providers.Anthropic.APIKey != "" { t.Error("Anthropic API key should be empty by default") } @@ -86,46 +267,18 @@ func TestDefaultConfig_Providers(t *testing.T) { if cfg.Providers.OpenRouter.APIKey != "" { t.Error("OpenRouter API key should be empty by default") } - if cfg.Providers.Groq.APIKey != "" { - t.Error("Groq API key should be empty by default") - } - if cfg.Providers.Zhipu.APIKey != "" { - t.Error("Zhipu API key should be empty by default") - } - if cfg.Providers.VLLM.APIKey != "" { - t.Error("VLLM API key should be empty by default") - } - if cfg.Providers.Gemini.APIKey != "" { - t.Error("Gemini API key should be empty by default") - } } // TestDefaultConfig_Channels verifies channels are disabled by default func TestDefaultConfig_Channels(t *testing.T) { cfg := DefaultConfig() - // Verify all channels are disabled by default - if cfg.Channels.WhatsApp.Enabled { - t.Error("WhatsApp should be disabled by default") - } if cfg.Channels.Telegram.Enabled { t.Error("Telegram should be disabled by default") } - if cfg.Channels.Feishu.Enabled { - t.Error("Feishu should be disabled by default") - } if cfg.Channels.Discord.Enabled { t.Error("Discord should be disabled by default") } - if cfg.Channels.MaixCam.Enabled { - t.Error("MaixCam should be disabled by default") - } - if cfg.Channels.QQ.Enabled { - t.Error("QQ should be disabled by default") - } - if cfg.Channels.DingTalk.Enabled { - t.Error("DingTalk should be disabled by default") - } if cfg.Channels.Slack.Enabled { t.Error("Slack should be disabled by default") } @@ -136,11 +289,38 @@ func TestDefaultConfig_WebTools(t *testing.T) { cfg := DefaultConfig() // Verify web tools defaults - if cfg.Tools.Web.Search.MaxResults != 5 { - t.Error("Expected MaxResults 5, got ", cfg.Tools.Web.Search.MaxResults) + if cfg.Tools.Web.Brave.MaxResults != 5 { + t.Error("Expected Brave MaxResults 5, got ", cfg.Tools.Web.Brave.MaxResults) + } + if cfg.Tools.Web.Brave.APIKey != "" { + t.Error("Brave API key should be empty by default") } - if cfg.Tools.Web.Search.APIKey != "" { - t.Error("Search API key should be empty by default") + if cfg.Tools.Web.DuckDuckGo.MaxResults != 5 { + t.Error("Expected DuckDuckGo MaxResults 5, got ", cfg.Tools.Web.DuckDuckGo.MaxResults) + } +} + +func TestSaveConfig_FilePermissions(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("file permission bits are not enforced on Windows") + } + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "config.json") + + cfg := DefaultConfig() + if err := SaveConfig(path, cfg); err != nil { + t.Fatalf("SaveConfig failed: %v", err) + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat failed: %v", err) + } + + perm := info.Mode().Perm() + if perm != 0o600 { + t.Errorf("config file has permission %04o, want 0600", perm) } } @@ -148,15 +328,14 @@ func TestDefaultConfig_WebTools(t *testing.T) { func TestConfig_Complete(t *testing.T) { cfg := DefaultConfig() - // Verify complete config structure if cfg.Agents.Defaults.Workspace == "" { t.Error("Workspace should not be empty") } if cfg.Agents.Defaults.Model == "" { t.Error("Model should not be empty") } - if cfg.Agents.Defaults.Temperature == 0 { - t.Error("Temperature should have default value") + if cfg.Agents.Defaults.Temperature != nil { + t.Error("Temperature should be nil when not provided") } if cfg.Agents.Defaults.MaxTokens == 0 { t.Error("MaxTokens should not be zero") @@ -174,3 +353,42 @@ func TestConfig_Complete(t *testing.T) { t.Error("Heartbeat should be enabled by default") } } + +func TestDefaultConfig_OpenAIWebSearchEnabled(t *testing.T) { + cfg := DefaultConfig() + if !cfg.Providers.OpenAI.WebSearch { + t.Fatal("DefaultConfig().Providers.OpenAI.WebSearch should be true") + } +} + +func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"api_base":""}}}`), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if !cfg.Providers.OpenAI.WebSearch { + t.Fatal("OpenAI codex web search should remain true when unset in config file") + } +} + +func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"web_search":false}}}`), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Providers.OpenAI.WebSearch { + t.Fatal("OpenAI codex web search should be false when disabled in config file") + } +} diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go new file mode 100644 index 000000000..7654326e7 --- /dev/null +++ b/pkg/config/defaults.go @@ -0,0 +1,316 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package config + +// DefaultConfig returns the default configuration for PicoClaw. +func DefaultConfig() *Config { + return &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Workspace: "~/.picoclaw/workspace", + RestrictToWorkspace: true, + Provider: "", + Model: "glm-4.7", + MaxTokens: 8192, + Temperature: nil, // nil means use provider default + MaxToolIterations: 20, + }, + }, + Bindings: []AgentBinding{}, + Session: SessionConfig{ + DMScope: "main", + }, + Channels: ChannelsConfig{ + WhatsApp: WhatsAppConfig{ + Enabled: false, + BridgeURL: "ws://localhost:3001", + AllowFrom: FlexibleStringSlice{}, + }, + Telegram: TelegramConfig{ + Enabled: false, + Token: "", + AllowFrom: FlexibleStringSlice{}, + }, + Feishu: FeishuConfig{ + Enabled: false, + AppID: "", + AppSecret: "", + EncryptKey: "", + VerificationToken: "", + AllowFrom: FlexibleStringSlice{}, + }, + Discord: DiscordConfig{ + Enabled: false, + Token: "", + AllowFrom: FlexibleStringSlice{}, + MentionOnly: false, + }, + MaixCam: MaixCamConfig{ + Enabled: false, + Host: "0.0.0.0", + Port: 18790, + AllowFrom: FlexibleStringSlice{}, + }, + QQ: QQConfig{ + Enabled: false, + AppID: "", + AppSecret: "", + AllowFrom: FlexibleStringSlice{}, + }, + DingTalk: DingTalkConfig{ + Enabled: false, + ClientID: "", + ClientSecret: "", + AllowFrom: FlexibleStringSlice{}, + }, + Slack: SlackConfig{ + Enabled: false, + BotToken: "", + AppToken: "", + AllowFrom: FlexibleStringSlice{}, + }, + LINE: LINEConfig{ + Enabled: false, + ChannelSecret: "", + ChannelAccessToken: "", + WebhookHost: "0.0.0.0", + WebhookPort: 18791, + WebhookPath: "/webhook/line", + AllowFrom: FlexibleStringSlice{}, + }, + OneBot: OneBotConfig{ + Enabled: false, + WSUrl: "ws://127.0.0.1:3001", + AccessToken: "", + ReconnectInterval: 5, + GroupTriggerPrefix: []string{}, + AllowFrom: FlexibleStringSlice{}, + }, + WeCom: WeComConfig{ + Enabled: false, + Token: "", + EncodingAESKey: "", + WebhookURL: "", + WebhookHost: "0.0.0.0", + WebhookPort: 18793, + WebhookPath: "/webhook/wecom", + AllowFrom: FlexibleStringSlice{}, + ReplyTimeout: 5, + }, + WeComApp: WeComAppConfig{ + Enabled: false, + CorpID: "", + CorpSecret: "", + AgentID: 0, + Token: "", + EncodingAESKey: "", + WebhookHost: "0.0.0.0", + WebhookPort: 18792, + WebhookPath: "/webhook/wecom-app", + AllowFrom: FlexibleStringSlice{}, + ReplyTimeout: 5, + }, + }, + Providers: ProvidersConfig{ + OpenAI: OpenAIProviderConfig{WebSearch: true}, + }, + ModelList: []ModelConfig{ + // ============================================ + // Add your API key to the model you want to use + // ============================================ + + // Zhipu AI (智谱) - https://open.bigmodel.cn/usercenter/apikeys + { + ModelName: "glm-4.7", + Model: "zhipu/glm-4.7", + APIBase: "https://open.bigmodel.cn/api/paas/v4", + APIKey: "", + }, + + // OpenAI - https://platform.openai.com/api-keys + { + ModelName: "gpt-5.2", + Model: "openai/gpt-5.2", + APIBase: "https://api.openai.com/v1", + APIKey: "", + }, + + // Anthropic Claude - https://console.anthropic.com/settings/keys + { + ModelName: "claude-sonnet-4.6", + Model: "anthropic/claude-sonnet-4.6", + APIBase: "https://api.anthropic.com/v1", + APIKey: "", + }, + + // DeepSeek - https://platform.deepseek.com/ + { + ModelName: "deepseek-chat", + Model: "deepseek/deepseek-chat", + APIBase: "https://api.deepseek.com/v1", + APIKey: "", + }, + + // Google Gemini - https://ai.google.dev/ + { + ModelName: "gemini-2.0-flash", + Model: "gemini/gemini-2.0-flash-exp", + APIBase: "https://generativelanguage.googleapis.com/v1beta", + APIKey: "", + }, + + // Qwen (通义千问) - https://dashscope.console.aliyun.com/apiKey + { + ModelName: "qwen-plus", + Model: "qwen/qwen-plus", + APIBase: "https://dashscope.aliyuncs.com/compatible-mode/v1", + APIKey: "", + }, + + // Moonshot (月之暗面) - https://platform.moonshot.cn/console/api-keys + { + ModelName: "moonshot-v1-8k", + Model: "moonshot/moonshot-v1-8k", + APIBase: "https://api.moonshot.cn/v1", + APIKey: "", + }, + + // Groq - https://console.groq.com/keys + { + ModelName: "llama-3.3-70b", + Model: "groq/llama-3.3-70b-versatile", + APIBase: "https://api.groq.com/openai/v1", + APIKey: "", + }, + + // OpenRouter (100+ models) - https://openrouter.ai/keys + { + ModelName: "openrouter-auto", + Model: "openrouter/auto", + APIBase: "https://openrouter.ai/api/v1", + APIKey: "", + }, + { + ModelName: "openrouter-gpt-5.2", + Model: "openrouter/openai/gpt-5.2", + APIBase: "https://openrouter.ai/api/v1", + APIKey: "", + }, + + // NVIDIA - https://build.nvidia.com/ + { + ModelName: "nemotron-4-340b", + Model: "nvidia/nemotron-4-340b-instruct", + APIBase: "https://integrate.api.nvidia.com/v1", + APIKey: "", + }, + + // Cerebras - https://inference.cerebras.ai/ + { + ModelName: "cerebras-llama-3.3-70b", + Model: "cerebras/llama-3.3-70b", + APIBase: "https://api.cerebras.ai/v1", + APIKey: "", + }, + + // Volcengine (火山引擎) - https://console.volcengine.com/ark + { + ModelName: "doubao-pro", + Model: "volcengine/doubao-pro-32k", + APIBase: "https://ark.cn-beijing.volces.com/api/v3", + APIKey: "", + }, + + // ShengsuanYun (神算云) + { + ModelName: "deepseek-v3", + Model: "shengsuanyun/deepseek-v3", + APIBase: "https://api.shengsuanyun.com/v1", + APIKey: "", + }, + + // Antigravity (Google Cloud Code Assist) - OAuth only + { + ModelName: "gemini-flash", + Model: "antigravity/gemini-3-flash", + AuthMethod: "oauth", + }, + + // GitHub Copilot - https://github.com/settings/tokens + { + ModelName: "copilot-gpt-5.2", + Model: "github-copilot/gpt-5.2", + APIBase: "http://localhost:4321", + AuthMethod: "oauth", + }, + + // Ollama (local) - https://ollama.com + { + ModelName: "llama3", + Model: "ollama/llama3", + APIBase: "http://localhost:11434/v1", + APIKey: "ollama", + }, + + // VLLM (local) - http://localhost:8000 + { + ModelName: "local-model", + Model: "vllm/custom-model", + APIBase: "http://localhost:8000/v1", + APIKey: "", + }, + }, + Gateway: GatewayConfig{ + Host: "0.0.0.0", + Port: 18790, + }, + Tools: ToolsConfig{ + Web: WebToolsConfig{ + Brave: BraveConfig{ + Enabled: false, + APIKey: "", + MaxResults: 5, + }, + DuckDuckGo: DuckDuckGoConfig{ + Enabled: true, + MaxResults: 5, + }, + Perplexity: PerplexityConfig{ + Enabled: false, + APIKey: "", + MaxResults: 5, + }, + }, + Cron: CronToolsConfig{ + ExecTimeoutMinutes: 5, + }, + Exec: ExecConfig{ + EnableDenyPatterns: true, + }, + Skills: SkillsToolsConfig{ + Registries: SkillsRegistriesConfig{ + ClawHub: ClawHubRegistryConfig{ + Enabled: true, + BaseURL: "https://clawhub.ai", + }, + }, + MaxConcurrentSearches: 2, + SearchCache: SearchCacheConfig{ + MaxSize: 50, + TTLSeconds: 300, + }, + }, + }, + Heartbeat: HeartbeatConfig{ + Enabled: true, + Interval: 30, + }, + Devices: DevicesConfig{ + Enabled: false, + MonitorUSB: true, + }, + } +} diff --git a/pkg/config/migration.go b/pkg/config/migration.go new file mode 100644 index 000000000..689e2312f --- /dev/null +++ b/pkg/config/migration.go @@ -0,0 +1,353 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package config + +import ( + "slices" + "strings" +) + +// buildModelWithProtocol constructs a model string with protocol prefix. +// If the model already contains a "/" (indicating it has a protocol prefix), it is returned as-is. +// Otherwise, the protocol prefix is added. +func buildModelWithProtocol(protocol, model string) string { + if strings.Contains(model, "/") { + // Model already has a protocol prefix, return as-is + return model + } + return protocol + "/" + model +} + +// providerMigrationConfig defines how to migrate a provider from old config to new format. +type providerMigrationConfig struct { + // providerNames are the possible names used in agents.defaults.provider + providerNames []string + // protocol is the protocol prefix for the model field + protocol string + // buildConfig creates the ModelConfig from ProviderConfig + buildConfig func(p ProvidersConfig) (ModelConfig, bool) +} + +// ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig. +// This enables backward compatibility with existing configurations. +// It preserves the user's configured model from agents.defaults.model when possible. +func ConvertProvidersToModelList(cfg *Config) []ModelConfig { + if cfg == nil { + return nil + } + + // Get user's configured provider and model + userProvider := strings.ToLower(cfg.Agents.Defaults.Provider) + userModel := cfg.Agents.Defaults.Model + + p := cfg.Providers + + var result []ModelConfig + + // Track if we've applied the legacy model name fix (only for first provider) + legacyModelNameApplied := false + + // Define migration rules for each provider + migrations := []providerMigrationConfig{ + { + providerNames: []string{"openai", "gpt"}, + protocol: "openai", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.OpenAI.APIKey == "" && p.OpenAI.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "openai", + Model: "openai/gpt-5.2", + APIKey: p.OpenAI.APIKey, + APIBase: p.OpenAI.APIBase, + Proxy: p.OpenAI.Proxy, + AuthMethod: p.OpenAI.AuthMethod, + }, true + }, + }, + { + providerNames: []string{"anthropic", "claude"}, + protocol: "anthropic", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Anthropic.APIKey == "" && p.Anthropic.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "anthropic", + Model: "anthropic/claude-sonnet-4.6", + APIKey: p.Anthropic.APIKey, + APIBase: p.Anthropic.APIBase, + Proxy: p.Anthropic.Proxy, + AuthMethod: p.Anthropic.AuthMethod, + }, true + }, + }, + { + providerNames: []string{"openrouter"}, + protocol: "openrouter", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.OpenRouter.APIKey == "" && p.OpenRouter.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "openrouter", + Model: "openrouter/auto", + APIKey: p.OpenRouter.APIKey, + APIBase: p.OpenRouter.APIBase, + Proxy: p.OpenRouter.Proxy, + }, true + }, + }, + { + providerNames: []string{"groq"}, + protocol: "groq", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Groq.APIKey == "" && p.Groq.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "groq", + Model: "groq/llama-3.1-70b-versatile", + APIKey: p.Groq.APIKey, + APIBase: p.Groq.APIBase, + Proxy: p.Groq.Proxy, + }, true + }, + }, + { + providerNames: []string{"zhipu", "glm"}, + protocol: "zhipu", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Zhipu.APIKey == "" && p.Zhipu.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "zhipu", + Model: "zhipu/glm-4", + APIKey: p.Zhipu.APIKey, + APIBase: p.Zhipu.APIBase, + Proxy: p.Zhipu.Proxy, + }, true + }, + }, + { + providerNames: []string{"vllm"}, + protocol: "vllm", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.VLLM.APIKey == "" && p.VLLM.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "vllm", + Model: "vllm/auto", + APIKey: p.VLLM.APIKey, + APIBase: p.VLLM.APIBase, + Proxy: p.VLLM.Proxy, + }, true + }, + }, + { + providerNames: []string{"gemini", "google"}, + protocol: "gemini", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Gemini.APIKey == "" && p.Gemini.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "gemini", + Model: "gemini/gemini-pro", + APIKey: p.Gemini.APIKey, + APIBase: p.Gemini.APIBase, + Proxy: p.Gemini.Proxy, + }, true + }, + }, + { + providerNames: []string{"nvidia"}, + protocol: "nvidia", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Nvidia.APIKey == "" && p.Nvidia.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "nvidia", + Model: "nvidia/meta/llama-3.1-8b-instruct", + APIKey: p.Nvidia.APIKey, + APIBase: p.Nvidia.APIBase, + Proxy: p.Nvidia.Proxy, + }, true + }, + }, + { + providerNames: []string{"ollama"}, + protocol: "ollama", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Ollama.APIKey == "" && p.Ollama.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "ollama", + Model: "ollama/llama3", + APIKey: p.Ollama.APIKey, + APIBase: p.Ollama.APIBase, + Proxy: p.Ollama.Proxy, + }, true + }, + }, + { + providerNames: []string{"moonshot", "kimi"}, + protocol: "moonshot", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Moonshot.APIKey == "" && p.Moonshot.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "moonshot", + Model: "moonshot/kimi", + APIKey: p.Moonshot.APIKey, + APIBase: p.Moonshot.APIBase, + Proxy: p.Moonshot.Proxy, + }, true + }, + }, + { + providerNames: []string{"shengsuanyun"}, + protocol: "shengsuanyun", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.ShengSuanYun.APIKey == "" && p.ShengSuanYun.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "shengsuanyun", + Model: "shengsuanyun/auto", + APIKey: p.ShengSuanYun.APIKey, + APIBase: p.ShengSuanYun.APIBase, + Proxy: p.ShengSuanYun.Proxy, + }, true + }, + }, + { + providerNames: []string{"deepseek"}, + protocol: "deepseek", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.DeepSeek.APIKey == "" && p.DeepSeek.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "deepseek", + Model: "deepseek/deepseek-chat", + APIKey: p.DeepSeek.APIKey, + APIBase: p.DeepSeek.APIBase, + Proxy: p.DeepSeek.Proxy, + }, true + }, + }, + { + providerNames: []string{"cerebras"}, + protocol: "cerebras", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Cerebras.APIKey == "" && p.Cerebras.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "cerebras", + Model: "cerebras/llama-3.3-70b", + APIKey: p.Cerebras.APIKey, + APIBase: p.Cerebras.APIBase, + Proxy: p.Cerebras.Proxy, + }, true + }, + }, + { + providerNames: []string{"volcengine", "doubao"}, + protocol: "volcengine", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.VolcEngine.APIKey == "" && p.VolcEngine.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "volcengine", + Model: "volcengine/doubao-pro", + APIKey: p.VolcEngine.APIKey, + APIBase: p.VolcEngine.APIBase, + Proxy: p.VolcEngine.Proxy, + }, true + }, + }, + { + providerNames: []string{"github_copilot", "copilot"}, + protocol: "github-copilot", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" && p.GitHubCopilot.ConnectMode == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "github-copilot", + Model: "github-copilot/gpt-5.2", + APIBase: p.GitHubCopilot.APIBase, + ConnectMode: p.GitHubCopilot.ConnectMode, + }, true + }, + }, + { + providerNames: []string{"antigravity"}, + protocol: "antigravity", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Antigravity.APIKey == "" && p.Antigravity.AuthMethod == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "antigravity", + Model: "antigravity/gemini-2.0-flash", + APIKey: p.Antigravity.APIKey, + AuthMethod: p.Antigravity.AuthMethod, + }, true + }, + }, + { + providerNames: []string{"qwen", "tongyi"}, + protocol: "qwen", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Qwen.APIKey == "" && p.Qwen.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "qwen", + Model: "qwen/qwen-max", + APIKey: p.Qwen.APIKey, + APIBase: p.Qwen.APIBase, + Proxy: p.Qwen.Proxy, + }, true + }, + }, + } + + // Process each provider migration + for _, m := range migrations { + mc, ok := m.buildConfig(p) + if !ok { + continue + } + + // Check if this is the user's configured provider + if slices.Contains(m.providerNames, userProvider) && userModel != "" { + // Use the user's configured model instead of default + mc.Model = buildModelWithProtocol(m.protocol, userModel) + } else if userProvider == "" && userModel != "" && !legacyModelNameApplied { + // Legacy config: no explicit provider field but model is specified + // Use userModel as ModelName for the FIRST provider so GetModelConfig(model) can find it + // This maintains backward compatibility with old configs that relied on implicit provider selection + mc.ModelName = userModel + mc.Model = buildModelWithProtocol(m.protocol, userModel) + legacyModelNameApplied = true + } + + result = append(result, mc) + } + + return result +} diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go new file mode 100644 index 000000000..1e8139e68 --- /dev/null +++ b/pkg/config/migration_test.go @@ -0,0 +1,561 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package config + +import ( + "strings" + "testing" +) + +func TestConvertProvidersToModelList_OpenAI(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + OpenAI: OpenAIProviderConfig{ + ProviderConfig: ProviderConfig{ + APIKey: "sk-test-key", + APIBase: "https://custom.api.com/v1", + }, + }, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].ModelName != "openai" { + t.Errorf("ModelName = %q, want %q", result[0].ModelName, "openai") + } + if result[0].Model != "openai/gpt-5.2" { + t.Errorf("Model = %q, want %q", result[0].Model, "openai/gpt-5.2") + } + if result[0].APIKey != "sk-test-key" { + t.Errorf("APIKey = %q, want %q", result[0].APIKey, "sk-test-key") + } +} + +func TestConvertProvidersToModelList_Anthropic(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + Anthropic: ProviderConfig{ + APIKey: "ant-key", + APIBase: "https://custom.anthropic.com", + }, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].ModelName != "anthropic" { + t.Errorf("ModelName = %q, want %q", result[0].ModelName, "anthropic") + } + if result[0].Model != "anthropic/claude-sonnet-4.6" { + t.Errorf("Model = %q, want %q", result[0].Model, "anthropic/claude-sonnet-4.6") + } +} + +func TestConvertProvidersToModelList_Multiple(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "openai-key"}}, + Groq: ProviderConfig{APIKey: "groq-key"}, + Zhipu: ProviderConfig{APIKey: "zhipu-key"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 3 { + t.Fatalf("len(result) = %d, want 3", len(result)) + } + + // Check that all providers are present + found := make(map[string]bool) + for _, mc := range result { + found[mc.ModelName] = true + } + + for _, name := range []string{"openai", "groq", "zhipu"} { + if !found[name] { + t.Errorf("Missing provider %q in result", name) + } + } +} + +func TestConvertProvidersToModelList_Empty(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{}, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestConvertProvidersToModelList_Nil(t *testing.T) { + result := ConvertProvidersToModelList(nil) + + if result != nil { + t.Errorf("result = %v, want nil", result) + } +} + +func TestConvertProvidersToModelList_AllProviders(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "key1"}}, + Anthropic: ProviderConfig{APIKey: "key2"}, + OpenRouter: ProviderConfig{APIKey: "key3"}, + Groq: ProviderConfig{APIKey: "key4"}, + Zhipu: ProviderConfig{APIKey: "key5"}, + VLLM: ProviderConfig{APIKey: "key6"}, + Gemini: ProviderConfig{APIKey: "key7"}, + Nvidia: ProviderConfig{APIKey: "key8"}, + Ollama: ProviderConfig{APIKey: "key9"}, + Moonshot: ProviderConfig{APIKey: "key10"}, + ShengSuanYun: ProviderConfig{APIKey: "key11"}, + DeepSeek: ProviderConfig{APIKey: "key12"}, + Cerebras: ProviderConfig{APIKey: "key13"}, + VolcEngine: ProviderConfig{APIKey: "key14"}, + GitHubCopilot: ProviderConfig{ConnectMode: "grpc"}, + Antigravity: ProviderConfig{AuthMethod: "oauth"}, + Qwen: ProviderConfig{APIKey: "key17"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + // All 17 providers should be converted + if len(result) != 17 { + t.Errorf("len(result) = %d, want 17", len(result)) + } +} + +func TestConvertProvidersToModelList_Proxy(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + OpenAI: OpenAIProviderConfig{ + ProviderConfig: ProviderConfig{ + APIKey: "key", + Proxy: "http://proxy:8080", + }, + }, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].Proxy != "http://proxy:8080" { + t.Errorf("Proxy = %q, want %q", result[0].Proxy, "http://proxy:8080") + } +} + +func TestConvertProvidersToModelList_AuthMethod(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + OpenAI: OpenAIProviderConfig{ + ProviderConfig: ProviderConfig{ + AuthMethod: "oauth", + }, + }, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0 (AuthMethod alone should not create entry)", len(result)) + } +} + +// Tests for preserving user's configured model during migration + +func TestConvertProvidersToModelList_PreservesUserModel_DeepSeek(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "deepseek", + Model: "deepseek-reasoner", + }, + }, + Providers: ProvidersConfig{ + DeepSeek: ProviderConfig{APIKey: "sk-deepseek"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + // Should use user's model, not default + if result[0].Model != "deepseek/deepseek-reasoner" { + t.Errorf("Model = %q, want %q (user's configured model)", result[0].Model, "deepseek/deepseek-reasoner") + } +} + +func TestConvertProvidersToModelList_PreservesUserModel_OpenAI(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "openai", + Model: "gpt-4-turbo", + }, + }, + Providers: ProvidersConfig{ + OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "sk-openai"}}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].Model != "openai/gpt-4-turbo" { + t.Errorf("Model = %q, want %q", result[0].Model, "openai/gpt-4-turbo") + } +} + +func TestConvertProvidersToModelList_PreservesUserModel_Anthropic(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "claude", // alternative name + Model: "claude-opus-4-20250514", + }, + }, + Providers: ProvidersConfig{ + Anthropic: ProviderConfig{APIKey: "sk-ant"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].Model != "anthropic/claude-opus-4-20250514" { + t.Errorf("Model = %q, want %q", result[0].Model, "anthropic/claude-opus-4-20250514") + } +} + +func TestConvertProvidersToModelList_PreservesUserModel_Qwen(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "qwen", + Model: "qwen-plus", + }, + }, + Providers: ProvidersConfig{ + Qwen: ProviderConfig{APIKey: "sk-qwen"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].Model != "qwen/qwen-plus" { + t.Errorf("Model = %q, want %q", result[0].Model, "qwen/qwen-plus") + } +} + +func TestConvertProvidersToModelList_UsesDefaultWhenNoUserModel(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "deepseek", + Model: "", // no model specified + }, + }, + Providers: ProvidersConfig{ + DeepSeek: ProviderConfig{APIKey: "sk-deepseek"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + // Should use default model + if result[0].Model != "deepseek/deepseek-chat" { + t.Errorf("Model = %q, want %q (default)", result[0].Model, "deepseek/deepseek-chat") + } +} + +func TestConvertProvidersToModelList_MultipleProviders_PreservesUserModel(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "deepseek", + Model: "deepseek-reasoner", + }, + }, + Providers: ProvidersConfig{ + OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "sk-openai"}}, + DeepSeek: ProviderConfig{APIKey: "sk-deepseek"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 2 { + t.Fatalf("len(result) = %d, want 2", len(result)) + } + + // Find each provider and verify model + for _, mc := range result { + switch mc.ModelName { + case "openai": + if mc.Model != "openai/gpt-5.2" { + t.Errorf("OpenAI Model = %q, want %q (default)", mc.Model, "openai/gpt-5.2") + } + case "deepseek": + if mc.Model != "deepseek/deepseek-reasoner" { + t.Errorf("DeepSeek Model = %q, want %q (user's)", mc.Model, "deepseek/deepseek-reasoner") + } + } + } +} + +func TestConvertProvidersToModelList_ProviderNameAliases(t *testing.T) { + tests := []struct { + providerAlias string + expectedModel string + provider ProviderConfig + }{ + {"gpt", "openai/gpt-4-custom", ProviderConfig{APIKey: "key"}}, + {"claude", "anthropic/claude-custom", ProviderConfig{APIKey: "key"}}, + {"doubao", "volcengine/doubao-custom", ProviderConfig{APIKey: "key"}}, + {"tongyi", "qwen/qwen-custom", ProviderConfig{APIKey: "key"}}, + {"kimi", "moonshot/kimi-custom", ProviderConfig{APIKey: "key"}}, + } + + for _, tt := range tests { + t.Run(tt.providerAlias, func(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: tt.providerAlias, + Model: strings.TrimPrefix( + tt.expectedModel, + tt.expectedModel[:strings.Index(tt.expectedModel, "/")+1], + ), + }, + }, + Providers: ProvidersConfig{}, + } + + // Set the appropriate provider config + switch tt.providerAlias { + case "gpt": + cfg.Providers.OpenAI = OpenAIProviderConfig{ProviderConfig: tt.provider} + case "claude": + cfg.Providers.Anthropic = tt.provider + case "doubao": + cfg.Providers.VolcEngine = tt.provider + case "tongyi": + cfg.Providers.Qwen = tt.provider + case "kimi": + cfg.Providers.Moonshot = tt.provider + } + + // Need to fix the model name in config + cfg.Agents.Defaults.Model = strings.TrimPrefix( + tt.expectedModel, + tt.expectedModel[:strings.Index(tt.expectedModel, "/")+1], + ) + + result := ConvertProvidersToModelList(cfg) + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + // Extract just the model ID part (after the first /) + expectedModelID := tt.expectedModel + if result[0].Model != expectedModelID { + t.Errorf("Model = %q, want %q", result[0].Model, expectedModelID) + } + }) + } +} + +// Test for backward compatibility: single provider without explicit provider field +// This matches the legacy config pattern where users only set model, not provider + +func TestConvertProvidersToModelList_NoProviderField_SingleProvider(t *testing.T) { + // This matches the user's actual config: + // - No provider field set + // - model = "glm-4.7" + // - Only zhipu has API key configured + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "", // Not set + Model: "glm-4.7", + }, + }, + Providers: ProvidersConfig{ + Zhipu: ProviderConfig{APIKey: "test-zhipu-key"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + // ModelName should be the user's model value for backward compatibility + if result[0].ModelName != "glm-4.7" { + t.Errorf("ModelName = %q, want %q (user's model for backward compatibility)", result[0].ModelName, "glm-4.7") + } + + // Model should use the user's model with protocol prefix + if result[0].Model != "zhipu/glm-4.7" { + t.Errorf("Model = %q, want %q", result[0].Model, "zhipu/glm-4.7") + } +} + +func TestConvertProvidersToModelList_NoProviderField_MultipleProviders(t *testing.T) { + // When multiple providers are configured but no provider field is set, + // the FIRST provider (in migration order) will use userModel as ModelName + // for backward compatibility with legacy implicit provider selection + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "", // Not set + Model: "some-model", + }, + }, + Providers: ProvidersConfig{ + OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "openai-key"}}, + Zhipu: ProviderConfig{APIKey: "zhipu-key"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 2 { + t.Fatalf("len(result) = %d, want 2", len(result)) + } + + // The first provider (OpenAI in migration order) should use userModel as ModelName + // This ensures GetModelConfig("some-model") will find it + if result[0].ModelName != "some-model" { + t.Errorf("First provider ModelName = %q, want %q", result[0].ModelName, "some-model") + } + + // Other providers should use provider name as ModelName + if result[1].ModelName != "zhipu" { + t.Errorf("Second provider ModelName = %q, want %q", result[1].ModelName, "zhipu") + } +} + +func TestConvertProvidersToModelList_NoProviderField_NoModel(t *testing.T) { + // Edge case: no provider, no model + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "", + Model: "", + }, + }, + Providers: ProvidersConfig{ + Zhipu: ProviderConfig{APIKey: "zhipu-key"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + // Should use default provider name since no model is specified + if result[0].ModelName != "zhipu" { + t.Errorf("ModelName = %q, want %q", result[0].ModelName, "zhipu") + } +} + +// Tests for buildModelWithProtocol helper function + +func TestBuildModelWithProtocol_NoPrefix(t *testing.T) { + result := buildModelWithProtocol("openai", "gpt-5.2") + if result != "openai/gpt-5.2" { + t.Errorf("buildModelWithProtocol(openai, gpt-5.2) = %q, want %q", result, "openai/gpt-5.2") + } +} + +func TestBuildModelWithProtocol_AlreadyHasPrefix(t *testing.T) { + result := buildModelWithProtocol("openrouter", "openrouter/auto") + if result != "openrouter/auto" { + t.Errorf("buildModelWithProtocol(openrouter, openrouter/auto) = %q, want %q", result, "openrouter/auto") + } +} + +func TestBuildModelWithProtocol_DifferentPrefix(t *testing.T) { + result := buildModelWithProtocol("anthropic", "openrouter/claude-sonnet-4.6") + if result != "openrouter/claude-sonnet-4.6" { + t.Errorf( + "buildModelWithProtocol(anthropic, openrouter/claude-sonnet-4.6) = %q, want %q", + result, + "openrouter/claude-sonnet-4.6", + ) + } +} + +// Test for legacy config with protocol prefix in model name +func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "", // No explicit provider + Model: "openrouter/auto", // Model already has protocol prefix + }, + }, + Providers: ProvidersConfig{ + OpenRouter: ProviderConfig{APIKey: "sk-or-test"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) < 1 { + t.Fatalf("len(result) = %d, want at least 1", len(result)) + } + + // First provider should use userModel as ModelName for backward compatibility + if result[0].ModelName != "openrouter/auto" { + t.Errorf("ModelName = %q, want %q", result[0].ModelName, "openrouter/auto") + } + + // Model should NOT have duplicated prefix + if result[0].Model != "openrouter/auto" { + t.Errorf("Model = %q, want %q (should not duplicate prefix)", result[0].Model, "openrouter/auto") + } +} diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go new file mode 100644 index 000000000..3c411dc0f --- /dev/null +++ b/pkg/config/model_config_test.go @@ -0,0 +1,235 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package config + +import ( + "strings" + "sync" + "testing" +) + +func TestGetModelConfig_Found(t *testing.T) { + cfg := &Config{ + ModelList: []ModelConfig{ + {ModelName: "test-model", Model: "openai/gpt-4o", APIKey: "key1"}, + {ModelName: "other-model", Model: "anthropic/claude", APIKey: "key2"}, + }, + } + + result, err := cfg.GetModelConfig("test-model") + if err != nil { + t.Fatalf("GetModelConfig() error = %v", err) + } + if result.Model != "openai/gpt-4o" { + t.Errorf("Model = %q, want %q", result.Model, "openai/gpt-4o") + } +} + +func TestGetModelConfig_NotFound(t *testing.T) { + cfg := &Config{ + ModelList: []ModelConfig{ + {ModelName: "test-model", Model: "openai/gpt-4o", APIKey: "key1"}, + }, + } + + _, err := cfg.GetModelConfig("nonexistent") + if err == nil { + t.Fatal("GetModelConfig() expected error for nonexistent model") + } +} + +func TestGetModelConfig_EmptyList(t *testing.T) { + cfg := &Config{ + ModelList: []ModelConfig{}, + } + + _, err := cfg.GetModelConfig("any-model") + if err == nil { + t.Fatal("GetModelConfig() expected error for empty model list") + } +} + +func TestGetModelConfig_RoundRobin(t *testing.T) { + cfg := &Config{ + ModelList: []ModelConfig{ + {ModelName: "lb-model", Model: "openai/gpt-4o-1", APIKey: "key1"}, + {ModelName: "lb-model", Model: "openai/gpt-4o-2", APIKey: "key2"}, + {ModelName: "lb-model", Model: "openai/gpt-4o-3", APIKey: "key3"}, + }, + } + + // Test round-robin distribution + results := make(map[string]int) + for i := 0; i < 30; i++ { + result, err := cfg.GetModelConfig("lb-model") + if err != nil { + t.Fatalf("GetModelConfig() error = %v", err) + } + results[result.Model]++ + } + + // Each model should appear roughly 10 times (30 calls / 3 models) + for model, count := range results { + if count < 5 || count > 15 { + t.Errorf("Model %s appeared %d times, expected ~10", model, count) + } + } +} + +func TestGetModelConfig_Concurrent(t *testing.T) { + cfg := &Config{ + ModelList: []ModelConfig{ + {ModelName: "concurrent-model", Model: "openai/gpt-4o-1", APIKey: "key1"}, + {ModelName: "concurrent-model", Model: "openai/gpt-4o-2", APIKey: "key2"}, + }, + } + + const goroutines = 100 + const iterations = 10 + + var wg sync.WaitGroup + errors := make(chan error, goroutines*iterations) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + _, err := cfg.GetModelConfig("concurrent-model") + if err != nil { + errors <- err + } + } + }() + } + + wg.Wait() + close(errors) + + for err := range errors { + t.Errorf("Concurrent GetModelConfig() error: %v", err) + } +} + +func TestModelConfig_Validate(t *testing.T) { + tests := []struct { + name string + config ModelConfig + wantErr bool + }{ + { + name: "valid config", + config: ModelConfig{ + ModelName: "test", + Model: "openai/gpt-4o", + }, + wantErr: false, + }, + { + name: "missing model_name", + config: ModelConfig{ + Model: "openai/gpt-4o", + }, + wantErr: true, + }, + { + name: "missing model", + config: ModelConfig{ + ModelName: "test", + }, + wantErr: true, + }, + { + name: "empty config", + config: ModelConfig{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfig_ValidateModelList(t *testing.T) { + tests := []struct { + name string + config *Config + wantErr bool + errMsg string // partial error message to check + }{ + { + name: "valid list", + config: &Config{ + ModelList: []ModelConfig{ + {ModelName: "test1", Model: "openai/gpt-4o"}, + {ModelName: "test2", Model: "anthropic/claude"}, + }, + }, + wantErr: false, + }, + { + name: "invalid entry", + config: &Config{ + ModelList: []ModelConfig{ + {ModelName: "test1", Model: "openai/gpt-4o"}, + {ModelName: "", Model: "anthropic/claude"}, // missing model_name + }, + }, + wantErr: true, + errMsg: "model_name is required", + }, + { + name: "empty list", + config: &Config{ + ModelList: []ModelConfig{}, + }, + wantErr: false, + }, + { + // Load balancing: multiple entries with same model_name are allowed + name: "duplicate model_name for load balancing", + config: &Config{ + ModelList: []ModelConfig{ + {ModelName: "gpt-4", Model: "openai/gpt-4o", APIKey: "key1"}, + {ModelName: "gpt-4", Model: "openai/gpt-4-turbo", APIKey: "key2"}, + }, + }, + wantErr: false, // Changed: duplicates are allowed for load balancing + }, + { + // Load balancing: non-adjacent entries with same model_name are also allowed + name: "duplicate model_name non-adjacent for load balancing", + config: &Config{ + ModelList: []ModelConfig{ + {ModelName: "model-a", Model: "openai/gpt-4o"}, + {ModelName: "model-b", Model: "anthropic/claude"}, + {ModelName: "model-a", Model: "openai/gpt-4-turbo"}, + }, + }, + wantErr: false, // Changed: duplicates are allowed for load balancing + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.ValidateModelList() + if (err != nil) != tt.wantErr { + t.Errorf("ValidateModelList() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil && tt.errMsg != "" { + if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("ValidateModelList() error = %v, want error containing %q", err, tt.errMsg) + } + } + }) + } +} diff --git a/pkg/constants/channels.go b/pkg/constants/channels.go index 3e3df3839..0a46e6cd9 100644 --- a/pkg/constants/channels.go +++ b/pkg/constants/channels.go @@ -1,15 +1,16 @@ // Package constants provides shared constants across the codebase. package constants -// InternalChannels defines channels that are used for internal communication +// internalChannels defines channels that are used for internal communication // and should not be exposed to external users or recorded as last active channel. -var InternalChannels = map[string]bool{ - "cli": true, - "system": true, - "subagent": true, +var internalChannels = map[string]struct{}{ + "cli": {}, + "system": {}, + "subagent": {}, } // IsInternalChannel returns true if the channel is an internal channel. func IsInternalChannel(channel string) bool { - return InternalChannels[channel] + _, found := internalChannels[channel] + return found } diff --git a/pkg/cron/service.go b/pkg/cron/service.go index 841db0ff6..e699a44b5 100644 --- a/pkg/cron/service.go +++ b/pkg/cron/service.go @@ -71,7 +71,6 @@ func NewCronService(storePath string, onJob JobHandler) *CronService { cs := &CronService{ storePath: storePath, onJob: onJob, - stopChan: make(chan struct{}), gronx: gronx.New(), } // Initialize and load store on creation @@ -96,8 +95,9 @@ func (cs *CronService) Start() error { return fmt.Errorf("failed to save store: %w", err) } + cs.stopChan = make(chan struct{}) cs.running = true - go cs.runLoop() + go cs.runLoop(cs.stopChan) return nil } @@ -111,16 +111,19 @@ func (cs *CronService) Stop() { } cs.running = false - close(cs.stopChan) + if cs.stopChan != nil { + close(cs.stopChan) + cs.stopChan = nil + } } -func (cs *CronService) runLoop() { +func (cs *CronService) runLoop(stopChan chan struct{}) { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { - case <-cs.stopChan: + case <-stopChan: return case <-ticker.C: cs.checkJobs() @@ -137,27 +140,23 @@ func (cs *CronService) checkJobs() { } now := time.Now().UnixMilli() - var dueJobs []*CronJob + var dueJobIDs []string // Collect jobs that are due (we need to copy them to execute outside lock) for i := range cs.store.Jobs { job := &cs.store.Jobs[i] if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now { - // Create a shallow copy of the job for execution - jobCopy := *job - dueJobs = append(dueJobs, &jobCopy) + dueJobIDs = append(dueJobIDs, job.ID) } } - // Update next run times for due jobs immediately (before executing) - // Use map for O(n) lookup instead of O(n²) nested loop - dueMap := make(map[string]bool, len(dueJobs)) - for _, job := range dueJobs { - dueMap[job.ID] = true + // Reset next run for due jobs before unlocking to avoid duplicate execution. + dueMap := make(map[string]bool, len(dueJobIDs)) + for _, jobID := range dueJobIDs { + dueMap[jobID] = true } for i := range cs.store.Jobs { if dueMap[cs.store.Jobs[i].ID] { - // Reset NextRunAtMS temporarily so we don't re-execute cs.store.Jobs[i].State.NextRunAtMS = nil } } @@ -168,52 +167,74 @@ func (cs *CronService) checkJobs() { cs.mu.Unlock() - // Execute jobs outside the lock - for _, job := range dueJobs { - cs.executeJob(job) + // Execute jobs outside lock. + for _, jobID := range dueJobIDs { + cs.executeJobByID(jobID) } } -func (cs *CronService) executeJob(job *CronJob) { +func (cs *CronService) executeJobByID(jobID string) { startTime := time.Now().UnixMilli() + cs.mu.RLock() + var callbackJob *CronJob + for i := range cs.store.Jobs { + job := &cs.store.Jobs[i] + if job.ID == jobID { + jobCopy := *job + callbackJob = &jobCopy + break + } + } + cs.mu.RUnlock() + + if callbackJob == nil { + return + } + var err error if cs.onJob != nil { - _, err = cs.onJob(job) + _, err = cs.onJob(callbackJob) } // Now acquire lock to update state cs.mu.Lock() defer cs.mu.Unlock() - // Find the job in store and update it + var job *CronJob for i := range cs.store.Jobs { - if cs.store.Jobs[i].ID == job.ID { - cs.store.Jobs[i].State.LastRunAtMS = &startTime - cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli() + if cs.store.Jobs[i].ID == jobID { + job = &cs.store.Jobs[i] + break + } + } + if job == nil { + log.Printf("[cron] job %s disappeared before state update", jobID) + return + } - if err != nil { - cs.store.Jobs[i].State.LastStatus = "error" - cs.store.Jobs[i].State.LastError = err.Error() - } else { - cs.store.Jobs[i].State.LastStatus = "ok" - cs.store.Jobs[i].State.LastError = "" - } + job.State.LastRunAtMS = &startTime + job.UpdatedAtMS = time.Now().UnixMilli() - // Compute next run time - if cs.store.Jobs[i].Schedule.Kind == "at" { - if cs.store.Jobs[i].DeleteAfterRun { - cs.removeJobUnsafe(job.ID) - } else { - cs.store.Jobs[i].Enabled = false - cs.store.Jobs[i].State.NextRunAtMS = nil - } - } else { - nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli()) - cs.store.Jobs[i].State.NextRunAtMS = nextRun - } - break + if err != nil { + job.State.LastStatus = "error" + job.State.LastError = err.Error() + } else { + job.State.LastStatus = "ok" + job.State.LastError = "" + } + + // Compute next run time + if job.Schedule.Kind == "at" { + if job.DeleteAfterRun { + cs.removeJobUnsafe(job.ID) + } else { + job.Enabled = false + job.State.NextRunAtMS = nil } + } else { + nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli()) + job.State.NextRunAtMS = nextRun } if err := cs.saveStoreUnsafe(); err != nil { @@ -310,7 +331,7 @@ func (cs *CronService) loadStore() error { func (cs *CronService) saveStoreUnsafe() error { dir := filepath.Dir(cs.storePath) - if err := os.MkdirAll(dir, 0755); err != nil { + if err := os.MkdirAll(dir, 0o755); err != nil { return err } @@ -319,10 +340,16 @@ func (cs *CronService) saveStoreUnsafe() error { return err } - return os.WriteFile(cs.storePath, data, 0644) + return os.WriteFile(cs.storePath, data, 0o600) } -func (cs *CronService) AddJob(name string, schedule CronSchedule, message string, deliver bool, channel, to string) (*CronJob, error) { +func (cs *CronService) AddJob( + name string, + schedule CronSchedule, + message string, + deliver bool, + channel, to string, +) (*CronJob, error) { cs.mu.Lock() defer cs.mu.Unlock() @@ -444,7 +471,7 @@ func (cs *CronService) ListJobs(includeDisabled bool) []CronJob { return enabled } -func (cs *CronService) Status() map[string]interface{} { +func (cs *CronService) Status() map[string]any { cs.mu.RLock() defer cs.mu.RUnlock() @@ -455,7 +482,7 @@ func (cs *CronService) Status() map[string]interface{} { } } - return map[string]interface{}{ + return map[string]any{ "enabled": cs.running, "jobs": len(cs.store.Jobs), "nextWakeAtMS": cs.getNextWakeMS(), diff --git a/pkg/cron/service_test.go b/pkg/cron/service_test.go new file mode 100644 index 000000000..1a0dd1829 --- /dev/null +++ b/pkg/cron/service_test.go @@ -0,0 +1,38 @@ +package cron + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestSaveStore_FilePermissions(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("file permission bits are not enforced on Windows") + } + + tmpDir := t.TempDir() + storePath := filepath.Join(tmpDir, "cron", "jobs.json") + + cs := NewCronService(storePath, nil) + + _, err := cs.AddJob("test", CronSchedule{Kind: "every", EveryMS: int64Ptr(60000)}, "hello", false, "cli", "direct") + if err != nil { + t.Fatalf("AddJob failed: %v", err) + } + + info, err := os.Stat(storePath) + if err != nil { + t.Fatalf("Stat failed: %v", err) + } + + perm := info.Mode().Perm() + if perm != 0o600 { + t.Errorf("cron store has permission %04o, want 0600", perm) + } +} + +func int64Ptr(v int64) *int64 { + return &v +} diff --git a/pkg/devices/events/events.go b/pkg/devices/events/events.go new file mode 100644 index 000000000..01226179c --- /dev/null +++ b/pkg/devices/events/events.go @@ -0,0 +1,57 @@ +package events + +import "context" + +type EventSource interface { + Kind() Kind + Start(ctx context.Context) (<-chan *DeviceEvent, error) + Stop() error +} + +type Action string + +const ( + ActionAdd Action = "add" + ActionRemove Action = "remove" + ActionChange Action = "change" +) + +type Kind string + +const ( + KindUSB Kind = "usb" + KindBluetooth Kind = "bluetooth" + KindPCI Kind = "pci" + KindGeneric Kind = "generic" +) + +type DeviceEvent struct { + Action Action + Kind Kind + DeviceID string // e.g. "1-2" for USB bus 1 dev 2 + Vendor string // Vendor name or ID + Product string // Product name or ID + Serial string // Serial number if available + Capabilities string // Human-readable capability description + Raw map[string]string // Raw properties for extensibility +} + +func (e *DeviceEvent) FormatMessage() string { + actionEmoji := "🔌" + actionText := "Connected" + if e.Action == ActionRemove { + actionEmoji = "🔌" + actionText = "Disconnected" + } + + msg := actionEmoji + " Device " + actionText + "\n\n" + msg += "Type: " + string(e.Kind) + "\n" + msg += "Device: " + e.Vendor + " " + e.Product + "\n" + if e.Capabilities != "" { + msg += "Capabilities: " + e.Capabilities + "\n" + } + if e.Serial != "" { + msg += "Serial: " + e.Serial + "\n" + } + return msg +} diff --git a/pkg/devices/service.go b/pkg/devices/service.go new file mode 100644 index 000000000..1541d3c57 --- /dev/null +++ b/pkg/devices/service.go @@ -0,0 +1,152 @@ +package devices + +import ( + "context" + "strings" + "sync" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/constants" + "github.com/sipeed/picoclaw/pkg/devices/events" + "github.com/sipeed/picoclaw/pkg/devices/sources" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/state" +) + +type Service struct { + bus *bus.MessageBus + state *state.Manager + sources []events.EventSource + enabled bool + ctx context.Context + cancel context.CancelFunc + mu sync.RWMutex +} + +type Config struct { + Enabled bool + MonitorUSB bool // When true, monitor USB hotplug (Linux only) + // Future: MonitorBluetooth, MonitorPCI, etc. +} + +func NewService(cfg Config, stateMgr *state.Manager) *Service { + s := &Service{ + state: stateMgr, + enabled: cfg.Enabled, + sources: make([]EventSource, 0), + } + + if cfg.Enabled && cfg.MonitorUSB { + s.sources = append(s.sources, sources.NewUSBMonitor()) + } + + return s +} + +func (s *Service) SetBus(msgBus *bus.MessageBus) { + s.mu.Lock() + defer s.mu.Unlock() + s.bus = msgBus +} + +func (s *Service) Start(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.enabled || len(s.sources) == 0 { + logger.InfoC("devices", "Device event service disabled or no sources") + return nil + } + + s.ctx, s.cancel = context.WithCancel(ctx) + + for _, src := range s.sources { + eventCh, err := src.Start(s.ctx) + if err != nil { + logger.ErrorCF("devices", "Failed to start source", map[string]any{ + "kind": src.Kind(), + "error": err.Error(), + }) + continue + } + go s.handleEvents(src.Kind(), eventCh) + logger.InfoCF("devices", "Device source started", map[string]any{ + "kind": src.Kind(), + }) + } + + logger.InfoC("devices", "Device event service started") + return nil +} + +func (s *Service) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.cancel != nil { + s.cancel() + s.cancel = nil + } + + for _, src := range s.sources { + src.Stop() + } + + logger.InfoC("devices", "Device event service stopped") +} + +func (s *Service) handleEvents(kind events.Kind, eventCh <-chan *events.DeviceEvent) { + for ev := range eventCh { + if ev == nil { + continue + } + s.sendNotification(ev) + } +} + +func (s *Service) sendNotification(ev *events.DeviceEvent) { + s.mu.RLock() + msgBus := s.bus + s.mu.RUnlock() + + if msgBus == nil { + return + } + + lastChannel := s.state.GetLastChannel() + if lastChannel == "" { + logger.DebugCF("devices", "No last channel, skipping notification", map[string]any{ + "event": ev.FormatMessage(), + }) + return + } + + platform, userID := parseLastChannel(lastChannel) + if platform == "" || userID == "" || constants.IsInternalChannel(platform) { + return + } + + msg := ev.FormatMessage() + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: platform, + ChatID: userID, + Content: msg, + }) + + logger.InfoCF("devices", "Device notification sent", map[string]any{ + "kind": ev.Kind, + "action": ev.Action, + "to": platform, + }) +} + +func parseLastChannel(lastChannel string) (platform, userID string) { + if lastChannel == "" { + return "", "" + } + parts := strings.SplitN(lastChannel, ":", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", "" + } + return parts[0], parts[1] +} diff --git a/pkg/devices/source.go b/pkg/devices/source.go new file mode 100644 index 000000000..cbf0a7d88 --- /dev/null +++ b/pkg/devices/source.go @@ -0,0 +1,5 @@ +package devices + +import "github.com/sipeed/picoclaw/pkg/devices/events" + +type EventSource = events.EventSource diff --git a/pkg/devices/sources/usb_linux.go b/pkg/devices/sources/usb_linux.go new file mode 100644 index 000000000..be0193cfb --- /dev/null +++ b/pkg/devices/sources/usb_linux.go @@ -0,0 +1,198 @@ +//go:build linux + +package sources + +import ( + "bufio" + "context" + "fmt" + "os/exec" + "strings" + "sync" + + "github.com/sipeed/picoclaw/pkg/devices/events" + "github.com/sipeed/picoclaw/pkg/logger" +) + +var usbClassToCapability = map[string]string{ + "00": "Interface Definition (by interface)", + "01": "Audio", + "02": "CDC Communication (Network Card/Modem)", + "03": "HID (Keyboard/Mouse/Gamepad)", + "05": "Physical Interface", + "06": "Image (Scanner/Camera)", + "07": "Printer", + "08": "Mass Storage (USB Flash Drive/Hard Disk)", + "09": "USB Hub", + "0a": "CDC Data", + "0b": "Smart Card", + "0e": "Video (Camera)", + "dc": "Diagnostic Device", + "e0": "Wireless Controller (Bluetooth)", + "ef": "Miscellaneous", + "fe": "Application Specific", + "ff": "Vendor Specific", +} + +type USBMonitor struct { + cmd *exec.Cmd + cancel context.CancelFunc + mu sync.Mutex +} + +func NewUSBMonitor() *USBMonitor { + return &USBMonitor{} +} + +func (m *USBMonitor) Kind() events.Kind { + return events.KindUSB +} + +func (m *USBMonitor) Start(ctx context.Context) (<-chan *events.DeviceEvent, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // udevadm monitor outputs: UDEV/KERNEL [timestamp] action devpath (subsystem) + // Followed by KEY=value lines, empty line separates events + // Use -s/--subsystem-match (eudev) or --udev-subsystem-match (systemd udev) + cmd := exec.CommandContext(ctx, "udevadm", "monitor", "--property", "--subsystem-match=usb") + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("udevadm stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("udevadm start: %w (is udevadm installed?)", err) + } + + m.cmd = cmd + eventCh := make(chan *events.DeviceEvent, 16) + + go func() { + defer close(eventCh) + scanner := bufio.NewScanner(stdout) + var props map[string]string + var action string + isUdev := false // Only UDEV events have complete info (ID_VENDOR, ID_MODEL); KERNEL events come first with less info + + for scanner.Scan() { + line := scanner.Text() + if line == "" { + // End of event block - only process UDEV events (skip KERNEL to avoid duplicate/incomplete notifications) + if isUdev && props != nil && (action == "add" || action == "remove") { + if ev := parseUSBEvent(action, props); ev != nil { + select { + case eventCh <- ev: + case <-ctx.Done(): + return + } + } + } + props = nil + action = "" + isUdev = false + continue + } + + idx := strings.Index(line, "=") + // First line of block: "UDEV [ts] action devpath" or "KERNEL[ts] action devpath" - no KEY=value + if idx <= 0 { + isUdev = strings.HasPrefix(strings.TrimSpace(line), "UDEV") + continue + } + + // Parse KEY=value + key := line[:idx] + val := line[idx+1:] + if props == nil { + props = make(map[string]string) + } + props[key] = val + + if key == "ACTION" { + action = val + } + } + + if err := scanner.Err(); err != nil { + logger.ErrorCF("devices", "udevadm scan error", map[string]any{"error": err.Error()}) + } + cmd.Wait() + }() + + return eventCh, nil +} + +func (m *USBMonitor) Stop() error { + m.mu.Lock() + defer m.mu.Unlock() + if m.cmd != nil && m.cmd.Process != nil { + m.cmd.Process.Kill() + m.cmd = nil + } + return nil +} + +func parseUSBEvent(action string, props map[string]string) *events.DeviceEvent { + // Only care about add/remove for physical devices (not interfaces) + subsystem := props["SUBSYSTEM"] + if subsystem != "usb" { + return nil + } + // Skip interface events - we want device-level only to avoid duplicates + devType := props["DEVTYPE"] + if devType == "usb_interface" { + return nil + } + // Prefer usb_device, but accept if DEVTYPE not set (varies by udev version) + if devType != "" && devType != "usb_device" { + return nil + } + + ev := &events.DeviceEvent{ + Raw: props, + } + switch action { + case "add": + ev.Action = events.ActionAdd + case "remove": + ev.Action = events.ActionRemove + default: + return nil + } + ev.Kind = events.KindUSB + + ev.Vendor = props["ID_VENDOR"] + if ev.Vendor == "" { + ev.Vendor = props["ID_VENDOR_ID"] + } + if ev.Vendor == "" { + ev.Vendor = "Unknown Vendor" + } + + ev.Product = props["ID_MODEL"] + if ev.Product == "" { + ev.Product = props["ID_MODEL_ID"] + } + if ev.Product == "" { + ev.Product = "Unknown Device" + } + + ev.Serial = props["ID_SERIAL_SHORT"] + ev.DeviceID = props["DEVPATH"] + if bus := props["BUSNUM"]; bus != "" { + if dev := props["DEVNUM"]; dev != "" { + ev.DeviceID = bus + ":" + dev + } + } + + // Map USB class to capability + if class := props["ID_USB_CLASS"]; class != "" { + ev.Capabilities = usbClassToCapability[strings.ToLower(class)] + } + if ev.Capabilities == "" { + ev.Capabilities = "USB Device" + } + + return ev +} diff --git a/pkg/devices/sources/usb_stub.go b/pkg/devices/sources/usb_stub.go new file mode 100644 index 000000000..f08c2d406 --- /dev/null +++ b/pkg/devices/sources/usb_stub.go @@ -0,0 +1,29 @@ +//go:build !linux + +package sources + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/devices/events" +) + +type USBMonitor struct{} + +func NewUSBMonitor() *USBMonitor { + return &USBMonitor{} +} + +func (m *USBMonitor) Kind() events.Kind { + return events.KindUSB +} + +func (m *USBMonitor) Start(ctx context.Context) (<-chan *events.DeviceEvent, error) { + ch := make(chan *events.DeviceEvent) + close(ch) // Immediately close, no events + return ch, nil +} + +func (m *USBMonitor) Stop() error { + return nil +} diff --git a/pkg/health/server.go b/pkg/health/server.go new file mode 100644 index 000000000..77b36034d --- /dev/null +++ b/pkg/health/server.go @@ -0,0 +1,164 @@ +package health + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" +) + +type Server struct { + server *http.Server + mu sync.RWMutex + ready bool + checks map[string]Check + startTime time.Time +} + +type Check struct { + Name string `json:"name"` + Status string `json:"status"` + Message string `json:"message,omitempty"` + Timestamp time.Time `json:"timestamp"` +} + +type StatusResponse struct { + Status string `json:"status"` + Uptime string `json:"uptime"` + Checks map[string]Check `json:"checks,omitempty"` +} + +func NewServer(host string, port int) *Server { + mux := http.NewServeMux() + s := &Server{ + ready: false, + checks: make(map[string]Check), + startTime: time.Now(), + } + + mux.HandleFunc("/health", s.healthHandler) + mux.HandleFunc("/ready", s.readyHandler) + + addr := fmt.Sprintf("%s:%d", host, port) + s.server = &http.Server{ + Addr: addr, + Handler: mux, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + } + + return s +} + +func (s *Server) Start() error { + s.mu.Lock() + s.ready = true + s.mu.Unlock() + return s.server.ListenAndServe() +} + +func (s *Server) StartContext(ctx context.Context) error { + s.mu.Lock() + s.ready = true + s.mu.Unlock() + + errCh := make(chan error, 1) + go func() { + errCh <- s.server.ListenAndServe() + }() + + select { + case err := <-errCh: + return err + case <-ctx.Done(): + return s.server.Shutdown(context.Background()) + } +} + +func (s *Server) Stop(ctx context.Context) error { + s.mu.Lock() + s.ready = false + s.mu.Unlock() + return s.server.Shutdown(ctx) +} + +func (s *Server) SetReady(ready bool) { + s.mu.Lock() + s.ready = ready + s.mu.Unlock() +} + +func (s *Server) RegisterCheck(name string, checkFn func() (bool, string)) { + s.mu.Lock() + defer s.mu.Unlock() + + status, msg := checkFn() + s.checks[name] = Check{ + Name: name, + Status: statusString(status), + Message: msg, + Timestamp: time.Now(), + } +} + +func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + uptime := time.Since(s.startTime) + resp := StatusResponse{ + Status: "ok", + Uptime: uptime.String(), + } + + json.NewEncoder(w).Encode(resp) +} + +func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + s.mu.RLock() + ready := s.ready + checks := make(map[string]Check) + for k, v := range s.checks { + checks[k] = v + } + s.mu.RUnlock() + + if !ready { + w.WriteHeader(http.StatusServiceUnavailable) + json.NewEncoder(w).Encode(StatusResponse{ + Status: "not ready", + Checks: checks, + }) + return + } + + for _, check := range checks { + if check.Status == "fail" { + w.WriteHeader(http.StatusServiceUnavailable) + json.NewEncoder(w).Encode(StatusResponse{ + Status: "not ready", + Checks: checks, + }) + return + } + } + + w.WriteHeader(http.StatusOK) + uptime := time.Since(s.startTime) + json.NewEncoder(w).Encode(StatusResponse{ + Status: "ready", + Uptime: uptime.String(), + Checks: checks, + }) +} + +func statusString(ok bool) string { + if ok { + return "ok" + } + return "fail" +} diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index a090cdac0..75d6248b9 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -40,7 +40,6 @@ type HeartbeatService struct { interval time.Duration enabled bool mu sync.RWMutex - started bool stopChan chan struct{} } @@ -60,7 +59,6 @@ func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *H interval: time.Duration(intervalMinutes) * time.Minute, enabled: enabled, state: state.NewManager(workspace), - stopChan: make(chan struct{}), } } @@ -83,7 +81,7 @@ func (hs *HeartbeatService) Start() error { hs.mu.Lock() defer hs.mu.Unlock() - if hs.started { + if hs.stopChan != nil { logger.InfoC("heartbeat", "Heartbeat service already running") return nil } @@ -93,10 +91,8 @@ func (hs *HeartbeatService) Start() error { return nil } - hs.started = true hs.stopChan = make(chan struct{}) - - go hs.runLoop() + go hs.runLoop(hs.stopChan) logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{ "interval_minutes": hs.interval.Minutes(), @@ -110,24 +106,24 @@ func (hs *HeartbeatService) Stop() { hs.mu.Lock() defer hs.mu.Unlock() - if !hs.started { + if hs.stopChan == nil { return } logger.InfoC("heartbeat", "Stopping heartbeat service") close(hs.stopChan) - hs.started = false + hs.stopChan = nil } // IsRunning returns whether the service is running func (hs *HeartbeatService) IsRunning() bool { hs.mu.RLock() defer hs.mu.RUnlock() - return hs.started + return hs.stopChan != nil } // runLoop runs the heartbeat ticker -func (hs *HeartbeatService) runLoop() { +func (hs *HeartbeatService) runLoop(stopChan chan struct{}) { ticker := time.NewTicker(hs.interval) defer ticker.Stop() @@ -138,7 +134,7 @@ func (hs *HeartbeatService) runLoop() { for { select { - case <-hs.stopChan: + case <-stopChan: return case <-ticker.C: hs.executeHeartbeat() @@ -149,8 +145,12 @@ func (hs *HeartbeatService) runLoop() { // executeHeartbeat performs a single heartbeat check func (hs *HeartbeatService) executeHeartbeat() { hs.mu.RLock() - enabled := hs.enabled && hs.started + enabled := hs.enabled handler := hs.handler + if !hs.enabled || hs.stopChan == nil { + hs.mu.RUnlock() + return + } hs.mu.RUnlock() if !enabled { @@ -193,7 +193,7 @@ func (hs *HeartbeatService) executeHeartbeat() { if result.Async { hs.logInfo("Async task started: %s", result.ForLLM) logger.InfoCF("heartbeat", "Async heartbeat task started", - map[string]interface{}{ + map[string]any{ "message": result.ForLLM, }) return @@ -275,7 +275,7 @@ This file contains tasks for the heartbeat service to check periodically. Add your heartbeat tasks below this line: ` - if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil { + if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0o644); err != nil { hs.logError("Failed to create default HEARTBEAT.md: %v", err) } else { hs.logInfo("Created default HEARTBEAT.md template") @@ -354,7 +354,7 @@ func (hs *HeartbeatService) logError(format string, args ...any) { // log writes a message to the heartbeat log file func (hs *HeartbeatService) log(level, format string, args ...any) { logFile := filepath.Join(hs.workspace, "heartbeat.log") - f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return } diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go index d7aed15f1..a4dfa7a72 100644 --- a/pkg/heartbeat/service_test.go +++ b/pkg/heartbeat/service_test.go @@ -17,7 +17,7 @@ func TestExecuteHeartbeat_Async(t *testing.T) { defer os.RemoveAll(tmpDir) hs := NewHeartbeatService(tmpDir, 30, true) - hs.started = true // Enable for testing + hs.stopChan = make(chan struct{}) // Enable for testing asyncCalled := false asyncResult := &tools.ToolResult{ @@ -37,7 +37,7 @@ func TestExecuteHeartbeat_Async(t *testing.T) { }) // Create HEARTBEAT.md - os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644) // Execute heartbeat directly (internal method for testing) hs.executeHeartbeat() @@ -55,7 +55,7 @@ func TestExecuteHeartbeat_Error(t *testing.T) { defer os.RemoveAll(tmpDir) hs := NewHeartbeatService(tmpDir, 30, true) - hs.started = true // Enable for testing + hs.stopChan = make(chan struct{}) // Enable for testing hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { return &tools.ToolResult{ @@ -68,7 +68,7 @@ func TestExecuteHeartbeat_Error(t *testing.T) { }) // Create HEARTBEAT.md - os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644) hs.executeHeartbeat() @@ -93,7 +93,7 @@ func TestExecuteHeartbeat_Silent(t *testing.T) { defer os.RemoveAll(tmpDir) hs := NewHeartbeatService(tmpDir, 30, true) - hs.started = true // Enable for testing + hs.stopChan = make(chan struct{}) // Enable for testing hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { return &tools.ToolResult{ @@ -106,7 +106,7 @@ func TestExecuteHeartbeat_Silent(t *testing.T) { }) // Create HEARTBEAT.md - os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644) hs.executeHeartbeat() @@ -167,14 +167,14 @@ func TestExecuteHeartbeat_NilResult(t *testing.T) { defer os.RemoveAll(tmpDir) hs := NewHeartbeatService(tmpDir, 30, true) - hs.started = true // Enable for testing + hs.stopChan = make(chan struct{}) // Enable for testing hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { return nil }) // Create HEARTBEAT.md - os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644) // Should not panic with nil result hs.executeHeartbeat() diff --git a/pkg/identity/generator.go b/pkg/identity/generator.go new file mode 100644 index 000000000..2ef269537 --- /dev/null +++ b/pkg/identity/generator.go @@ -0,0 +1,279 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package identity + +import ( + "fmt" + "sync" +) + +const ( + + DefaultIDLength = 8 +) + +var ( + // Default ULID generator + defaultULIDGen *ULIDGenerator + defaultULIDGenOnce sync.Once +) + +func getDefaultULIDGenerator() *ULIDGenerator { + defaultULIDGenOnce.Do(func() { + defaultULIDGen = NewULIDGenerator() + }) + return defaultULIDGen +} + +// Generator handles identity generation using ULID +type Generator struct { + hidPrefix string + sidPrefix string + idLength int // For legacy mode + ulidGen *ULIDGenerator + ulidVariant ULIDVariant + ulidType ULIDType + ulidEnabled bool + useLegacyMode bool // For backwards compatibility +} + +// NewGenerator creates a new identity generator +func NewGenerator() *Generator { + return &Generator{ + hidPrefix: DefaultHIDPrefix, + sidPrefix: DefaultSIDPrefix, + idLength: DefaultIDLength, + ulidGen: getDefaultULIDGenerator(), + ulidVariant: ULIDVariantNano, // Default to 21-char Nano ULIDs + ulidType: ULIDTypeNode, + ulidEnabled: true, + useLegacyMode: false, + } +} + +// WithHIDPrefix sets the H-id prefix for generation +func (g *Generator) WithHIDPrefix(prefix string) *Generator { + g.hidPrefix = normalizeID(prefix) + return g +} + +// WithSIDPrefix sets the S-id prefix for generation +func (g *Generator) WithSIDPrefix(prefix string) *Generator { + g.sidPrefix = normalizeID(prefix) + return g +} + +// WithIDLength sets the random ID length (DEPRECATED - use WithULIDVariant) +func (g *Generator) WithIDLength(length int) *Generator { + g.idLength = length + g.useLegacyMode = true + return g +} + +// WithULIDVariant sets the ULID variant for generation +func (g *Generator) WithULIDVariant(variant ULIDVariant) *Generator { + g.ulidVariant = variant + g.useLegacyMode = false + g.ulidEnabled = true + return g +} + +// WithULIDType sets the ULID type for generation +func (g *Generator) WithULIDType(ulidType ULIDType) *Generator { + g.ulidType = ulidType + return g +} + +// WithULIDGenerator sets a custom ULID generator +func (g *Generator) WithULIDGenerator(gen *ULIDGenerator) *Generator { + g.ulidGen = gen + g.ulidEnabled = true + return g +} + +// WithLegacyMode enables legacy random string generation (non-ULID) +func (g *Generator) WithLegacyMode() *Generator { + g.useLegacyMode = true + g.ulidEnabled = false + return g +} + +// GenerateHID generates a new H-id with the configured prefix and ULID suffix +func (g *Generator) GenerateHID() string { + if g.useLegacyMode { + suffix := randomString(g.idLength) + return fmt.Sprintf("%s-%s", g.hidPrefix, suffix) + } + ulid := g.ulidGen.GenerateWithType(g.ulidType, g.ulidVariant) + return fmt.Sprintf("%s-%s", g.hidPrefix, ulid.Value) +} + +// GenerateSID generates a new S-id with the configured prefix and ULID suffix +func (g *Generator) GenerateSID() string { + if g.useLegacyMode { + suffix := randomString(g.idLength) + return fmt.Sprintf("%s-%s", g.sidPrefix, suffix) + } + ulid := g.ulidGen.GenerateWithType(g.ulidType, g.ulidVariant) + return fmt.Sprintf("%s-%s", g.sidPrefix, ulid.Value) +} + +// Generate generates a complete new Identity +func (g *Generator) Generate() *Identity { + hid := g.GenerateHID() + sid := g.GenerateSID() + return NewIdentity(hid, sid) +} + +// GenerateWithHID generates an Identity with a specific H-id and random S-id +func (g *Generator) GenerateWithHID(hid string) *Identity { + sid := g.GenerateSID() + return NewIdentity(hid, sid) +} + +// GenerateWithSID generates an Identity with a specific S-id and random H-id +func (g *Generator) GenerateWithSID(sid string) *Identity { + hid := g.GenerateHID() + return NewIdentity(hid, sid) +} + +// GenerateULID generates a typed ULID +func (g *Generator) GenerateULID(ulidType ULIDType) *ULID { + return g.ulidGen.GenerateWithType(ulidType, g.ulidVariant) +} + +// randomString generates a random string of the given length +// DEPRECATED: Use ULID generation instead +func randomString(length int) string { + // Use ULID short variant for legacy compatibility + ulid := getDefaultULIDGenerator().GenerateWithType(ULIDTypeCustom, ULIDVariantShort) + if len(ulid.Value) >= length { + return ulid.Value[:length] + } + return ulid.Value +} + +// GenerateShortIDL generates a short random ID with specified length +// Note: Use GenerateShortID() from ulid.go for default 8-char IDs +func GenerateShortIDL(length int) string { + if length <= 0 { + length = DefaultIDLength + } + if length <= 8 { + // Use short ULID (8 chars) + ulid := getDefaultULIDGenerator().GenerateWithType(ULIDTypeCustom, ULIDVariantShort) + if len(ulid.Value) >= length { + return ulid.Value[:length] + } + return ulid.Value + } + // Use Nano ULID for longer requests + ulid := getDefaultULIDGenerator().GenerateWithType(ULIDTypeCustom, ULIDVariantNano) + if len(ulid.Value) >= length { + return ulid.Value[:length] + } + return ulid.Value +} + +// GenerateNodeID generates a node ID using ULID: "claw-{ulid}" +func GenerateNodeID() string { + ulid := getDefaultULIDGenerator().GenerateWithType(ULIDTypeNode, ULIDVariantNano) + return fmt.Sprintf("claw-%s", ulid.Value) +} + +// GenerateTaskID generates a task ID using ULID: "task-{ulid}" +func GenerateTaskID() string { + ulid := getDefaultULIDGenerator().GenerateWithType(ULIDTypeTask, ULIDVariantNano) + return fmt.Sprintf("task-%s", ulid.Value) +} + +// GenerateSessionID generates a session ID using ULID: "session-{ulid}" +func GenerateSessionID() string { + ulid := getDefaultULIDGenerator().GenerateWithType(ULIDTypeSession, ULIDVariantNano) + return fmt.Sprintf("session-%s", ulid.Value) +} + +// GenerateSwarmID generates a swarm ID using ULID: "swarm-{ulid}" +func GenerateSwarmID() string { + ulid := getDefaultULIDGenerator().GenerateWithType(ULIDTypeCustom, ULIDVariantNano) + return fmt.Sprintf("swarm-%s", ulid.Value) +} + +// GenerateTypedULID generates a ULID with the specified type and variant +// This is a typed wrapper around the ULIDGenerator +func GenerateTypedULID(ulidType ULIDType, variant ULIDVariant) *ULID { + return getDefaultULIDGenerator().GenerateWithType(ulidType, variant) +} + +// ParseOrGenerate parses an identity string or generates a new one if parsing fails +func ParseOrGenerate(s string) *Identity { + id, err := ParseIdentityString(s) + if err != nil { + return NewGenerator().Generate() + } + if !id.IsValid() { + return NewGenerator().Generate() + } + return id +} + +// Pool manages a pool of pre-generated identities for performance +// Now uses ULIDPool internally +type Pool struct { + hids chan string + sids chan string + ulidPool *ULIDPool +} + +// NewPool creates a new identity pool +func NewPool(size int) *Pool { + ulidGen := getDefaultULIDGenerator() + p := &Pool{ + hids: make(chan string, size), + sids: make(chan string, size), + ulidPool: NewULIDPool(size, ulidGen), + } + p.fill(size) + return p +} + +func (p *Pool) fill(count int) { + gen := NewGenerator() + for i := 0; i < count; i++ { + p.hids <- gen.GenerateHID() + p.sids <- gen.GenerateSID() + } +} + +// GetHID gets an H-id from the pool or generates a new one if empty +func (p *Pool) GetHID() string { + select { + case hid := <-p.hids: + return hid + default: + return NewGenerator().GenerateHID() + } +} + +// GetSID gets an S-id from the pool or generates a new one if empty +func (p *Pool) GetSID() string { + select { + case sid := <-p.sids: + return sid + default: + return NewGenerator().GenerateSID() + } +} + +// GenerateFromPool generates an Identity using IDs from the pool +func (p *Pool) GenerateFromPool() *Identity { + return NewIdentity(p.GetHID(), p.GetSID()) +} + +// GetULID gets a ULID from the internal pool +func (p *Pool) GetULID(ulidType ULIDType) *ULID { + return p.ulidPool.GetWithType(ulidType) +} diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go new file mode 100644 index 000000000..0d6b8e4ef --- /dev/null +++ b/pkg/identity/identity.go @@ -0,0 +1,213 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package identity + +import ( + "fmt" + "strings" + "sync" +) + +// Identity represents a two-level identity model: +// - H-id (Host ID): Represents a tenant/user/group for multi-tenancy isolation +// - S-id (Service ID): Represents a specific instance/node within the H-id +// +// Examples: +// - H-id: "user-alice", "org-company", "group-team1" +// - S-id: "node-01", "worker-primary", "coordinator-main" +// +// Full identity: "user-alice/node-01" or "org-company/worker-1" +type Identity struct { + // HID is the host/tenant identifier (e.g., "user-alice", "org-company") + HID string `json:"hid"` + + // SID is the service/instance identifier (e.g., "node-01", "worker-primary") + SID string `json:"sid"` + + // DisplayName is a human-readable name + DisplayName string `json:"display_name,omitempty"` + + // Metadata contains additional identity information + Metadata map[string]string `json:"metadata,omitempty"` + + mu sync.RWMutex +} + +// NewIdentity creates a new Identity with the given H-id and S-id +func NewIdentity(hid, sid string) *Identity { + return &Identity{ + HID: normalizeID(hid), + SID: normalizeID(sid), + Metadata: make(map[string]string), + } +} + +// NewIdentityFromString parses an identity string in format "hid/sid" +func NewIdentityFromString(s string) (*Identity, error) { + parts := strings.SplitN(s, "/", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid identity format: %q (expected \"hid/sid\")", s) + } + + hid := strings.TrimSpace(parts[0]) + sid := strings.TrimSpace(parts[1]) + + if hid == "" || sid == "" { + return nil, fmt.Errorf("invalid identity: hid and sid must not be empty") + } + + return NewIdentity(hid, sid), nil +} + +// String returns the identity in "hid/sid" format +func (id *Identity) String() string { + id.mu.RLock() + defer id.mu.RUnlock() + return fmt.Sprintf("%s/%s", id.HID, id.SID) +} + +// FullID returns the full identity string +func (id *Identity) FullID() string { + return id.String() +} + +// IsValid returns true if the identity is valid +func (id *Identity) IsValid() bool { + id.mu.RLock() + defer id.mu.RUnlock() + + if id.HID == "" || id.SID == "" { + return false + } + + // Check for invalid characters + for _, c := range id.HID + id.SID { + if !isValidIDChar(c) { + return false + } + } + + return true +} + +// Clone returns a deep copy of the identity +func (id *Identity) Clone() *Identity { + id.mu.RLock() + defer id.mu.RUnlock() + + clone := &Identity{ + HID: id.HID, + SID: id.SID, + DisplayName: id.DisplayName, + Metadata: make(map[string]string, len(id.Metadata)), + } + + for k, v := range id.Metadata { + clone.Metadata[k] = v + } + + return clone +} + +// SetMetadata sets a metadata key-value pair +func (id *Identity) SetMetadata(key, value string) { + id.mu.Lock() + defer id.mu.Unlock() + if id.Metadata == nil { + id.Metadata = make(map[string]string) + } + id.Metadata[key] = value +} + +// GetMetadata gets a metadata value by key +func (id *Identity) GetMetadata(key string) (string, bool) { + id.mu.RLock() + defer id.mu.RUnlock() + v, ok := id.Metadata[key] + return v, ok +} + +// IsSameTenant checks if two identities belong to the same tenant (same H-id) +func (id *Identity) IsSameTenant(other *Identity) bool { + if other == nil { + return false + } + id.mu.RLock() + defer id.mu.RUnlock() + other.mu.RLock() + defer other.mu.RUnlock() + return id.HID == other.HID +} + +// Equals checks if two identities are exactly the same +func (id *Identity) Equals(other *Identity) bool { + if other == nil { + return false + } + id.mu.RLock() + defer id.mu.RUnlock() + other.mu.RLock() + defer other.mu.RUnlock() + return id.HID == other.HID && id.SID == other.SID +} + +// normalizeID cleans up an ID by trimming whitespace and lowercasing +func normalizeID(id string) string { + return strings.TrimSpace(strings.ToLower(id)) +} + +// isValidIDChar checks if a character is valid for an ID +func isValidIDChar(c rune) bool { + return (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || + c == '-' || c == '_' || c == '.' +} + +// Source defines where an identity was loaded from +type Source int + +const ( + // SourceUnknown is when the source is unknown + SourceUnknown Source = iota + // SourceCLI is from command-line arguments + SourceCLI + // SourceEnv is from environment variables + SourceEnv + // SourceConfig is from configuration file + SourceConfig + // SourceAuto is auto-generated + SourceAuto +) + +// String returns the string representation of the source +func (s Source) String() string { + switch s { + case SourceCLI: + return "cli" + case SourceEnv: + return "env" + case SourceConfig: + return "config" + case SourceAuto: + return "auto" + default: + return "unknown" + } +} + +// LoadedIdentity includes the identity and its source +type LoadedIdentity struct { + *Identity + Source Source `json:"source"` +} + +// NewLoadedIdentity creates a new LoadedIdentity +func NewLoadedIdentity(id *Identity, source Source) *LoadedIdentity { + return &LoadedIdentity{ + Identity: id, + Source: source, + } +} diff --git a/pkg/identity/identity_test.go b/pkg/identity/identity_test.go new file mode 100644 index 000000000..e16857c81 --- /dev/null +++ b/pkg/identity/identity_test.go @@ -0,0 +1,840 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package identity + +import ( + "os" + "testing" + "time" +) + +func TestNewIdentity(t *testing.T) { + tests := []struct { + name string + hid string + sid string + want string + wantErr bool + }{ + { + name: "simple identity", + hid: "user-alice", + sid: "node-01", + want: "user-alice/node-01", + wantErr: false, + }, + { + name: "normalizes lowercase", + hid: "USER-BOB", + sid: "NODE-02", + want: "user-bob/node-02", + wantErr: false, + }, + { + name: "trims whitespace", + hid: " user-carol ", + sid: " node-03 ", + want: "user-carol/node-03", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id := NewIdentity(tt.hid, tt.sid) + if got := id.String(); got != tt.want { + t.Errorf("Identity.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewIdentityFromString(t *testing.T) { + tests := []struct { + name string + s string + wantHID string + wantSID string + wantErr bool + }{ + { + name: "valid format", + s: "user-alice/node-01", + wantHID: "user-alice", + wantSID: "node-01", + wantErr: false, + }, + { + name: "with spaces", + s: " user-alice / node-01 ", + wantHID: "user-alice", + wantSID: "node-01", + wantErr: false, + }, + { + name: "invalid format - no slash", + s: "user-alice", + wantErr: true, + }, + { + name: "empty hid", + s: "/node-01", + wantErr: true, + }, + { + name: "empty sid", + s: "user-alice/", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, err := NewIdentityFromString(tt.s) + if (err != nil) != tt.wantErr { + t.Errorf("NewIdentityFromString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if id.HID != tt.wantHID { + t.Errorf("Identity.HID = %v, want %v", id.HID, tt.wantHID) + } + if id.SID != tt.wantSID { + t.Errorf("Identity.SID = %v, want %v", id.SID, tt.wantSID) + } + } + }) + } +} + +func TestIdentity_IsValid(t *testing.T) { + tests := []struct { + name string + hid string + sid string + valid bool + }{ + {"valid simple", "user-alice", "node-01", true}, + {"valid with dots", "user.alice", "node.01", true}, + {"valid with underscores", "user_alice", "node_01", true}, + {"empty hid", "", "node-01", false}, + {"empty sid", "user-alice", "", false}, + {"uppercase is normalized", "user-Alice", "node-01", true}, // Auto-normalized to lowercase + {"invalid chars special", "user@alice", "node-01", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id := NewIdentity(tt.hid, tt.sid) + if got := id.IsValid(); got != tt.valid { + t.Errorf("Identity.IsValid() = %v, want %v", got, tt.valid) + } + }) + } +} + +func TestIdentity_IsSameTenant(t *testing.T) { + user1_node1 := NewIdentity("user-alice", "node-01") + user1_node2 := NewIdentity("user-alice", "node-02") + user2_node1 := NewIdentity("user-bob", "node-01") + + tests := []struct { + name string + id *Identity + other *Identity + expected bool + }{ + {"same tenant different node", user1_node1, user1_node2, true}, + {"different tenant", user1_node1, user2_node1, false}, + {"same identity", user1_node1, user1_node1, true}, + {"nil other", user1_node1, nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.id.IsSameTenant(tt.other); got != tt.expected { + t.Errorf("Identity.IsSameTenant() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestIdentity_Equals(t *testing.T) { + user1_node1 := NewIdentity("user-alice", "node-01") + user1_node1_copy := NewIdentity("user-alice", "node-01") + user1_node2 := NewIdentity("user-alice", "node-02") + + tests := []struct { + name string + id *Identity + other *Identity + expected bool + }{ + {"exact match", user1_node1, user1_node1_copy, true}, + {"same instance", user1_node1, user1_node1, true}, + {"different node", user1_node1, user1_node2, false}, + {"nil other", user1_node1, nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.id.Equals(tt.other); got != tt.expected { + t.Errorf("Identity.Equals() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestIdentity_Clone(t *testing.T) { + original := NewIdentity("user-alice", "node-01") + original.SetMetadata("key1", "value1") + original.DisplayName = "Alice's Node" + + clone := original.Clone() + + // Verify clone has same values + if clone.HID != original.HID { + t.Errorf("Clone.HID = %v, want %v", clone.HID, original.HID) + } + if clone.SID != original.SID { + t.Errorf("Clone.SID = %v, want %v", clone.SID, original.SID) + } + if clone.DisplayName != original.DisplayName { + t.Errorf("Clone.DisplayName = %v, want %v", clone.DisplayName, original.DisplayName) + } + + v, ok := clone.GetMetadata("key1") + if !ok || v != "value1" { + t.Errorf("Clone metadata not copied correctly") + } + + // Modify clone and ensure original is unchanged + clone.SetMetadata("key2", "value2") + _, ok = original.GetMetadata("key2") + if ok { + t.Errorf("Modifying clone affected original") + } +} + +func TestIdentity_Metadata(t *testing.T) { + id := NewIdentity("user-alice", "node-01") + + // Test set and get + id.SetMetadata("key1", "value1") + v, ok := id.GetMetadata("key1") + if !ok || v != "value1" { + t.Errorf("Metadata not set correctly") + } + + // Test missing key + _, ok = id.GetMetadata("missing") + if ok { + t.Errorf("Expected false for missing key") + } +} + +func TestLoader_Load(t *testing.T) { + // Save and restore environment + oldEnvHID := os.Getenv(EnvHID) + oldEnvSID := os.Getenv(EnvSID) + oldEnvIdentity := os.Getenv(EnvIdentity) + defer func() { + os.Setenv(EnvHID, oldEnvHID) + os.Setenv(EnvSID, oldEnvSID) + os.Setenv(EnvIdentity, oldEnvIdentity) + }() + + // Clean environment + os.Unsetenv(EnvHID) + os.Unsetenv(EnvSID) + os.Unsetenv(EnvIdentity) + + t.Run("CLI priority", func(t *testing.T) { + l := NewLoader() + l.SetCLI("cli-user", "cli-node") + l.SetConfig("config-user", "config-node") + + loaded, err := l.Load() + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if loaded.Source != SourceCLI { + t.Errorf("Expected SourceCLI, got %v", loaded.Source) + } + if loaded.HID != "cli-user" || loaded.SID != "cli-node" { + t.Errorf("CLI identity not loaded correctly: %s/%s", loaded.HID, loaded.SID) + } + }) + + t.Run("environment priority", func(t *testing.T) { + l := NewLoader() + os.Setenv(EnvIdentity, "env-user/env-node") + defer os.Unsetenv(EnvIdentity) + + loaded, err := l.Load() + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if loaded.Source != SourceEnv { + t.Errorf("Expected SourceEnv, got %v", loaded.Source) + } + if loaded.HID != "env-user" || loaded.SID != "env-node" { + t.Errorf("Env identity not loaded correctly: %s/%s", loaded.HID, loaded.SID) + } + }) + + t.Run("config priority", func(t *testing.T) { + // Ensure env vars are not set + os.Unsetenv(EnvHID) + os.Unsetenv(EnvSID) + os.Unsetenv(EnvIdentity) + + l := NewLoader() + l.SetConfig("config-user", "config-node") + + loaded, err := l.Load() + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if loaded.Source != SourceConfig { + t.Errorf("Expected SourceConfig, got %v", loaded.Source) + } + if loaded.HID != "config-user" || loaded.SID != "config-node" { + t.Errorf("Config identity not loaded correctly: %s/%s", loaded.HID, loaded.SID) + } + }) + + t.Run("auto generation", func(t *testing.T) { + l := NewLoader() + + loaded, err := l.Load() + if err != nil { + t.Fatalf("Load() error = %v", err) + } + if loaded.Source != SourceAuto { + t.Errorf("Expected SourceAuto, got %v", loaded.Source) + } + if !loaded.IsValid() { + t.Errorf("Auto-generated identity is invalid: %s/%s", loaded.HID, loaded.SID) + } + }) +} + +func TestGenerator_Generate(t *testing.T) { + g := NewGenerator() + + id := g.Generate() + if !id.IsValid() { + t.Errorf("Generated identity is invalid: %s", id.String()) + } + + // Test uniqueness + id2 := g.Generate() + if id.Equals(id2) { + t.Errorf("Generator produced duplicate identities") + } +} + +func TestGenerator_WithPrefixes(t *testing.T) { + g := NewGenerator(). + WithHIDPrefix("tenant"). + WithSIDPrefix("service"). + WithIDLength(4) + + hid := g.GenerateHID() + sid := g.GenerateSID() + + if !startsWith(hid, "tenant-") { + t.Errorf("HID doesn't have correct prefix: %s", hid) + } + if !startsWith(sid, "service-") { + t.Errorf("SID doesn't have correct prefix: %s", sid) + } +} + +func TestGenerateShortID(t *testing.T) { + id1 := GenerateShortIDL(8) + id2 := GenerateShortIDL(8) + + if len(id1) != 8 { + t.Errorf("GenerateShortIDL() length = %d, want 8", len(id1)) + } + if id1 == id2 { + t.Errorf("GenerateShortIDL() produced duplicate IDs") + } +} + +func TestGenerateNodeID(t *testing.T) { + id := GenerateNodeID() + if !startsWith(id, "claw-") { + t.Errorf("GenerateNodeID() = %v, want prefix 'claw-'", id) + } +} + +func TestParseIdentityString(t *testing.T) { + tests := []struct { + name string + s string + wantHID string + wantSID string + wantError bool + }{ + {"slash format", "user-alice/node-01", "user-alice", "node-01", false}, + {"dot format", "user-alice.node-01", "user-alice", "node-01", false}, + {"single value", "user-alice", "user-alice", "", false}, // SID is auto-generated + {"empty string", "", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, err := ParseIdentityString(tt.s) + if tt.wantError { + if err == nil { + t.Errorf("ParseIdentityString() expected error, got nil") + } + return + } + if err != nil { + t.Errorf("ParseIdentityString() error = %v", err) + return + } + if id.HID != tt.wantHID { + t.Errorf("HID = %v, want %v", id.HID, tt.wantHID) + } + if tt.wantSID != "" && id.SID != tt.wantSID { + t.Errorf("SID = %v, want %v", id.SID, tt.wantSID) + } + }) + } +} + +func TestPool(t *testing.T) { + pool := NewPool(5) + + hid1 := pool.GetHID() + hid2 := pool.GetHID() + + if hid1 == "" || hid2 == "" { + t.Errorf("Pool returned empty H-ID") + } + + id := pool.GenerateFromPool() + if !id.IsValid() { + t.Errorf("Pool generated invalid identity") + } +} + +func TestSource_String(t *testing.T) { + tests := []struct { + source Source + want string + }{ + {SourceCLI, "cli"}, + {SourceEnv, "env"}, + {SourceConfig, "config"}, + {SourceAuto, "auto"}, + {SourceUnknown, "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := tt.source.String(); got != tt.want { + t.Errorf("Source.String() = %v, want %v", got, tt.want) + } + }) + } +} + +// Helper function +func startsWith(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +} + +// ULID Tests + +func TestULID_Generate(t *testing.T) { + gen := NewULIDGenerator() + + tests := []struct { + name string + variant ULIDVariant + ulidType ULIDType + wantLen int + }{ + {"UUID v4", ULIDVariantUUID, ULIDTypeNode, 36}, + {"ULID v7", ULIDVariantULIDv7, ULIDTypeTask, 26}, + {"ULID 26-char", ULIDVariantULID, ULIDTypeSession, 26}, + {"Nano 26-char", ULIDVariantNano, ULIDTypeMemory, 26}, + {"Short 8-char", ULIDVariantShort, ULIDTypeDevice, 8}, + {"Base62", ULIDVariantBase62, ULIDTypeCustom, 1}, // Base62 varies in length + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ulid := gen.GenerateWithType(tt.ulidType, tt.variant) + + if ulid.Type != tt.ulidType { + t.Errorf("ULID.Type = %v, want %v", ulid.Type, tt.ulidType) + } + if ulid.Variant != tt.variant { + t.Errorf("ULID.Variant = %v, want %v", ulid.Variant, tt.variant) + } + if tt.name != "Base62" && len(ulid.Value) != tt.wantLen { + t.Errorf("ULID.Value length = %d, want %d", len(ulid.Value), tt.wantLen) + } + if !ulid.IsValid() { + t.Errorf("ULID.IsValid() = false, want true") + } + }) + } +} + +func TestULID_Uniqueness(t *testing.T) { + gen := NewULIDGenerator() + seen := make(map[string]bool) + + // Generate 1000 ULIDs and ensure uniqueness + for i := 0; i < 1000; i++ { + ulid := gen.GenerateWithType(ULIDTypeNode, ULIDVariantNano) + if seen[ulid.Value] { + t.Errorf("Duplicate ULID generated: %s", ulid.Value) + } + seen[ulid.Value] = true + } +} + +func TestULID_Sortable(t *testing.T) { + gen := NewULIDGenerator() + var ulids []*ULID + + // Generate 100 ULIDs + for i := 0; i < 100; i++ { + ulids = append(ulids, gen.GenerateWithType(ULIDTypeNode, ULIDVariantNano)) + time.Sleep(time.Millisecond) + } + + // Check they're sorted by time + for i := 1; i < len(ulids); i++ { + if ulids[i].Time.Before(ulids[i-1].Time) { + t.Errorf("ULIDs not sorted by time: %s after %s", ulids[i].Value, ulids[i-1].Value) + } + } +} + +func TestULID_Parse(t *testing.T) { + gen := NewULIDGenerator() + + tests := []struct { + name string + variant ULIDVariant + ulidType ULIDType + wantError bool + }{ + {"Parse UUID", ULIDVariantUUID, ULIDTypeNode, false}, + {"Parse ULIDv7", ULIDVariantULIDv7, ULIDTypeTask, false}, + {"Parse ULID", ULIDVariantULID, ULIDTypeSession, false}, + {"Parse Nano", ULIDVariantNano, ULIDTypeMemory, false}, + {"Parse Short", ULIDVariantShort, ULIDTypeDevice, false}, + {"Invalid", ULIDVariantShort, ULIDTypeNode, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.name == "Invalid" { + // ParseULID is lenient - will return Custom variant for unknown formats + parsed, err := ParseULID("invalid-ulid") + if err != nil || parsed.Variant != ULIDVariantCustom { + t.Logf("ParseULID() returned variant %v for invalid input", parsed.Variant) + } + return + } + + original := gen.GenerateWithType(tt.ulidType, tt.variant) + parsed, err := ParseULID(original.Value) + + if err != nil { + t.Errorf("ParseULID() error = %v", err) + return + } + + if parsed.Value != original.Value { + t.Errorf("Parsed value = %v, want %v", parsed.Value, original.Value) + } + }) + } +} + +func TestULID_Type(t *testing.T) { + tests := []struct { + ulidType ULIDType + prefix string + }{ + {ULIDTypeNode, "n"}, + {ULIDTypeTask, "t"}, + {ULIDTypeSession, "s"}, + {ULIDTypeMemory, "m"}, + {ULIDTypeWorkflow, "w"}, + {ULIDTypeDevice, "d"}, + {ULIDTypeUser, "u"}, + {ULIDTypeCustom, "x"}, + } + + for _, tt := range tests { + t.Run(string(tt.ulidType), func(t *testing.T) { + ulid := &ULID{Type: tt.ulidType} + if ulid.GetPrefix() != tt.prefix { + t.Errorf("ULID.GetPrefix() = %v, want %v", ulid.GetPrefix(), tt.prefix) + } + }) + } +} + +func TestULIDSet(t *testing.T) { + set := NewULIDSet() + gen := NewULIDGenerator() + + // Add ULIDs + for i := 0; i < 10; i++ { + ulid := gen.GenerateWithType(ULIDTypeNode, ULIDVariantNano) + set.Add(ulid) + } + + if set.Len() != 10 { + t.Errorf("ULIDSet.Len() = %d, want 10", set.Len()) + } + + // Check contains + ulids := set.List() + if len(ulids) != 10 { + t.Errorf("ULIDSet.List() length = %d, want 10", len(ulids)) + } + + // Remove + set.Remove(ulids[0].String()) + if set.Len() != 9 { + t.Errorf("ULIDSet.Len() after remove = %d, want 9", set.Len()) + } + + // Clear + set.Clear() + if set.Len() != 0 { + t.Errorf("ULIDSet.Len() after clear = %d, want 0", set.Len()) + } +} + +func TestULIDPool(t *testing.T) { + gen := NewULIDGenerator() + pool := NewULIDPool(5, gen) + + // Get ULIDs from pool + var ulids []*ULID + for i := 0; i < 5; i++ { + ulid := pool.GetWithType(ULIDTypeNode) + if ulid == nil { + t.Errorf("ULIDPool.GetWithType() returned nil") + } + ulids = append(ulids, ulid) + } + + // Check pool size + if pool.Size() < 0 { + t.Errorf("ULIDPool.Size() returned negative value") + } + + // Close pool + pool.Close() + if !pool.IsClosed() { + t.Errorf("ULIDPool.IsClosed() = false after Close()") + } +} + +func TestGenerator_WithULID(t *testing.T) { + g := NewGenerator(). + WithULIDVariant(ULIDVariantNano). + WithULIDType(ULIDTypeNode) + + hid := g.GenerateHID() + sid := g.GenerateSID() + + // Check that IDs contain ULIDs (longer than legacy 8-char) + if len(hid) <= len("tenant-xxxxxxxx") { + t.Errorf("HID with ULID should be longer: %s", hid) + } + if len(sid) <= len("service-xxxxxxxx") { + t.Errorf("SID with ULID should be longer: %s", sid) + } +} + +func TestGenerateULID(t *testing.T) { + ulid := GenerateTypedULID(ULIDTypeTask, ULIDVariantNano) + + if ulid.Type != ULIDTypeTask { + t.Errorf("GenerateTypedULID() type = %v, want %v", ulid.Type, ULIDTypeTask) + } + if ulid.Variant != ULIDVariantNano { + t.Errorf("GenerateTypedULID() variant = %v, want %v", ulid.Variant, ULIDVariantNano) + } + if !ulid.IsValid() { + t.Errorf("GenerateTypedULID() produced invalid ULID") + } +} + +func TestULID_IsValid(t *testing.T) { + tests := []struct { + name string + ulid *ULID + valid bool + }{ + {"Valid ULID", &ULID{Value: "01ARZ3NDEKTSV4RRFFQ69G5FAV", Variant: ULIDVariantULID}, true}, + {"Valid Nano", &ULID{Value: "01ARZ3NDEKTSV4RRFFQ69G5FAV", Variant: ULIDVariantNano}, true}, + {"Valid Short", &ULID{Value: "ARZ3NDEK", Variant: ULIDVariantShort}, true}, + {"Empty value", &ULID{Value: "", Variant: ULIDVariantNano}, false}, + {"Too short", &ULID{Value: "AB", Variant: ULIDVariantNano}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.ulid.IsValid(); got != tt.valid { + t.Errorf("ULID.IsValid() = %v, want %v", got, tt.valid) + } + }) + } +} + +func TestULID_String(t *testing.T) { + ulid := &ULID{ + Type: ULIDTypeNode, + Variant: ULIDVariantNano, + Value: "01ARZ3NDEKTSV4RRFFQ69G5FAV", + } + + str := ulid.String() + if str == "" { + t.Errorf("ULID.String() returned empty") + } + // Nano ULID with prefix should be "n-01ARZ3NDEKTSV4RRFFQ69G5FAV" + expectedValue := "01ARZ3NDEKTSV4RRFFQ69G5FAV" + if !contains(str, expectedValue) { + t.Errorf("ULID.String() = %v, should contain value %v", str, expectedValue) + } +} + +func TestULIDType_String(t *testing.T) { + tests := []struct { + ulidType ULIDType + want string + }{ + {ULIDTypeNode, "node"}, + {ULIDTypeTask, "task"}, + {ULIDTypeSession, "session"}, + {ULIDTypeMemory, "mem"}, + {ULIDTypeWorkflow, "workflow"}, + {ULIDTypeDevice, "device"}, + {ULIDTypeUser, "user"}, + {ULIDTypeCustom, "custom"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := tt.ulidType.String(); got != tt.want { + t.Errorf("ULIDType.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestULIDVariant_String(t *testing.T) { + tests := []struct { + variant ULIDVariant + want string + }{ + {ULIDVariantUUID, "uuid"}, + {ULIDVariantULIDv7, "ulidv7"}, + {ULIDVariantULID, "ulid"}, + {ULIDVariantNano, "nano"}, + {ULIDVariantShort, "short"}, + {ULIDVariantBase62, "base62"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := tt.variant.String(); got != tt.want { + t.Errorf("ULIDVariant.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestULIDSet_FilterByType(t *testing.T) { + set := NewULIDSet() + gen := NewULIDGenerator() + + // Add different types + set.Add(gen.GenerateWithType(ULIDTypeNode, ULIDVariantNano)) + set.Add(gen.GenerateWithType(ULIDTypeTask, ULIDVariantNano)) + set.Add(gen.GenerateWithType(ULIDTypeTask, ULIDVariantNano)) + set.Add(gen.GenerateWithType(ULIDTypeSession, ULIDVariantNano)) + + taskSlice := set.FilterByType(ULIDTypeTask) + if len(taskSlice) != 2 { + t.Errorf("FilterByType() length = %d, want 2", len(taskSlice)) + } +} + +func TestULIDSet_Merge(t *testing.T) { + set1 := NewULIDSet() + set2 := NewULIDSet() + gen := NewULIDGenerator() + + ulid1 := gen.GenerateWithType(ULIDTypeNode, ULIDVariantNano) + ulid2 := gen.GenerateWithType(ULIDTypeTask, ULIDVariantNano) + + set1.Add(ulid1) + set2.Add(ulid2) + + merged := set1.Merge(set2) + if merged.Len() != 2 { + t.Errorf("Merge() length = %d, want 2", merged.Len()) + } +} + +func TestULIDSet_Difference(t *testing.T) { + set1 := NewULIDSet() + set2 := NewULIDSet() + gen := NewULIDGenerator() + + ulid1 := gen.GenerateWithType(ULIDTypeNode, ULIDVariantNano) + ulid2 := gen.GenerateWithType(ULIDTypeTask, ULIDVariantNano) + + set1.Add(ulid1) + set1.Add(ulid2) + set2.Add(ulid2) + + diff := set1.Difference(set2) + if diff.Len() != 1 { + t.Errorf("Difference() length = %d, want 1", diff.Len()) + } + if !diff.Contains(ulid1) { + t.Errorf("Difference() should contain ulid1") + } +} + +// Helper function +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && ( + s[:len(substr)] == substr || + s[len(s)-len(substr):] == substr || + findInString(s, substr))) +} + +func findInString(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/pkg/identity/loader.go b/pkg/identity/loader.go new file mode 100644 index 000000000..202f1ada5 --- /dev/null +++ b/pkg/identity/loader.go @@ -0,0 +1,267 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package identity + +import ( + "fmt" + "os" + "strings" +) + +const ( + // DefaultHIDPrefix is the default prefix for auto-generated H-ids + DefaultHIDPrefix = "user" + // DefaultSIDPrefix is the default prefix for auto-generated S-ids + DefaultSIDPrefix = "node" +) + +// Environment variables for identity configuration +const ( + EnvHID = "PICOCLAW_IDENTITY_HID" + EnvSID = "PICOCLAW_IDENTITY_SID" + EnvIdentity = "PICOCLAW_IDENTITY" // Full identity in "hid/sid" format +) + +// Loader handles identity loading from multiple sources +type Loader struct { + // configHID is the H-id from config file + configHID string + // configSID is the S-id from config file + configSID string + // cliHID is the H-id from CLI arguments + cliHID string + // cliSID is the S-id from CLI arguments + cliSID string +} + +// NewLoader creates a new identity loader +func NewLoader() *Loader { + return &Loader{} +} + +// SetConfig sets the identity from config file +func (l *Loader) SetConfig(hid, sid string) { + l.configHID = hid + l.configSID = sid +} + +// SetCLI sets the identity from CLI arguments +func (l *Loader) SetCLI(hid, sid string) { + l.cliHID = hid + l.cliSID = sid +} + +// Load loads the identity from available sources in priority order: +// 1. CLI arguments (highest priority) +// 2. Environment variables +// 3. Config file +// 4. Auto-generation (lowest priority) +func (l *Loader) Load() (*LoadedIdentity, error) { + // Priority 1: CLI arguments + if l.cliHID != "" || l.cliSID != "" { + hid := l.cliHID + sid := l.cliSID + + // If only one is provided, try to get the other from env + if hid == "" { + hid = os.Getenv(EnvHID) + } + if sid == "" { + sid = os.Getenv(EnvSID) + } + + // Generate missing parts if needed + if hid == "" { + hid = generateDefaultHID() + } + if sid == "" { + sid = generateDefaultSID() + } + + id := NewIdentity(hid, sid) + if !id.IsValid() { + return nil, fmt.Errorf("invalid CLI identity: %s", id.String()) + } + return NewLoadedIdentity(id, SourceCLI), nil + } + + // Priority 2: Environment variables + envHID := os.Getenv(EnvHID) + envSID := os.Getenv(EnvSID) + envIdentity := os.Getenv(EnvIdentity) + + if envIdentity != "" { + id, err := NewIdentityFromString(envIdentity) + if err != nil { + return nil, fmt.Errorf("invalid environment identity: %w", err) + } + if id.IsValid() { + return NewLoadedIdentity(id, SourceEnv), nil + } + } + + if envHID != "" || envSID != "" { + hid := envHID + sid := envSID + + // Generate missing parts + if hid == "" { + hid = generateDefaultHID() + } + if sid == "" { + sid = generateDefaultSID() + } + + id := NewIdentity(hid, sid) + if id.IsValid() { + return NewLoadedIdentity(id, SourceEnv), nil + } + } + + // Priority 3: Config file + if l.configHID != "" || l.configSID != "" { + hid := l.configHID + sid := l.configSID + + // Generate missing parts + if hid == "" { + hid = generateDefaultHID() + } + if sid == "" { + sid = generateDefaultSID() + } + + id := NewIdentity(hid, sid) + if id.IsValid() { + return NewLoadedIdentity(id, SourceConfig), nil + } + } + + // Priority 4: Auto-generate + hid := generateDefaultHID() + sid := generateDefaultSID() + id := NewIdentity(hid, sid) + return NewLoadedIdentity(id, SourceAuto), nil +} + +// LoadOrGenerate loads the identity or generates one if loading fails +func (l *Loader) LoadOrGenerate() *LoadedIdentity { + id, err := l.Load() + if err != nil { + // Generate on any error + hid := generateDefaultHID() + sid := generateDefaultSID() + return NewLoadedIdentity(NewIdentity(hid, sid), SourceAuto) + } + return id +} + +// MustLoad loads the identity or panics +func (l *Loader) MustLoad() *LoadedIdentity { + id, err := l.Load() + if err != nil { + panic(fmt.Sprintf("failed to load identity: %v", err)) + } + return id +} + +// ParseIdentityString parses an identity from a string +// Supports formats: "hid/sid", "hid", "hid.sid" +func ParseIdentityString(s string) (*Identity, error) { + s = strings.TrimSpace(s) + if s == "" { + return nil, fmt.Errorf("empty identity string") + } + + // Try "hid/sid" format first + if strings.Contains(s, "/") { + return NewIdentityFromString(s) + } + + // Try "hid.sid" format + if strings.Contains(s, ".") { + parts := strings.SplitN(s, ".", 2) + if len(parts) == 2 { + return NewIdentity(parts[0], parts[1]), nil + } + } + + // Single value - treat as HID, auto-generate SID + return NewIdentity(s, generateDefaultSID()), nil +} + +// LoadFromConfigMap loads identity from a generic config map +func LoadFromConfigMap(cfg map[string]interface{}) (*Identity, error) { + var hid, sid string + + if v, ok := cfg["hid"].(string); ok { + hid = v + } + if v, ok := cfg["sid"].(string); ok { + sid = v + } + + // Check for nested identity config + if identityCfg, ok := cfg["identity"].(map[string]interface{}); ok { + if v, ok := identityCfg["hid"].(string); ok { + hid = v + } + if v, ok := identityCfg["sid"].(string); ok { + sid = v + } + } + + // At least HID should be present + if hid == "" { + return nil, fmt.Errorf("missing hid in config") + } + + if sid == "" { + sid = generateDefaultSID() + } + + id := NewIdentity(hid, sid) + if !id.IsValid() { + return nil, fmt.Errorf("invalid identity from config: %s/%s", hid, sid) + } + + return id, nil +} + +// GetUsername derives a username from the current environment +// Used as a default HID when none is provided +func GetUsername() string { + // Try common environment variables + for _, env := range []string{"USER", "USERNAME", "LOGNAME", "LNAME"} { + if user := os.Getenv(env); user != "" { + return normalizeID(user) + } + } + + return DefaultHIDPrefix +} + +// GetHostname derives a hostname from the current environment +// Used as a default SID when none is provided +func GetHostname() string { + if host, err := os.Hostname(); err == nil { + // Extract just the hostname part (remove domain) + host = strings.Split(host, ".")[0] + return normalizeID(host) + } + + return DefaultSIDPrefix +} + +// generateDefaultHID generates a default H-id +func generateDefaultHID() string { + return fmt.Sprintf("%s-%s", DefaultHIDPrefix, GetUsername()) +} + +// generateDefaultSID generates a default S-id +func generateDefaultSID() string { + return fmt.Sprintf("%s-%s", DefaultSIDPrefix, GetHostname()) +} diff --git a/pkg/identity/ulid.go b/pkg/identity/ulid.go new file mode 100644 index 000000000..02e53e4f8 --- /dev/null +++ b/pkg/identity/ulid.go @@ -0,0 +1,924 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package identity + +import ( + "crypto/rand" + "fmt" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/oklog/ulid/v2" +) + +// ULIDType represents the type of ULID +type ULIDType string + +const ( + ULIDTypeNode ULIDType = "node" + ULIDTypeTask ULIDType = "task" + ULIDTypeSession ULIDType = "session" + ULIDTypeMemory ULIDType = "mem" + ULIDTypeWorkflow ULIDType = "workflow" + ULIDTypeDevice ULIDType = "device" + ULIDTypeUser ULIDType = "user" + ULIDTypeCustom ULIDType = "custom" +) + +// String returns the string representation of ULIDType +func (t ULIDType) String() string { + return string(t) +} + +// ULIDVariant represents the ULID format variant +type ULIDVariant string + +const ( + ULIDVariantUUID ULIDVariant = "uuid" // Standard UUID v4: 8-4-4-4-4-4-7-4-4-4 + ULIDVariantULIDv7 ULIDVariant = "ulidv7" // ULID v7: 8-4-4-4-4-4-7 + ULIDVariantULID ULIDVariant = "ulid" // ULID: 26 chars Crockford base32 + ULIDVariantNano ULIDVariant = "nano" // Nano: 21 chars + ULIDVariantShort ULIDVariant = "short" // Short: 8 chars (first 8 of UUID) + ULIDVariantBase62 ULIDVariant = "base62" // Base62 encoded + ULIDVariantCustom ULIDVariant = "custom" // Custom format +) + +// String returns the string representation of ULIDVariant +func (v ULIDVariant) String() string { + return string(v) +} + +// ULID represents a ULID (Universal Unique IDentifier) +type ULID struct { + Type ULIDType + Variant ULIDVariant + Value string + Prefix string + Time time.Time + Metadata map[string]string +} + +// ULIDGenerator generates ULIDs +type ULIDGenerator struct { + mu sync.RWMutex + entropy *ulid.MonotonicEntropy + prefix string +} + +// NewULIDGenerator creates a new ULID generator +func NewULIDGenerator() *ULIDGenerator { + entropy := ulid.Monotonic(rand.Reader, 0) + return &ULIDGenerator{ + entropy: entropy, + } +} + +// WithPrefix sets the prefix for generated ULIDs +func (g *ULIDGenerator) WithPrefix(prefix string) *ULIDGenerator { + g.mu.Lock() + defer g.mu.Unlock() + g.prefix = prefix + return g +} + +// Generate generates a new ULID with default UUID variant +func (g *ULIDGenerator) Generate() *ULID { + return g.GenerateWithType(ULIDTypeNode, ULIDVariantUUID) +} + +// GenerateWithType generates a ULID with specific type and variant +func (g *ULIDGenerator) GenerateWithType(ulidType ULIDType, variant ULIDVariant) *ULID { + u := &ULID{ + Type: ulidType, + Variant: variant, + Time: time.Now(), + Metadata: make(map[string]string), + } + + switch variant { + case ULIDVariantUUID: + // UUID v4 format + id := uuid.New() + u.Value = id.String() + + case ULIDVariantULIDv7: + // ULID v7 format + id := ulid.Make() + u.Value = id.String() + + case ULIDVariantULID: + // ULID (26 chars, Crockford base32) + id := ulid.Make() + u.Value = id.String() + + case ULIDVariantNano: + // Nano format: 21 chars (21 chars from ULID without prefix) + // Generate a ULID and extract components + id, err := ulid.New(ulid.Timestamp(time.Now()), g.entropy) + if err != nil { + // Fallback to simple format + id = ulid.MustNew(ulid.Timestamp(time.Now()), ulid.DefaultEntropy()) + } + // Use full 26 chars from ULID, but store type prefix separately + u.Value = id.String() + u.Prefix = getNanoTypeChar(ulidType) + + case ULIDVariantShort: + // Short format: 8 chars (UUID prefix) + id := uuid.New() + u.Value = strings.ReplaceAll(id.String(), "-", "")[:8] + + case ULIDVariantBase62: + // Base62 encoded (using ULID as base) + id := ulid.Make() + // Encode the full ULID bytes as base62 + ulidBytes := id.Bytes() + // Convert to base62 + var n uint64 + for i := 0; i < 8; i++ { + n = n<<8 + uint64(ulidBytes[i]) + } + u.Value = base62Encode(n) + + default: + // Default to UUID v4 + id := uuid.New() + u.Value = id.String() + } + + // Add prefix if specified + if g.prefix != "" { + u.Prefix = g.prefix + } + + return u +} + +// String returns the full ULID string +func (u *ULID) String() string { + if u == nil { + return "" + } + + parts := []string{} + if u.Prefix != "" { + parts = append(parts, u.Prefix, "-") + } + parts = append(parts, u.Value) + + return strings.Join(parts, "-") +} + +// GenerateNodeID generates a node ULID +func (g *ULIDGenerator) GenerateNodeID() string { + u := g.GenerateWithType(ULIDTypeNode, ULIDVariantUUID) + return u.String() +} + +// GenerateTaskID generates a task ULID +func (g *ULIDGenerator) GenerateTaskID() string { + u := g.GenerateWithType(ULIDTypeTask, ULIDVariantUUID) + return u.String() +} + +// GenerateSessionID generates a session ULID +func (g *ULIDGenerator) GenerateSessionID() string { + u := g.GenerateWithType(ULIDTypeSession, ULIDVariantUUID) + return u.String() +} + +// GenerateMemoryID generates a memory ULID +func (g *ULIDGenerator) GenerateMemoryID() string { + u := g.GenerateWithType(ULIDTypeMemory, ULIDVariantUUID) + return u.String() +} + +// GetNanoTypeChar returns the type character for nano ULIDs +func getNanoTypeChar(ulidType ULIDType) string { + switch ulidType { + case ULIDTypeNode: + return "n" + case ULIDTypeTask: + return "t" + case ULIDTypeSession: + return "s" + case ULIDTypeMemory: + return "m" + case ULIDTypeWorkflow: + return "w" + case ULIDTypeDevice: + return "d" + case ULIDTypeUser: + return "u" + default: + return "x" + } +} + +// base62Encode encodes a uint64 to base62 string +func base62Encode(n uint64) string { + const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + if n == 0 { + return "0" + } + + var result []byte + for n > 0 { + result = append(result, charset[n%62]) + n /= 62 + } + + // Reverse result + for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 { + result[i], result[j] = result[j], result[i] + } + + return string(result) +} + +// ParseULID parses a ULID string into its components +func ParseULID(s string) (*ULID, error) { + if s == "" { + return nil, fmt.Errorf("empty ULID") + } + + u := &ULID{ + Metadata: make(map[string]string), + } + + // Detect format by checking string characteristics + if strings.Contains(s, "-") { + parts := strings.Split(s, "-") + switch len(parts) { + case 5: + // Standard UUID: 8-4-4-4-4 + if len(parts[0]) == 8 && len(parts[1]) == 4 && len(parts[2]) == 4 && + len(parts[3]) == 4 && len(parts[4]) == 12 { + u.Variant = ULIDVariantUUID + u.Value = s + return u, nil + } + + case 4: + // ULIDv7 or ULID: 8-4-4-4-4-4 + if len(parts[0]) == 8 && len(parts[1]) == 4 && len(parts[2]) == 4 && + len(parts[3]) == 4 && len(parts[4]) == 4 { + // Check if it's ULIDv7 (version 7) or ULID (base32) + if parts[4][0] == '7' { + u.Variant = ULIDVariantULIDv7 + } else { + u.Variant = ULIDVariantULID + } + u.Value = s + return u, nil + } + + case 3: + // Nano format or with prefix + if len(parts[0]) == 1 && len(parts[2]) == 21 { + u.Variant = ULIDVariantNano + u.Value = parts[2] + u.Prefix = parts[0] + return u, nil + } + } + } + + // Try to parse as ULID + if _, err := ulid.Parse(s); err == nil { + u.Variant = ULIDVariantULID + u.Value = s + return u, nil + } + + // Try to parse as UUID + if _, err := uuid.Parse(s); err == nil { + u.Variant = ULIDVariantUUID + u.Value = s + return u, nil + } + + // Fallback: treat as custom value + u.Variant = ULIDVariantCustom + u.Value = s + return u, nil +} + +// IsStandard returns true if this is a standard ULID format +func (u *ULID) IsStandard() bool { + return u.Variant == ULIDVariantUUID || + u.Variant == ULIDVariantULIDv7 || + u.Variant == ULIDVariantULID +} + +// IsValid returns true if the ULID has a valid value +func (u *ULID) IsValid() bool { + if u == nil || u.Value == "" { + return false + } + switch u.Variant { + case ULIDVariantUUID: + // UUID should be 36 chars with dashes + return len(u.Value) == 36 + case ULIDVariantULIDv7, ULIDVariantULID: + // ULID should be 26 chars + return len(u.Value) == 26 + case ULIDVariantNano: + // Nano ULID uses 26 chars from ULID, prefix stored separately + return len(u.Value) == 26 + case ULIDVariantShort: + // Short ULID should be 8 chars + return len(u.Value) == 8 + case ULIDVariantBase62: + // Base62 should be 1-22 chars + return len(u.Value) >= 1 && len(u.Value) <= 22 + case ULIDVariantCustom: + // Custom can be any non-empty string + return len(u.Value) > 0 + default: + return len(u.Value) > 0 + } +} + +// GetPrefix returns the type prefix for this ULID +func (u *ULID) GetPrefix() string { + return getNanoTypeChar(u.Type) +} + +// ExtractType extracts the ULID type from the ULID string +func ExtractType(ulidStr string) ULIDType { + // Look for type prefix in nano format or custom prefix + if strings.HasPrefix(ulidStr, "n-") || strings.HasPrefix(ulidStr, "t-") { + return ULIDTypeTask + } + if strings.HasPrefix(ulidStr, "s-") { + return ULIDTypeSession + } + if strings.HasPrefix(ulidStr, "m-") { + return ULIDTypeMemory + } + if strings.HasPrefix(ulidStr, "w-") { + return ULIDTypeWorkflow + } + if strings.HasPrefix(ulidStr, "d-") { + return ULIDTypeDevice + } + if strings.HasPrefix(ulidStr, "u-") { + return ULIDTypeUser + } + return ULIDTypeNode +} + +// FormatULID formats a ULID with the specified variant +func FormatULID(id string, variant ULIDVariant) (string, error) { + switch variant { + case ULIDVariantUUID: + // Validate as UUID v4 + if _, err := uuid.Parse(id); err != nil { + return "", fmt.Errorf("invalid UUID: %w", err) + } + return id, nil + + case ULIDVariantULIDv7: + // Validate as ULID v7 + if _, err := ulid.Parse(id); err != nil { + return "", fmt.Errorf("invalid ULID: %w", err) + } + return id, nil + + case ULIDVariantULID: + // Validate as ULID + if _, err := ulid.Parse(id); err != nil { + return "", fmt.Errorf("invalid ULID: %w", err) + } + return id, nil + + case ULIDVariantNano: + // Validate as nano format: t-timestamp-random (21 chars) + if len(id) != 21 { + return "", fmt.Errorf("invalid nano ULID: wrong length") + } + // Check first char is type char + if id[0] < 'a' || id[0] > 'z' { + return "", fmt.Errorf("invalid nano ULID: wrong type char") + } + return id, nil + + case ULIDVariantShort: + // Validate as 8-char format + if len(id) != 8 { + return "", fmt.Errorf("invalid short ULID: wrong length") + } + // Check hex chars + for _, c := range id { + if !isHexChar(byte(c)) { + return "", fmt.Errorf("invalid short ULID: non-hex character") + } + } + return id, nil + + case ULIDVariantBase62: + // Base62 format validation + if len(id) < 1 || len(id) > 22 { + return "", fmt.Errorf("invalid base62 ULID: wrong length") + } + return id, nil + + default: + return "", fmt.Errorf("unknown ULID variant: %s", variant) + } +} + +// isHexChar checks if a character is a hex digit +func isHexChar(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + default: + return false + } +} + +// GenerateULID generates a new ULID with the specified variant +func GenerateULID(variant ULIDVariant) (string, error) { + gen := NewULIDGenerator() + u := gen.GenerateWithType(ULIDTypeNode, variant) + return u.String(), nil +} + +// MustGenerateULID generates a new ULID or panics +func MustGenerateULID(variant ULIDVariant) string { + id, err := GenerateULID(variant) + if err != nil { + panic(err) + } + return id +} + +// GenerateNodeULID generates a node ULID +func GenerateNodeULID() string { + gen := NewULIDGenerator() + return gen.GenerateNodeID() +} + +// GenerateTaskULID generates a task ULID +func GenerateTaskULID() string { + gen := NewULIDGenerator() + return gen.GenerateTaskID() +} + +// GenerateSessionULID generates a session ULID +func GenerateSessionULID() string { + gen := NewULIDGenerator() + return gen.GenerateSessionID() +} + +// GenerateMemoryULID generates a memory ULID +func GenerateMemoryULID() string { + gen := NewULIDGenerator() + return gen.GenerateMemoryID() +} + +// GenerateWorkflowULID generates a workflow ULID +func GenerateWorkflowULID() string { + gen := NewULIDGenerator() + u := gen.GenerateWithType(ULIDTypeWorkflow, ULIDVariantUUID) + return u.String() +} + +// GenerateCustomULID generates a custom ULID with specified prefix and variant +func GenerateCustomULID(prefix string, variant ULIDVariant) (string, error) { + gen := NewULIDGenerator().WithPrefix(prefix) + u := gen.GenerateWithType(ULIDTypeCustom, variant) + return u.String(), nil +} + +// ParseAndValidateULID parses a ULID string and validates it +func ParseAndValidateULID(id string) (*ULID, error) { + u, err := ParseULID(id) + if err != nil { + return nil, err + } + + // Additional validation based on variant + switch u.Variant { + case ULIDVariantUUID: + if _, err := uuid.Parse(u.Value); err != nil { + return nil, fmt.Errorf("invalid UUID: %w", err) + } + + case ULIDVariantULIDv7, ULIDVariantULID: + if _, err := ulid.Parse(u.Value); err != nil { + return nil, fmt.Errorf("invalid ULID: %w", err) + } + + case ULIDVariantNano: + if len(u.Value) != 21 { + return nil, fmt.Errorf("invalid nano ULID: wrong length %d", len(u.Value)) + } + } + + return u, nil +} + +// GenerateID generates a standard ULID (UUID v4 format) +func GenerateID() string { + id := uuid.New() + return id.String() +} + +// GenerateIDWithPrefix generates a ULID with a custom prefix +func GenerateIDWithPrefix(prefix string) string { + id := uuid.New() + return fmt.Sprintf("%s-%s", prefix, id.String()) +} + +// GenerateShortID generates an 8-character ULID (first 8 chars of UUID) +func GenerateShortID() string { + id := uuid.New() + return strings.ReplaceAll(id.String(), "-", "")[:8] +} + +// GenerateULIDv7 generates a ULID v7 format ULID +func GenerateULIDv7() string { + id := ulid.Make() + return id.String() +} + +// GenerateNanoULID generates a nano-format ULID (21 chars) +func GenerateNanoULID(ulidType ULIDType) string { + gen := NewULIDGenerator() + u := gen.GenerateWithType(ulidType, ULIDVariantNano) + return u.String() +} + +// GenerateBase62ULID generates a base62 encoded ULID +func GenerateBase62ULID() string { + id := ulid.Make() + return base62Encode(id.Time()) +} + +// ValidateULID validates a ULID string +func ValidateULID(id string) error { + _, err := ParseAndValidateULID(id) + return err +} + +// ULIDVersion represents the version of ULID +type ULIDVersion string + +const ( + ULIDVersion4 ULIDVersion = "v4" // UUID v4 + ULIDVersion7 ULIDVersion = "v7" // ULID v7 +) + +// Version returns the ULID version +func (u *ULID) Version() ULIDVersion { + if u.Variant == ULIDVariantULIDv7 || u.Variant == ULIDVariantULID { + return ULIDVersion7 + } + return ULIDVersion4 +} + +// GetTimestamp extracts the timestamp from the ULID if available +func (u *ULID) GetTimestamp() (time.Time, error) { + switch u.Variant { + case ULIDVariantULIDv7, ULIDVariantULID: + id, err := ulid.Parse(u.Value) + if err != nil { + return time.Time{}, err + } + return ulid.Time(id.Time()), nil + + case ULIDVariantUUID: + // UUID v4 doesn't have embedded timestamp + // Use creation time from metadata if available + if !u.Time.IsZero() { + return u.Time, nil + } + return time.Time{}, fmt.Errorf("UUID v4 does not contain timestamp") + + default: + return u.Time, nil + } +} + +// Compare compares two ULIDs by timestamp (for sorting) +func Compare(a, b *ULID) int { + timeA, _ := a.GetTimestamp() + timeB, _ := b.GetTimestamp() + + if timeA.Before(timeB) { + return -1 + } else if timeA.After(timeB) { + return 1 + } + return 0 +} + +// ULIDSet manages a set of ULIDs with efficient lookup +type ULIDSet struct { + ulids map[string]*ULID + mu sync.RWMutex +} + +// NewULIDSet creates a new ULID set +func NewULIDSet() *ULIDSet { + return &ULIDSet{ + ulids: make(map[string]*ULID), + } +} + +// Add adds a ULID to the set +func (s *ULIDSet) Add(ulid *ULID) error { + if ulid == nil { + return fmt.Errorf("cannot add nil ULID") + } + + // Validate first + if err := ValidateULID(ulid.String()); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + key := ulid.String() + s.ulids[key] = ulid + return nil +} + +// Get retrieves a ULID by its string value +func (s *ULIDSet) Get(id string) (*ULID, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + ulid, ok := s.ulids[id] + return ulid, ok +} + +// Remove removes a ULID from the set +func (s *ULIDSet) Remove(id string) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.ulids, id) +} + +// List returns all ULIDs in the set +func (s *ULIDSet) List() []*ULID { + s.mu.RLock() + defer s.mu.RUnlock() + + ulids := make([]*ULID, 0, len(s.ulids)) + for _, ulid := range s.ulids { + ulids = append(ulids, ulid) + } + + // Sort by timestamp (newest first) + for i := 0; i < len(ulids); i++ { + for j := i + 1; j < len(ulids); j++ { + if Compare(ulids[j], ulids[i]) < 0 { + ulids[i], ulids[j] = ulids[j], ulids[i] + } + } + } + + return ulids +} + +// FilterByType returns ULIDs of a specific type +func (s *ULIDSet) FilterByType(ulidType ULIDType) []*ULID { + s.mu.RLock() + defer s.mu.RUnlock() + + result := make([]*ULID, 0) + for _, ulid := range s.ulids { + if ulid.Type == ulidType { + result = append(result, ulid) + } + } + + return result +} + +// FilterByVariant returns ULIDs of a specific variant +func (s *ULIDSet) FilterByVariant(variant ULIDVariant) []*ULID { + s.mu.RLock() + defer s.mu.RUnlock() + + result := make([]*ULID, 0) + for _, ulid := range s.ulids { + if ulid.Variant == variant { + result = append(result, ulid) + } + } + + return result +} + +// Count returns the number of ULIDs in the set +func (s *ULIDSet) Count() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.ulids) +} + +// Len returns the number of ULIDs in the set (alias for Count) +func (s *ULIDSet) Len() int { + return s.Count() +} + +// Clear removes all ULIDs from the set +func (s *ULIDSet) Clear() { + s.mu.Lock() + defer s.mu.Unlock() + s.ulids = make(map[string]*ULID) +} + +// ToSlice converts the set to a slice of ULID strings +func (s *ULIDSet) ToSlice() []string { + s.mu.RLock() + defer s.mu.RUnlock() + + ids := make([]string, 0, len(s.ulids)) + for id := range s.ulids { + ids = append(ids, id) + } + return ids +} + +// Contains checks if a ULID is in the set +func (s *ULIDSet) Contains(ulid *ULID) bool { + s.mu.RLock() + defer s.mu.RUnlock() + _, ok := s.ulids[ulid.String()] + return ok +} + +// Merge merges another ULIDSet into this one and returns a new set +func (s *ULIDSet) Merge(other *ULIDSet) *ULIDSet { + result := NewULIDSet() + + s.mu.RLock() + for k, v := range s.ulids { + result.ulids[k] = v + } + s.mu.RUnlock() + + other.mu.RLock() + for k, v := range other.ulids { + result.ulids[k] = v + } + other.mu.RUnlock() + + return result +} + +// Difference returns a new ULIDSet with elements in s but not in other +func (s *ULIDSet) Difference(other *ULIDSet) *ULIDSet { + result := NewULIDSet() + + s.mu.RLock() + defer s.mu.RUnlock() + + other.mu.RLock() + defer other.mu.RUnlock() + + for k, v := range s.ulids { + if _, exists := other.ulids[k]; !exists { + result.ulids[k] = v + } + } + + return result +} + +// ULIDPool manages a pool of pre-generated ULIDs +type ULIDPool struct { + gen *ULIDGenerator + ulids chan *ULID + mu sync.Mutex + closed bool +} + +// NewULIDPool creates a new ULID pool +func NewULIDPool(size int, gen *ULIDGenerator) *ULIDPool { + if gen == nil { + gen = NewULIDGenerator() + } + + p := &ULIDPool{ + gen: gen, + ulids: make(chan *ULID, size), + } + + // Pre-fill the pool + go p.fill(size) + + return p +} + +// fill fills the pool with ULIDs +func (p *ULIDPool) fill(count int) { + for i := 0; i < count; i++ { + p.mu.Lock() + closed := p.closed + p.mu.Unlock() + + if closed { + return + } + + select { + case p.ulids <- p.gen.Generate(): + default: + // Pool full, stop filling + return + } + } +} + +// Get gets a ULID from the pool or generates a new one +func (p *ULIDPool) Get() *ULID { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return p.gen.Generate() + } + + select { + case ulid := <-p.ulids: + // Async refill + go func() { + p.fill(1) + }() + return ulid + default: + return p.gen.Generate() + } +} + +// GetWithType gets a ULID of specific type from the pool +func (p *ULIDPool) GetWithType(ulidType ULIDType) *ULID { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return p.gen.GenerateWithType(ulidType, ULIDVariantUUID) + } + + // Try to get from pool + for { + select { + case ulid := <-p.ulids: + if ulid.Type == ulidType { + // Async refill + go func() { + p.fill(1) + }() + return ulid + } + // Wrong type, put back and try next + p.ulids <- ulid + + default: + // Pool empty, generate new + return p.gen.GenerateWithType(ulidType, ULIDVariantUUID) + } + } +} + +// Close closes the pool +func (p *ULIDPool) Close() { + p.mu.Lock() + defer p.mu.Unlock() + + p.closed = true + close(p.ulids) + + // Drain remaining ULIDs + for ulid := range p.ulids { + _ = ulid + } +} + +// Size returns the current pool size +func (p *ULIDPool) Size() int { + return len(p.ulids) +} + +// IsClosed returns true if the pool is closed +func (p *ULIDPool) IsClosed() bool { + return p.closed +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 22f66829f..54de66bf9 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -41,12 +41,12 @@ type Logger struct { } type LogEntry struct { - Level string `json:"level"` - Timestamp string `json:"timestamp"` - Component string `json:"component,omitempty"` - Message string `json:"message"` - Fields map[string]interface{} `json:"fields,omitempty"` - Caller string `json:"caller,omitempty"` + Level string `json:"level"` + Timestamp string `json:"timestamp"` + Component string `json:"component,omitempty"` + Message string `json:"message"` + Fields map[string]any `json:"fields,omitempty"` + Caller string `json:"caller,omitempty"` } func init() { @@ -71,7 +71,7 @@ func EnableFileLogging(filePath string) error { mu.Lock() defer mu.Unlock() - file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) if err != nil { return fmt.Errorf("failed to open log file: %w", err) } @@ -96,7 +96,7 @@ func DisableFileLogging() { } } -func logMessage(level LogLevel, component string, message string, fields map[string]interface{}) { +func logMessage(level LogLevel, component string, message string, fields map[string]any) { if level < currentLevel { return } @@ -150,7 +150,7 @@ func formatComponent(component string) string { return fmt.Sprintf(" %s:", component) } -func formatFields(fields map[string]interface{}) string { +func formatFields(fields map[string]any) string { var parts []string for k, v := range fields { parts = append(parts, fmt.Sprintf("%s=%v", k, v)) @@ -166,11 +166,11 @@ func DebugC(component string, message string) { logMessage(DEBUG, component, message, nil) } -func DebugF(message string, fields map[string]interface{}) { +func DebugF(message string, fields map[string]any) { logMessage(DEBUG, "", message, fields) } -func DebugCF(component string, message string, fields map[string]interface{}) { +func DebugCF(component string, message string, fields map[string]any) { logMessage(DEBUG, component, message, fields) } @@ -182,11 +182,11 @@ func InfoC(component string, message string) { logMessage(INFO, component, message, nil) } -func InfoF(message string, fields map[string]interface{}) { +func InfoF(message string, fields map[string]any) { logMessage(INFO, "", message, fields) } -func InfoCF(component string, message string, fields map[string]interface{}) { +func InfoCF(component string, message string, fields map[string]any) { logMessage(INFO, component, message, fields) } @@ -198,11 +198,11 @@ func WarnC(component string, message string) { logMessage(WARN, component, message, nil) } -func WarnF(message string, fields map[string]interface{}) { +func WarnF(message string, fields map[string]any) { logMessage(WARN, "", message, fields) } -func WarnCF(component string, message string, fields map[string]interface{}) { +func WarnCF(component string, message string, fields map[string]any) { logMessage(WARN, component, message, fields) } @@ -214,11 +214,11 @@ func ErrorC(component string, message string) { logMessage(ERROR, component, message, nil) } -func ErrorF(message string, fields map[string]interface{}) { +func ErrorF(message string, fields map[string]any) { logMessage(ERROR, "", message, fields) } -func ErrorCF(component string, message string, fields map[string]interface{}) { +func ErrorCF(component string, message string, fields map[string]any) { logMessage(ERROR, component, message, fields) } @@ -230,10 +230,10 @@ func FatalC(component string, message string) { logMessage(FATAL, component, message, nil) } -func FatalF(message string, fields map[string]interface{}) { +func FatalF(message string, fields map[string]any) { logMessage(FATAL, "", message, fields) } -func FatalCF(component string, message string, fields map[string]interface{}) { +func FatalCF(component string, message string, fields map[string]any) { logMessage(FATAL, component, message, fields) } diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go index 9b9c96820..6e6f8dfa8 100644 --- a/pkg/logger/logger_test.go +++ b/pkg/logger/logger_test.go @@ -54,11 +54,11 @@ func TestLoggerWithComponent(t *testing.T) { name string component string message string - fields map[string]interface{} + fields map[string]any }{ {"Simple message", "test", "Hello, world!", nil}, {"Message with component", "discord", "Discord message", nil}, - {"Message with fields", "telegram", "Telegram message", map[string]interface{}{ + {"Message with fields", "telegram", "Telegram message", map[string]any{ "user_id": "12345", "count": 42, }}, @@ -128,12 +128,12 @@ func TestLoggerHelperFunctions(t *testing.T) { Error("This should log") InfoC("test", "Component message") - InfoF("Fields message", map[string]interface{}{"key": "value"}) + InfoF("Fields message", map[string]any{"key": "value"}) WarnC("test", "Warning with component") - ErrorF("Error with fields", map[string]interface{}{"error": "test"}) + ErrorF("Error with fields", map[string]any{"error": "test"}) SetLevel(DEBUG) DebugC("test", "Debug with component") - WarnF("Warning with fields", map[string]interface{}{"key": "value"}) + WarnF("Warning with fields", map[string]any{"key": "value"}) } diff --git a/pkg/memory/acl.go b/pkg/memory/acl.go new file mode 100644 index 000000000..dcec5f786 --- /dev/null +++ b/pkg/memory/acl.go @@ -0,0 +1,403 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package memory + +import ( + "strings" + "sync" +) + +// ACLType represents the type of ACL entry +type ACLType int + +const ( + // ACLAllow is an allow-list entry + ACLAllow ACLType = iota + + // ACLDeny is a deny-list entry + ACLDeny +) + +// String returns the string representation of the ACL type +func (t ACLType) String() string { + switch t { + case ACLAllow: + return "allow" + case ACLDeny: + return "deny" + default: + return "unknown" + } +} + +// ACLEntry represents a single ACL rule +type ACLEntry struct { + // Type is either allow or deny + Type ACLType `json:"type"` + + // HID is the H-id this entry applies to (empty = all) + HID string `json:"hid,omitempty"` + + // SID is the S-id this entry applies to (empty = all in H-id) + SID string `json:"sid,omitempty"` + + // Permission specifies which permission this applies to (empty = all) + Permission Permission `json:"permission,omitempty"` + + // Reason is an optional reason for this entry + Reason string `json:"reason,omitempty"` +} + +// Matches checks if an ACL entry matches the given requester and permission +func (e *ACLEntry) Matches(hid, sid string, perm Permission) bool { + // Check permission match + // PermWildcard matches any permission, otherwise exact match is required + if e.Permission != PermWildcard && e.Permission != perm { + return false + } + + // Check HID match + if e.HID != "" && e.HID != hid { + return false + } + + // Check SID match (only if HID also matches or is empty) + if e.SID != "" && e.SID != sid { + // If SID is specified, HID must also match + if e.HID == "" || e.HID == hid { + return false + } + } + + return true +} + +// String returns a string representation of the ACL entry +func (e *ACLEntry) String() string { + var parts []string + + parts = append(parts, e.Type.String()) + + if e.HID == "" { + parts = append(parts, "*") + } else { + if e.SID != "" { + parts = append(parts, e.HID+"/"+e.SID) + } else { + parts = append(parts, e.HID) + parts = append(parts, "*") + } + } + + if e.Permission == PermWildcard { + parts = append(parts, "*") + } else { + parts = append(parts, e.Permission.String()) + } + + return strings.Join(parts, " ") +} + +// ACL manages access control lists for memory items +type ACL struct { + entries []ACLEntry + mu sync.RWMutex +} + +// NewACL creates a new empty ACL +func NewACL() *ACL { + return &ACL{ + entries: make([]ACLEntry, 0), + } +} + +// Add adds a new ACL entry +func (a *ACL) Add(entry ACLEntry) { + a.mu.Lock() + defer a.mu.Unlock() + a.entries = append(a.entries, entry) +} + +// Allow adds an allow entry for the given H-id/S-id and permission +func (a *ACL) Allow(hid, sid string, perm Permission) *ACL { + a.Add(ACLEntry{ + Type: ACLAllow, + HID: hid, + SID: sid, + Permission: perm, + }) + return a +} + +// Deny adds a deny entry for the given H-id/S-id and permission +func (a *ACL) Deny(hid, sid string, perm Permission) *ACL { + a.Add(ACLEntry{ + Type: ACLDeny, + HID: hid, + SID: sid, + Permission: perm, + }) + return a +} + +// AllowAll adds an allow-all entry +func (a *ACL) AllowAll() *ACL { + a.Add(ACLEntry{ + Type: ACLAllow, + Permission: PermWildcard, + }) + return a +} + +// DenyAll adds a deny-all entry +func (a *ACL) DenyAll() *ACL { + a.Add(ACLEntry{ + Type: ACLDeny, + Permission: PermWildcard, + }) + return a +} + +// Remove removes ACL entries that match the given criteria +func (a *ACL) Remove(hid, sid string, perm Permission) { + a.mu.Lock() + defer a.mu.Unlock() + + newEntries := make([]ACLEntry, 0, len(a.entries)) + for _, e := range a.entries { + if !e.Matches(hid, sid, perm) { + newEntries = append(newEntries, e) + } + } + a.entries = newEntries +} + +// Clear removes all ACL entries +func (a *ACL) Clear() { + a.mu.Lock() + defer a.mu.Unlock() + a.entries = make([]ACLEntry, 0) +} + +// Check checks if the given requester and permission is allowed +// Returns: (allowed, explicitMatch) +func (a *ACL) Check(hid, sid string, perm Permission) (bool, bool) { + a.mu.RLock() + defer a.mu.RUnlock() + + explicitMatch := false + + for _, e := range a.entries { + if !e.Matches(hid, sid, perm) { + continue + } + + explicitMatch = true + + switch e.Type { + case ACLAllow: + return true, true + case ACLDeny: + return false, true + } + } + + // No explicit match - default deny + return false, explicitMatch +} + +// IsAllowed is a convenience method that returns true if allowed +func (a *ACL) IsAllowed(hid, sid string, perm Permission) bool { + allowed, _ := a.Check(hid, sid, perm) + return allowed +} + +// Entries returns a copy of all ACL entries +func (a *ACL) Entries() []ACLEntry { + a.mu.RLock() + defer a.mu.RUnlock() + + result := make([]ACLEntry, len(a.entries)) + copy(result, a.entries) + return result +} + +// Len returns the number of ACL entries +func (a *ACL) Len() int { + a.mu.RLock() + defer a.mu.RUnlock() + return len(a.entries) +} + +// Merge merges another ACL into this one +func (a *ACL) Merge(other *ACL) { + if other == nil { + return + } + + entries := other.Entries() + a.mu.Lock() + defer a.mu.Unlock() + a.entries = append(a.entries, entries...) +} + +// Clone creates a deep copy of the ACL +func (a *ACL) Clone() *ACL { + a.mu.RLock() + defer a.mu.RUnlock() + + clone := &ACL{ + entries: make([]ACLEntry, len(a.entries)), + } + copy(clone.entries, a.entries) + return clone +} + +// ACLChecker combines permission checking with ACL support +type ACLChecker struct { + *Checker + acl *ACL +} + +// NewACLChecker creates a new ACL-enabled permission checker +func NewACLChecker() *ACLChecker { + return &ACLChecker{ + Checker: NewChecker(), + acl: NewACL(), + } +} + +// NewACLCheckerWithACL creates a new ACL checker with the given ACL +func NewACLCheckerWithACL(acl *ACL) *ACLChecker { + return &ACLChecker{ + Checker: NewChecker(), + acl: acl, + } +} + +// SetACL sets the ACL for this checker +func (c *ACLChecker) SetACL(acl *ACL) { + c.acl = acl +} + +// GetACL returns the ACL for this checker +func (c *ACLChecker) GetACL() *ACL { + return c.acl +} + +// Check checks permission using both scope-based and ACL-based rules +func (c *ACLChecker) Check(req *AccessRequest) *AccessResult { + if req == nil || req.Item == nil { + return &AccessResult{ + Allowed: false, + Reason: "invalid request", + } + } + + // First check ACL if present + if c.acl != nil && c.acl.Len() > 0 { + allowed, explicitMatch := c.acl.Check(req.RequesterHID, req.RequesterSID, req.Permission) + + if explicitMatch { + if allowed { + return &AccessResult{ + Allowed: true, + Reason: "allowed by ACL", + } + } + return &AccessResult{ + Allowed: false, + Reason: "denied by ACL", + } + } + } + + // Fall back to scope-based checking + return c.Checker.Check(req) +} + +// AddACLEntry adds an ACL entry to this checker +func (c *ACLChecker) AddACLEntry(entry ACLEntry) { + if c.acl == nil { + c.acl = NewACL() + } + c.acl.Add(entry) +} + +// Allow adds an allow ACL entry +func (c *ACLChecker) Allow(hid, sid string, perm Permission) { + if c.acl == nil { + c.acl = NewACL() + } + c.acl.Allow(hid, sid, perm) +} + +// Deny adds a deny ACL entry +func (c *ACLChecker) Deny(hid, sid string, perm Permission) { + if c.acl == nil { + c.acl = NewACL() + } + c.acl.Deny(hid, sid, perm) +} + +// MemoryItemWithACL extends MemoryItem with ACL support +type MemoryItemWithACL struct { + *MemoryItem + acl *ACL +} + +// NewMemoryItemWithACL creates a new memory item with ACL +func NewMemoryItemWithACL(item *MemoryItem) *MemoryItemWithACL { + return &MemoryItemWithACL{ + MemoryItem: item, + acl: NewACL(), + } +} + +// GetACL returns the ACL for this memory item +func (m *MemoryItemWithACL) GetACL() *ACL { + return m.acl +} + +// SetACL sets the ACL for this memory item +func (m *MemoryItemWithACL) SetACL(acl *ACL) { + m.acl = acl +} + +// Allow adds an allow entry to this memory item's ACL +func (m *MemoryItemWithACL) Allow(hid, sid string, perm Permission) { + if m.acl == nil { + m.acl = NewACL() + } + m.acl.Allow(hid, sid, perm) +} + +// Deny adds a deny entry to this memory item's ACL +func (m *MemoryItemWithACL) Deny(hid, sid string, perm Permission) { + if m.acl == nil { + m.acl = NewACL() + } + m.acl.Deny(hid, sid, perm) +} + +// CheckAccess checks access using both scope and ACL +func (m *MemoryItemWithACL) CheckAccess(checker *ACLChecker, requesterHID, requesterSID string, perm Permission) *AccessResult { + req := &AccessRequest{ + RequesterHID: requesterHID, + RequesterSID: requesterSID, + Permission: perm, + Item: m.MemoryItem, + } + + // Use the item's ACL if checker doesn't have one or has an empty one + if m.acl != nil && m.acl.Len() > 0 { + if checker.acl == nil || checker.acl.Len() == 0 { + checker.SetACL(m.acl) + } + } + + return checker.Check(req) +} diff --git a/pkg/memory/item.go b/pkg/memory/item.go new file mode 100644 index 000000000..303cf46b4 --- /dev/null +++ b/pkg/memory/item.go @@ -0,0 +1,332 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package memory + +import ( + "time" + + "github.com/google/uuid" +) + +// MemoryItem represents a single memory entry in the system +type MemoryItem struct { + // ID is a unique identifier for this memory item + ID string `json:"id"` + + // OwnerHID is the H-id (tenant) that owns this memory + OwnerHID string `json:"owner_hid"` + + // OwnerSID is the S-id (instance) that created this memory + OwnerSID string `json:"owner_sid"` + + // Scope defines the visibility scope of this memory + Scope MemoryScope `json:"scope"` + + // Type defines the type of memory content + Type MemoryType `json:"type"` + + // Key is an optional user-defined key for this memory + Key string `json:"key,omitempty"` + + // Content is the actual memory content + Content string `json:"content"` + + // Embedding is an optional vector embedding for semantic search + Embedding []float32 `json:"embedding,omitempty"` + + // Metadata contains additional information about this memory + Metadata map[string]string `json:"metadata,omitempty"` + + // Tags are optional tags for categorization and filtering + Tags []string `json:"tags,omitempty"` + + // ExpiresAt is an optional expiration time (0 = never expires) + ExpiresAt int64 `json:"expires_at,omitempty"` + + // CreatedAt is the timestamp when this memory was created + CreatedAt int64 `json:"created_at"` + + // UpdatedAt is the timestamp when this memory was last updated + UpdatedAt int64 `json:"updated_at"` + + // AccessedAt is the timestamp when this memory was last accessed + AccessedAt int64 `json:"accessed_at"` + + // AccessCount tracks how many times this memory has been accessed + AccessCount int64 `json:"access_count"` + + // Size is the approximate size in bytes + Size int64 `json:"size"` +} + +// NewMemoryItem creates a new memory item with the given parameters +func NewMemoryItem(ownerHID, ownerSID string, scope MemoryScope, memType MemoryType, content string) *MemoryItem { + now := time.Now().UnixMilli() + return &MemoryItem{ + ID: generateMemoryID(), + OwnerHID: ownerHID, + OwnerSID: ownerSID, + Scope: scope, + Type: memType, + Content: content, + Metadata: make(map[string]string), + Tags: make([]string, 0), + CreatedAt: now, + UpdatedAt: now, + AccessedAt: now, + Size: int64(len(content)), + } +} + +// NewPrivateMemory creates a new private memory item +func NewPrivateMemory(ownerHID, ownerSID string, memType MemoryType, content string) *MemoryItem { + return NewMemoryItem(ownerHID, ownerSID, ScopePrivate, memType, content) +} + +// NewSharedMemory creates a new shared memory item (same H-id) +func NewSharedMemory(ownerHID, ownerSID string, memType MemoryType, content string) *MemoryItem { + return NewMemoryItem(ownerHID, ownerSID, ScopeShared, memType, content) +} + +// NewPublicMemory creates a new public memory item (any H-id) +func NewPublicMemory(ownerHID, ownerSID string, memType MemoryType, content string) *MemoryItem { + return NewMemoryItem(ownerHID, ownerSID, ScopePublic, memType, content) +} + +// IsExpired returns true if the memory item has expired +func (m *MemoryItem) IsExpired() bool { + if m.ExpiresAt == 0 { + return false + } + return time.Now().UnixMilli() > m.ExpiresAt +} + +// Touch updates the accessed timestamp and increments access count +func (m *MemoryItem) Touch() { + m.AccessedAt = time.Now().UnixMilli() + m.AccessCount++ +} + +// UpdateContent updates the content and recalculates size +func (m *MemoryItem) UpdateContent(content string) { + m.Content = content + m.UpdatedAt = time.Now().UnixMilli() + m.Size = int64(len(content)) +} + +// SetKey sets the key for this memory item +func (m *MemoryItem) SetKey(key string) { + m.Key = key + m.UpdatedAt = time.Now().UnixMilli() +} + +// SetMetadata sets a metadata key-value pair +func (m *MemoryItem) SetMetadata(key, value string) { + if m.Metadata == nil { + m.Metadata = make(map[string]string) + } + m.Metadata[key] = value + m.UpdatedAt = time.Now().UnixMilli() +} + +// GetMetadata gets a metadata value by key +func (m *MemoryItem) GetMetadata(key string) (string, bool) { + if m.Metadata == nil { + return "", false + } + v, ok := m.Metadata[key] + return v, ok +} + +// AddTag adds a tag to the memory item +func (m *MemoryItem) AddTag(tag string) { + for _, t := range m.Tags { + if t == tag { + return // Already exists + } + } + m.Tags = append(m.Tags, tag) + m.UpdatedAt = time.Now().UnixMilli() +} + +// RemoveTag removes a tag from the memory item +func (m *MemoryItem) RemoveTag(tag string) { + for i, t := range m.Tags { + if t == tag { + m.Tags = append(m.Tags[:i], m.Tags[i+1:]...) + m.UpdatedAt = time.Now().UnixMilli() + return + } + } +} + +// HasTag checks if the memory item has a specific tag +func (m *MemoryItem) HasTag(tag string) bool { + for _, t := range m.Tags { + if t == tag { + return true + } + } + return false +} + +// SetExpiration sets the expiration time for this memory item +func (m *MemoryItem) SetExpiration(duration time.Duration) { + m.ExpiresAt = time.Now().Add(duration).UnixMilli() +} + +// SetExpirationAt sets a specific expiration timestamp +func (m *MemoryItem) SetExpirationAt(timestamp int64) { + m.ExpiresAt = timestamp +} + +// ClearExpiration clears the expiration time (makes it never expire) +func (m *MemoryItem) ClearExpiration() { + m.ExpiresAt = 0 +} + +// GetFullID returns the full ID with H-id and S-id prefix +func (m *MemoryItem) GetFullID() string { + return m.OwnerHID + "/" + m.OwnerSID + "/" + m.ID +} + +// Clone creates a deep copy of the memory item +func (m *MemoryItem) Clone() *MemoryItem { + clone := &MemoryItem{ + ID: m.ID, + OwnerHID: m.OwnerHID, + OwnerSID: m.OwnerSID, + Scope: m.Scope, + Type: m.Type, + Key: m.Key, + Content: m.Content, + ExpiresAt: m.ExpiresAt, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + AccessedAt: m.AccessedAt, + AccessCount: m.AccessCount, + Size: m.Size, + Metadata: make(map[string]string, len(m.Metadata)), + Tags: make([]string, len(m.Tags)), + } + + // Copy metadata + for k, v := range m.Metadata { + clone.Metadata[k] = v + } + + // Copy tags + copy(clone.Tags, m.Tags) + + // Copy embedding if present + if m.Embedding != nil { + clone.Embedding = make([]float32, len(m.Embedding)) + copy(clone.Embedding, m.Embedding) + } + + return clone +} + +// IsValid returns true if the memory item is valid +func (m *MemoryItem) IsValid() bool { + return m.ID != "" && m.OwnerHID != "" && m.OwnerSID != "" && m.Content != "" +} + +// MemoryFilter is used to filter memory items when querying +type MemoryFilter struct { + // OwnerHID filters by owner H-id (empty = all) + OwnerHID string + // OwnerSID filters by owner S-id (empty = all) + OwnerSID string + // Scope filters by scope (zero = wildcard unless ScopeSet is true) + Scope MemoryScope + ScopeSet bool // true if Scope was explicitly set + // Type filters by type (zero = wildcard unless TypeSet is true) + Type MemoryType + TypeSet bool // true if Type was explicitly set + // Tags filters by tags (any match) + Tags []string + // Key filters by key (empty = all) + Key string + // MinCreatedAt filters by minimum creation time + MinCreatedAt int64 + // MaxCreatedAt filters by maximum creation time + MaxCreatedAt int64 + // IncludeExpired includes expired items if true + IncludeExpired bool + // Limit limits the number of results + Limit int + // Offset skips the first N results + Offset int +} + +// Matches checks if a memory item matches the filter +func (f *MemoryFilter) Matches(item *MemoryItem) bool { + if f.OwnerHID != "" && item.OwnerHID != f.OwnerHID { + return false + } + if f.OwnerSID != "" && item.OwnerSID != f.OwnerSID { + return false + } + // Scope check: if ScopeSet is true, do exact match; otherwise wildcard + if f.ScopeSet && item.Scope != f.Scope { + return false + } + // Type check: if TypeSet is true, do exact match; otherwise wildcard + if f.TypeSet && item.Type != f.Type { + return false + } + if f.Key != "" && item.Key != f.Key { + return false + } + if !f.IncludeExpired && item.IsExpired() { + return false + } + if f.MinCreatedAt > 0 && item.CreatedAt < f.MinCreatedAt { + return false + } + if f.MaxCreatedAt > 0 && item.CreatedAt > f.MaxCreatedAt { + return false + } + if len(f.Tags) > 0 { + hasAnyTag := false + for _, tag := range f.Tags { + if item.HasTag(tag) { + hasAnyTag = true + break + } + } + if !hasAnyTag { + return false + } + } + return true +} + +// generateMemoryID generates a unique memory ID +func generateMemoryID() string { + return "mem-" + uuid.New().String()[:8] +} + +// MemoryQuery is used for complex queries with sorting and pagination +type MemoryQuery struct { + *MemoryFilter + SortBy string // "created_at", "updated_at", "accessed_at", "size" + SortOrder string // "asc", "desc" + Limit int + Offset int +} + +// NewMemoryQuery creates a new memory query +func NewMemoryQuery() *MemoryQuery { + return &MemoryQuery{ + MemoryFilter: &MemoryFilter{}, + SortBy: "created_at", + SortOrder: "desc", + Limit: 100, + Offset: 0, + } +} diff --git a/pkg/memory/jetstream/consumer.go b/pkg/memory/jetstream/consumer.go new file mode 100644 index 000000000..2cad7b8b9 --- /dev/null +++ b/pkg/memory/jetstream/consumer.go @@ -0,0 +1,428 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package jetstream + +import ( + "context" + "fmt" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// Consumer manages a JetStream consumer for memory operations +type Consumer struct { + js nats.JetStreamContext + streamName string + name string + sub *nats.Subscription + ctx context.Context + cancel context.CancelFunc +} + +// NewConsumer creates a new consumer +func NewConsumer(js nats.JetStreamContext, streamName, consumerName string) *Consumer { + ctx, cancel := context.WithCancel(context.Background()) + + return &Consumer{ + js: js, + streamName: streamName, + name: consumerName, + ctx: ctx, + cancel: cancel, + } +} + +// Subscribe creates a subscription for this consumer +func (c *Consumer) Subscribe(subject string, handler func(*nats.Msg), opts ...nats.SubOpt) error { + // Create consumer config if it doesn't exist + consumerCfg := &nats.ConsumerConfig{ + Durable: c.name, + AckPolicy: nats.AckExplicitPolicy, + AckWait: 30 * time.Second, + MaxDeliver: 3, + FilterSubject: subject, + DeliverPolicy: nats.DeliverAllPolicy, + } + + // Try to get existing consumer info + _, err := c.js.ConsumerInfo(c.streamName, c.name) + if err != nil { + // Consumer doesn't exist, create it + _, err = c.js.AddConsumer(c.streamName, consumerCfg) + if err != nil && err != nats.ErrConsumerNameAlreadyInUse { + return fmt.Errorf("failed to create consumer: %w", err) + } + } + + // Create pull subscription + sub, err := c.js.PullSubscribe(subject, c.name, opts...) + if err != nil { + return fmt.Errorf("failed to create subscription: %w", err) + } + + c.sub = sub + + // Start message processing loop + go c.processMessages(handler) + + logger.InfoCF("memory", "Consumer subscribed", map[string]interface{}{ + "stream": c.streamName, + "consumer": c.name, + "subject": subject, + }) + + return nil +} + +// processMessages processes messages from the subscription +func (c *Consumer) processMessages(handler func(*nats.Msg)) { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + if c.sub == nil { + continue + } + + msgs, err := c.sub.Fetch(10, nats.MaxWait(500*time.Millisecond)) + if err == nats.ErrTimeout { + continue + } + if err != nil { + logger.ErrorCF("memory", "Consumer fetch error", map[string]interface{}{ + "error": err.Error(), + "consumer": c.name, + }) + continue + } + + for _, msg := range msgs { + handler(msg) + } + } + } +} + +// Ack acknowledges a message +func (c *Consumer) Ack(msg *nats.Msg) error { + return msg.Ack() +} + +// Nak negatively acknowledges a message +func (c *Consumer) Nak(msg *nats.Msg) error { + return msg.Nak() +} + +// AckSync acknowledges a message synchronously +func (c *Consumer) AckSync(msg *nats.Msg) error { + return msg.AckSync() +} + +// Term terminates a message with no retry +func (c *Consumer) Term(msg *nats.Msg) error { + return msg.Term() +} + +// InProgress tells the server that work is ongoing +func (c *Consumer) InProgress(msg *nats.Msg) error { + return msg.InProgress() +} + +// Close closes the consumer +func (c *Consumer) Close() error { + c.cancel() + + if c.sub != nil { + if err := c.sub.Unsubscribe(); err != nil { + return fmt.Errorf("failed to unsubscribe: %w", err) + } + } + + logger.InfoCF("memory", "Consumer closed", map[string]interface{}{ + "consumer": c.name, + }) + + return nil +} + +// GetInfo returns consumer information +func (c *Consumer) GetInfo() (*nats.ConsumerInfo, error) { + return c.js.ConsumerInfo(c.streamName, c.name) +} + +// Pause pauses message delivery +func (c *Consumer) Pause() error { + if c.sub == nil { + return fmt.Errorf("no active subscription") + } + + return c.sub.SetPendingLimits(-1, -1) +} + +// Resume resumes message delivery +func (c *Consumer) Resume() error { + if c.sub == nil { + return fmt.Errorf("no active subscription") + } + + return c.sub.SetPendingLimits(1000, 100*1024*1024) +} + +// ConsumerSet manages multiple consumers +type ConsumerSet struct { + js nats.JetStreamContext + streamName string + consumers map[string]*Consumer +} + +// NewConsumerSet creates a new consumer set +func NewConsumerSet(js nats.JetStreamContext, streamName string) *ConsumerSet { + return &ConsumerSet{ + js: js, + streamName: streamName, + consumers: make(map[string]*Consumer), + } +} + +// Add adds a consumer to the set +func (cs *ConsumerSet) Add(name string, subject string, handler func(*nats.Msg)) error { + consumer := NewConsumer(cs.js, cs.streamName, name) + + if err := consumer.Subscribe(subject, handler); err != nil { + return err + } + + cs.consumers[name] = consumer + return nil +} + +// Get gets a consumer by name +func (cs *ConsumerSet) Get(name string) *Consumer { + return cs.consumers[name] +} + +// Remove removes a consumer from the set +func (cs *ConsumerSet) Remove(name string) error { + consumer, ok := cs.consumers[name] + if !ok { + return fmt.Errorf("consumer not found: %s", name) + } + + if err := consumer.Close(); err != nil { + return err + } + + delete(cs.consumers, name) + return nil +} + +// Close closes all consumers +func (cs *ConsumerSet) Close() error { + for name, consumer := range cs.consumers { + if err := consumer.Close(); err != nil { + logger.ErrorCF("memory", "Error closing consumer", map[string]interface{}{ + "consumer": name, + "error": err.Error(), + }) + } + } + + cs.consumers = make(map[string]*Consumer) + return nil +} + +// List returns all consumer names +func (cs *ConsumerSet) List() []string { + names := make([]string, 0, len(cs.consumers)) + for name := range cs.consumers { + names = append(names, name) + } + return names +} + +// MessageHandler is a function that handles memory messages +type MessageHandler func(msg *MemoryMessage) error + +// MemoryMessage wraps a NATS message with memory metadata +type MemoryMessage struct { + *nats.Msg + ID string + OwnerHID string + OwnerSID string + Type string + Timestamp int64 +} + +// MemoryConsumer is a specialized consumer for memory operations +type MemoryConsumer struct { + *Consumer + handler MessageHandler +} + +// NewMemoryConsumer creates a new memory consumer +func NewMemoryConsumer(js nats.JetStreamContext, streamName, consumerName string, handler MessageHandler) *MemoryConsumer { + base := NewConsumer(js, streamName, consumerName) + + return &MemoryConsumer{ + Consumer: base, + handler: handler, + } +} + +// Subscribe subscribes to memory messages +func (mc *MemoryConsumer) Subscribe(filterSubject string) error { + return mc.Consumer.Subscribe(filterSubject, func(msg *nats.Msg) { + // Parse memory metadata + memMsg := &MemoryMessage{ + Msg: msg, + } + + // Extract metadata from headers + if msg.Header != nil { + memMsg.ID = msg.Header.Get("X-Memory-ID") + memMsg.OwnerHID = msg.Header.Get("X-Owner-HID") + memMsg.OwnerSID = msg.Header.Get("X-Owner-SID") + memMsg.Type = msg.Header.Get("X-Memory-Type") + + // Parse timestamp + if ts := msg.Header.Get("X-Timestamp"); ts != "" { + fmt.Sscanf(ts, "%d", &memMsg.Timestamp) + } + } + + // Call handler + if err := mc.handler(memMsg); err != nil { + logger.ErrorCF("memory", "Handler error", map[string]interface{}{ + "error": err.Error(), + "msg_id": memMsg.ID, + "consumer": mc.name, + }) + msg.Nak() + } else { + msg.Ack() + } + }) +} + +// BatchConsumer handles messages in batches +type BatchConsumer struct { + *Consumer + handler func([]*MemoryMessage) error + batchSize int + batchTimeout time.Duration +} + +// NewBatchConsumer creates a new batch consumer +func NewBatchConsumer(js nats.JetStreamContext, streamName, consumerName string, handler func([]*MemoryMessage) error, batchSize int, batchTimeout time.Duration) *BatchConsumer { + base := NewConsumer(js, streamName, consumerName) + + return &BatchConsumer{ + Consumer: base, + handler: handler, + batchSize: batchSize, + batchTimeout: batchTimeout, + } +} + +// Subscribe subscribes to messages in batch mode +func (bc *BatchConsumer) Subscribe(filterSubject string) error { + batch := make([]*MemoryMessage, 0, bc.batchSize) + ticker := time.NewTicker(bc.batchTimeout) + defer ticker.Stop() + + flushBatch := func() error { + if len(batch) == 0 { + return nil + } + + messages := make([]*MemoryMessage, len(batch)) + copy(messages, batch) + + err := bc.handler(messages) + + // Ack all messages + for _, msg := range batch { + if err != nil { + msg.Msg.Nak() + } else { + msg.Msg.Ack() + } + } + + batch = batch[:0] + return err + } + + return bc.Consumer.Subscribe(filterSubject, func(msg *nats.Msg) { + memMsg := &MemoryMessage{ + Msg: msg, + ID: msg.Header.Get("X-Memory-ID"), + } + + batch = append(batch, memMsg) + + if len(batch) >= bc.batchSize { + if err := flushBatch(); err != nil { + logger.ErrorCF("memory", "Batch handler error", map[string]interface{}{ + "error": err.Error(), + }) + } + } + }) +} + +// GetConsumerInfo retrieves detailed information about a consumer +func GetConsumerInfo(js nats.JetStreamContext, streamName, consumerName string) (*nats.ConsumerInfo, error) { + return js.ConsumerInfo(streamName, consumerName) +} + +// ResetConsumer resets a consumer to start from the beginning +func ResetConsumer(js nats.JetStreamContext, streamName, consumerName string) error { + info, err := js.ConsumerInfo(streamName, consumerName) + if err != nil { + return err + } + + cfg := info.Config + cfg.DeliverPolicy = nats.DeliverAllPolicy + + _, err = js.UpdateConsumer(streamName, &cfg) + return err +} + +// SetAckWait sets the acknowledgment timeout for a consumer +func SetAckWait(js nats.JetStreamContext, streamName, consumerName string, ackWait time.Duration) error { + info, err := js.ConsumerInfo(streamName, consumerName) + if err != nil { + return err + } + + cfg := info.Config + cfg.AckWait = ackWait + + _, err = js.UpdateConsumer(streamName, &cfg) + return err +} + +// SetMaxDeliveries sets the maximum delivery count for a consumer +func SetMaxDeliveries(js nats.JetStreamContext, streamName, consumerName string, maxDeliveries int) error { + info, err := js.ConsumerInfo(streamName, consumerName) + if err != nil { + return err + } + + cfg := info.Config + cfg.MaxDeliver = maxDeliveries + + _, err = js.UpdateConsumer(streamName, &cfg) + return err +} diff --git a/pkg/memory/jetstream/store.go b/pkg/memory/jetstream/store.go new file mode 100644 index 000000000..48289b4ad --- /dev/null +++ b/pkg/memory/jetstream/store.go @@ -0,0 +1,447 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package jetstream + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// Store is a NATS JetStream-backed memory store +type Store struct { + conn *nats.Conn + js nats.JetStreamContext + stream string + bucket string // For KV store + + mu sync.RWMutex +} + +// Config holds configuration for the JetStream store +type Config struct { + // StreamName is the name of the JetStream stream + StreamName string + + // BucketName is the name of the KV bucket (optional, for CRUD operations) + BucketName string + + // Subjects is the list of subjects to include in the stream + Subjects []string + + // MaxAge is the maximum age of messages in the stream + MaxAge time.Duration + + // MaxBytes is the maximum size of the stream + MaxBytes int64 + + // Replicas is the number of stream replicas + Replicas int +} + +// DefaultConfig returns the default configuration +func DefaultConfig() *Config { + return &Config{ + StreamName: "PICOCLAW_MEMORY", + BucketName: "PICOCLAW_KV", + Subjects: []string{"picoclaw.memory.>"}, + MaxAge: 24 * time.Hour * 30, // 30 days + MaxBytes: 1024 * 1024 * 1024, // 1GB + Replicas: 1, + } +} + +// NewStore creates a new JetStream memory store +func NewStore(conn *nats.Conn) (*Store, error) { + js, err := conn.JetStream() + if err != nil { + return nil, fmt.Errorf("failed to get JetStream context: %w", err) + } + + return &Store{ + conn: conn, + js: js, + }, nil +} + +// Initialize sets up the JetStream stream and KV bucket +func (s *Store) Initialize(ctx context.Context, cfg *Config) error { + if cfg == nil { + cfg = DefaultConfig() + } + + s.mu.Lock() + s.stream = cfg.StreamName + s.bucket = cfg.BucketName + s.mu.Unlock() + + // Create stream + if err := s.createStream(ctx, cfg); err != nil { + return fmt.Errorf("failed to create stream: %w", err) + } + + // Create KV bucket + if cfg.BucketName != "" { + if err := s.createBucket(ctx, cfg); err != nil { + return fmt.Errorf("failed to create bucket: %w", err) + } + } + + logger.InfoCF("memory", "JetStream memory store initialized", map[string]interface{}{ + "stream": cfg.StreamName, + "bucket": cfg.BucketName, + }) + + return nil +} + +// createStream creates the JetStream stream +func (s *Store) createStream(ctx context.Context, cfg *Config) error { + stream, err := s.js.StreamInfo(cfg.StreamName) + if err != nil { + // Stream doesn't exist, create it + streamCfg := &nats.StreamConfig{ + Name: cfg.StreamName, + Subjects: cfg.Subjects, + MaxAge: cfg.MaxAge, + MaxBytes: cfg.MaxBytes, + Replicas: cfg.Replicas, + Retention: nats.LimitsPolicy, + Discard: nats.DiscardOld, + Storage: nats.FileStorage, + } + + _, err = s.js.AddStream(streamCfg) + if err != nil { + return fmt.Errorf("failed to add stream: %w", err) + } + + logger.InfoCF("memory", "Created JetStream stream", map[string]interface{}{ + "name": cfg.StreamName, + "subjects": fmt.Sprintf("%v", cfg.Subjects), + }) + } else { + logger.DebugCF("memory", "JetStream stream already exists", map[string]interface{}{ + "name": stream.Config.Name, + }) + } + + return nil +} + +// createBucket creates the KV bucket +func (s *Store) createBucket(ctx context.Context, cfg *Config) error { + _, err := s.js.KeyValue(cfg.BucketName) + if err != nil { + // Bucket doesn't exist, create it + _, err = s.js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: cfg.BucketName, + Description: "PicoClaw Memory KV Store", + MaxBytes: cfg.MaxBytes / 10, // 10% of stream size + TTL: cfg.MaxAge, + Storage: nats.FileStorage, + Replicas: cfg.Replicas, + }) + if err != nil { + return fmt.Errorf("failed to create KV bucket: %w", err) + } + + logger.InfoCF("memory", "Created KV bucket", map[string]interface{}{ + "name": cfg.BucketName, + }) + } else { + logger.DebugCF("memory", "KV bucket already exists", map[string]interface{}{ + "name": cfg.BucketName, + }) + } + + return nil +} + +// Publish publishes a memory item to the stream +func (s *Store) Publish(ctx context.Context, subject string, data []byte) error { + _, err := s.js.Publish(subject, data) + if err != nil { + return fmt.Errorf("failed to publish: %w", err) + } + + return nil +} + +// Subscribe subscribes to memory updates +func (s *Store) Subscribe(ctx context.Context, subject string, handler func(*nats.Msg)) (*nats.Subscription, error) { + // Create consumer subscription + sub, err := s.js.PullSubscribe(subject, "", nats.AckExplicit(), nats.DeliverAll()) + if err != nil { + return nil, fmt.Errorf("failed to subscribe: %w", err) + } + + // Start consumer loop + go func() { + for { + select { + case <-ctx.Done(): + sub.Unsubscribe() + return + default: + msgs, err := sub.Fetch(10, nats.MaxWait(2*time.Second)) + if err == nats.ErrTimeout { + continue + } + if err != nil { + logger.ErrorCF("memory", "Fetch error", map[string]interface{}{ + "error": err.Error(), + }) + continue + } + + for _, msg := range msgs { + handler(msg) + msg.Ack() + } + } + } + }() + + return sub, nil +} + +// Get retrieves a value from KV store +func (s *Store) Get(key string) ([]byte, error) { + s.mu.RLock() + bucket := s.bucket + s.mu.RUnlock() + + if bucket == "" { + return nil, fmt.Errorf("no KV bucket configured") + } + + kv, err := s.js.KeyValue(bucket) + if err != nil { + return nil, fmt.Errorf("failed to get KV bucket: %w", err) + } + + entry, err := kv.Get(key) + if err != nil { + if err == nats.ErrKeyNotFound { + return nil, ErrKeyNotFound{key} + } + return nil, fmt.Errorf("failed to get key: %w", err) + } + + return entry.Value(), nil +} + +// Put stores a value in KV store +func (s *Store) Put(key string, value []byte) error { + s.mu.RLock() + bucket := s.bucket + s.mu.RUnlock() + + if bucket == "" { + return fmt.Errorf("no KV bucket configured") + } + + kv, err := s.js.KeyValue(bucket) + if err != nil { + return fmt.Errorf("failed to get KV bucket: %w", err) + } + + _, err = kv.Put(key, value) + if err != nil { + return fmt.Errorf("failed to put key: %w", err) + } + + return nil +} + +// Delete removes a key from KV store +func (s *Store) Delete(key string) error { + s.mu.RLock() + bucket := s.bucket + s.mu.RUnlock() + + if bucket == "" { + return fmt.Errorf("no KV bucket configured") + } + + kv, err := s.js.KeyValue(bucket) + if err != nil { + return fmt.Errorf("failed to get KV bucket: %w", err) + } + + err = kv.Delete(key) + if err != nil { + return fmt.Errorf("failed to delete key: %w", err) + } + + return nil +} + +// Query performs a wildcard query on the KV store +func (s *Store) Query(prefix string) (map[string][]byte, error) { + s.mu.RLock() + bucket := s.bucket + s.mu.RUnlock() + + if bucket == "" { + return nil, fmt.Errorf("no KV bucket configured") + } + + kv, err := s.js.KeyValue(bucket) + if err != nil { + return nil, fmt.Errorf("failed to get KV bucket: %w", err) + } + + // Get all keys and values + keys, err := kv.Keys() + if err != nil { + return nil, fmt.Errorf("failed to get keys: %w", err) + } + + results := make(map[string][]byte) + + for _, key := range keys { + if matchesPrefix(key, prefix) { + entry, err := kv.Get(key) + if err != nil { + if err == nats.ErrKeyNotFound { + continue + } + return nil, fmt.Errorf("failed to get key %s: %w", key, err) + } + results[key] = entry.Value() + } + } + + return results, nil +} + +// Watch creates a watcher for key changes +func (s *Store) Watch(prefix string, callback func(key string, value []byte)) error { + s.mu.RLock() + bucket := s.bucket + s.mu.RUnlock() + + if bucket == "" { + return fmt.Errorf("no KV bucket configured") + } + + kv, err := s.js.KeyValue(bucket) + if err != nil { + return fmt.Errorf("failed to get KV bucket: %w", err) + } + + watcher, err := kv.WatchAll() + if err != nil { + return fmt.Errorf("failed to create watcher: %w", err) + } + + go func() { + for entry := range watcher.Updates() { + if entry != nil { + callback(entry.Key(), entry.Value()) + } + } + }() + + return nil +} + +// GetStreamInfo returns information about the stream +func (s *Store) GetStreamInfo() (*nats.StreamInfo, error) { + s.mu.RLock() + stream := s.stream + s.mu.RUnlock() + + return s.js.StreamInfo(stream) +} + +// GetBucketStatus returns status of the KV bucket +func (s *Store) GetBucketStatus() (*nats.KeyValueStatus, error) { + s.mu.RLock() + bucket := s.bucket + s.mu.RUnlock() + + if bucket == "" { + return nil, fmt.Errorf("no KV bucket configured") + } + + kv, err := s.js.KeyValue(bucket) + if err != nil { + return nil, fmt.Errorf("failed to get KV bucket: %w", err) + } + + status, err := kv.Status() + if err != nil { + return nil, fmt.Errorf("failed to get status: %w", err) + } + + return &status, nil +} + +// Close closes the store connection +func (s *Store) Close() error { + // NATS connection is managed externally, just clear references + s.mu.Lock() + s.stream = "" + s.bucket = "" + s.mu.Unlock() + return nil +} + +// ErrKeyNotFound is returned when a key is not found +type ErrKeyNotFound struct { + Key string +} + +func (e ErrKeyNotFound) Error() string { + return fmt.Sprintf("key not found: %s", e.Key) +} + +// matchesPrefix checks if a key matches a prefix +func matchesPrefix(key, prefix string) bool { + if prefix == "" || prefix == "*" { + return true + } + if prefix[len(prefix)-1] == '>' || prefix[len(prefix)-1] == '*' { + // Wildcard prefix + return len(key) >= len(prefix)-1 && key[:len(prefix)-1] == prefix[:len(prefix)-1] + } + return key == prefix +} + +// JSONPut is a convenience method to store JSON values +func (s *Store) JSONPut(key string, value interface{}) error { + data, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("failed to marshal JSON: %w", err) + } + return s.Put(key, data) +} + +// JSONGet is a convenience method to retrieve JSON values +func (s *Store) JSONGet(key string, dest interface{}) error { + data, err := s.Get(key) + if err != nil { + return err + } + return json.Unmarshal(data, dest) +} + +// PublishJSON publishes a JSON message to the stream +func (s *Store) PublishJSON(ctx context.Context, subject string, value interface{}) error { + data, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("failed to marshal JSON: %w", err) + } + return s.Publish(ctx, subject, data) +} diff --git a/pkg/memory/jetstream/stream.go b/pkg/memory/jetstream/stream.go new file mode 100644 index 000000000..01fd7635f --- /dev/null +++ b/pkg/memory/jetstream/stream.go @@ -0,0 +1,350 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package jetstream + +import ( + "context" + "fmt" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// StreamManager manages JetStream streams for memory storage +type StreamManager struct { + store *Store +} + +// NewStreamManager creates a new stream manager +func NewStreamManager(store *Store) *StreamManager { + return &StreamManager{store: store} +} + +// CreateMemoryStream creates a stream for a specific H-id +func (sm *StreamManager) CreateMemoryStream(ctx context.Context, hid string, cfg *StreamConfig) error { + streamName := fmt.Sprintf("PICOCLAW_MEMORY_%s", hid) + + if cfg == nil { + cfg = DefaultStreamConfig() + } + cfg.Name = streamName + + streamCfg := &nats.StreamConfig{ + Name: cfg.Name, + Subjects: cfg.Subjects, + MaxAge: cfg.MaxAge, + MaxBytes: cfg.MaxBytes, + Replicas: cfg.Replicas, + Retention: nats.LimitsPolicy, + Discard: nats.DiscardOld, + Storage: nats.FileStorage, + } + + if cfg.Compression { + streamCfg.Compression = nats.S2Compression + } + + js := sm.store.js + _, err := js.AddStream(streamCfg) + if err != nil && err != nats.ErrStreamNameAlreadyInUse { + return fmt.Errorf("failed to create memory stream: %w", err) + } + + logger.InfoCF("memory", "Created memory stream", map[string]interface{}{ + "name": streamName, + "subjects": fmt.Sprintf("%v", cfg.Subjects), + "max_age": cfg.MaxAge.String(), + }) + + return nil +} + +// DeleteMemoryStream deletes a stream for an H-id +func (sm *StreamManager) DeleteMemoryStream(ctx context.Context, hid string) error { + streamName := fmt.Sprintf("PICOCLAW_MEMORY_%s", hid) + + js := sm.store.js + err := js.DeleteStream(streamName) + if err != nil { + return fmt.Errorf("failed to delete stream: %w", err) + } + + logger.InfoCF("memory", "Deleted memory stream", map[string]interface{}{ + "name": streamName, + }) + + return nil +} + +// GetStreamInfo returns information about a memory stream +func (sm *StreamManager) GetStreamInfo(hid string) (*nats.StreamInfo, error) { + streamName := fmt.Sprintf("PICOCLAW_MEMORY_%s", hid) + + js := sm.store.js + return js.StreamInfo(streamName) +} + +// ListStreams lists all memory streams +func (sm *StreamManager) ListStreams() ([]*nats.StreamInfo, error) { + js := sm.store.js + + streams := make([]*nats.StreamInfo, 0) + streamChan := js.Streams() + + for info := range streamChan { + // Filter for memory streams + if len(info.Config.Name) > 15 && info.Config.Name[:15] == "PICOCLAW_MEMORY_" { + streams = append(streams, info) + } + } + + return streams, nil +} + +// PurgeStream purges all messages from a stream +func (sm *StreamManager) PurgeStream(hid string) error { + streamName := fmt.Sprintf("PICOCLAW_MEMORY_%s", hid) + + js := sm.store.js + stream, err := js.StreamInfo(streamName) + if err != nil { + return fmt.Errorf("failed to get stream info: %w", err) + } + + if err := js.PurgeStream(streamName); err != nil { + return fmt.Errorf("failed to purge stream: %w", err) + } + + logger.InfoCF("memory", "Purged memory stream", map[string]interface{}{ + "name": streamName, + "messages_before": stream.State.Msgs, + }) + + return nil +} + +// StreamConfig holds configuration for a memory stream +type StreamConfig struct { + Name string + Subjects []string + MaxAge time.Duration + MaxBytes int64 + Replicas int + Compression bool +} + +// DefaultStreamConfig returns default stream configuration +func DefaultStreamConfig() *StreamConfig { + return &StreamConfig{ + Subjects: []string{"picocraw.memory.>"}, + MaxAge: 30 * 24 * time.Hour, // 30 days + MaxBytes: 256 * 1024 * 1024, // 256MB + Replicas: 1, + Compression: false, + } +} + +// StreamStats holds statistics about a stream +type StreamStats struct { + Name string + Messages uint64 + Bytes uint64 + FirstSequence uint64 + LastSequence uint64 + CreateTime time.Time +} + +// GetStreamStats returns statistics for a stream +func (sm *StreamManager) GetStreamStats(hid string) (*StreamStats, error) { + info, err := sm.GetStreamInfo(hid) + if err != nil { + return nil, err + } + + return &StreamStats{ + Name: info.Config.Name, + Messages: info.State.Msgs, + Bytes: info.State.Bytes, + FirstSequence: info.State.FirstSeq, + LastSequence: info.State.LastSeq, + CreateTime: info.Created, + }, nil +} + +// GetAllStats returns statistics for all memory streams +func (sm *StreamManager) GetAllStats() (map[string]*StreamStats, error) { + streams, err := sm.ListStreams() + if err != nil { + return nil, err + } + + stats := make(map[string]*StreamStats) + + for _, stream := range streams { + // Extract H-id from stream name + if len(stream.Config.Name) > 15 { + hid := stream.Config.Name[15:] + stats[hid] = &StreamStats{ + Name: stream.Config.Name, + Messages: stream.State.Msgs, + Bytes: stream.State.Bytes, + FirstSequence: stream.State.FirstSeq, + LastSequence: stream.State.LastSeq, + CreateTime: stream.Created, + } + } + } + + return stats, nil +} + +// CompactStream compacts a stream by removing deleted messages +func (sm *StreamManager) CompactStream(hid string) error { + streamName := fmt.Sprintf("PICOCLAW_MEMORY_%s", hid) + + js := sm.store.js + + // Get current state + info, err := js.StreamInfo(streamName) + if err != nil { + return fmt.Errorf("failed to get stream info: %w", err) + } + + messagesBefore := info.State.Msgs + + // Compact the stream + if err := js.DeleteStream(streamName); err != nil { + return fmt.Errorf("failed to delete stream for compaction: %w", err) + } + + // Recreate with same config + streamCfg := info.Config + _, err = js.AddStream(&streamCfg) + if err != nil { + return fmt.Errorf("failed to recreate stream: %w", err) + } + + logger.InfoCF("memory", "Compacted memory stream", map[string]interface{}{ + "name": streamName, + "messages_before": messagesBefore, + }) + + return nil +} + +// ScaleStream adjusts the replication factor of a stream +func (sm *StreamManager) ScaleStream(hid string, replicas int) error { + streamName := fmt.Sprintf("PICOCLAW_MEMORY_%s", hid) + + js := sm.store.js + + info, err := js.StreamInfo(streamName) + if err != nil { + return fmt.Errorf("failed to get stream info: %w", err) + } + + // Update config + cfg := info.Config + cfg.Replicas = replicas + + _, err = js.UpdateStream(&cfg) + if err != nil { + return fmt.Errorf("failed to update stream: %w", err) + } + + logger.InfoCF("memory", "Scaled memory stream", map[string]interface{}{ + "name": streamName, + "replicas": replicas, + }) + + return nil +} + +// CreateConsumer creates a consumer for a stream +func (sm *StreamManager) CreateConsumer(hid string, consumerName string, cfg *ConsumerConfig) error { + streamName := fmt.Sprintf("PICOCLAW_MEMORY_%s", hid) + + if cfg == nil { + cfg = DefaultConsumerConfig() + } + + js := sm.store.js + + consumerCfg := &nats.ConsumerConfig{ + Durable: consumerName, + Description: consumerName, + AckPolicy: nats.AckExplicitPolicy, + AckWait: cfg.AckWait, + MaxDeliver: cfg.MaxDeliveries, + FilterSubject: cfg.FilterSubject, + ReplayPolicy: nats.ReplayInstantPolicy, + } + + _, err := js.AddConsumer(streamName, consumerCfg) + if err != nil && err != nats.ErrConsumerNameAlreadyInUse { + return fmt.Errorf("failed to create consumer: %w", err) + } + + logger.InfoCF("memory", "Created consumer", map[string]interface{}{ + "stream": streamName, + "consumer": consumerName, + "filter": cfg.FilterSubject, + }) + + return nil +} + +// DeleteConsumer deletes a consumer +func (sm *StreamManager) DeleteConsumer(hid, consumerName string) error { + streamName := fmt.Sprintf("PICOCLAW_MEMORY_%s", hid) + + js := sm.store.js + err := js.DeleteConsumer(streamName, consumerName) + if err != nil { + return fmt.Errorf("failed to delete consumer: %w", err) + } + + logger.InfoCF("memory", "Deleted consumer", map[string]interface{}{ + "stream": streamName, + "consumer": consumerName, + }) + + return nil +} + +// ListConsumers lists all consumers for a stream +func (sm *StreamManager) ListConsumers(hid string) ([]*nats.ConsumerInfo, error) { + streamName := fmt.Sprintf("PICOCLAW_MEMORY_%s", hid) + + js := sm.store.js + + // Get consumer info channel from the stream + consumerChan := js.Consumers(streamName) + + consumers := make([]*nats.ConsumerInfo, 0) + for cInfo := range consumerChan { + consumers = append(consumers, cInfo) + } + + return consumers, nil +} + +// ConsumerConfig holds configuration for a consumer +type ConsumerConfig struct { + FilterSubject string + AckWait time.Duration + MaxDeliveries int +} + +// DefaultConsumerConfig returns default consumer configuration +func DefaultConsumerConfig() *ConsumerConfig { + return &ConsumerConfig{ + FilterSubject: ">", + AckWait: 30 * time.Second, + MaxDeliveries: 3, + } +} diff --git a/pkg/memory/memory_test.go b/pkg/memory/memory_test.go new file mode 100644 index 000000000..af43a2d3f --- /dev/null +++ b/pkg/memory/memory_test.go @@ -0,0 +1,715 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package memory + +import ( + "testing" + "time" +) + +func TestMemoryScope_String(t *testing.T) { + tests := []struct { + scope MemoryScope + want string + }{ + {ScopePrivate, "private"}, + {ScopeShared, "shared"}, + {ScopePublic, "public"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := tt.scope.String(); got != tt.want { + t.Errorf("MemoryScope.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParseMemoryScope(t *testing.T) { + tests := []struct { + name string + s string + want MemoryScope + wantErr bool + }{ + {"private", "private", ScopePrivate, false}, + {"shared", "shared", ScopeShared, false}, + {"public", "public", ScopePublic, false}, + {"invalid", "invalid", ScopePrivate, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseMemoryScope(tt.s) + if (err != nil) != tt.wantErr { + t.Errorf("ParseMemoryScope() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.want { + t.Errorf("ParseMemoryScope() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMemoryType_String(t *testing.T) { + tests := []struct { + memType MemoryType + want string + }{ + {TypeConversation, "conversation"}, + {TypeKnowledge, "knowledge"}, + {TypeContext, "context"}, + {TypeToolResult, "tool_result"}, + {TypeUserPreference, "user_preference"}, + {TypeSystem, "system"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := tt.memType.String(); got != tt.want { + t.Errorf("MemoryType.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMemoryScope_Comparison(t *testing.T) { + tests := []struct { + name string + scope MemoryScope + other MemoryScope + more bool + less bool + level int + }{ + {"private vs shared", ScopePrivate, ScopeShared, false, true, 0}, + {"shared vs public", ScopeShared, ScopePublic, false, true, 1}, + {"public vs private", ScopePublic, ScopePrivate, true, false, 2}, + {"same", ScopeShared, ScopeShared, false, false, 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.scope.IsMorePermissive(tt.other); got != tt.more { + t.Errorf("IsMorePermissive() = %v, want %v", got, tt.more) + } + if got := tt.scope.IsLessPermissive(tt.other); got != tt.less { + t.Errorf("IsLessPermissive() = %v, want %v", got, tt.less) + } + if got := tt.scope.ScopeLevel(); got != tt.level { + t.Errorf("ScopeLevel() = %v, want %v", got, tt.level) + } + }) + } +} + +func TestNewMemoryItem(t *testing.T) { + hid := "user-alice" + sid := "node-01" + content := "test content" + + item := NewMemoryItem(hid, sid, ScopePrivate, TypeContext, content) + + if item.OwnerHID != hid { + t.Errorf("OwnerHID = %v, want %v", item.OwnerHID, hid) + } + if item.OwnerSID != sid { + t.Errorf("OwnerSID = %v, want %v", item.OwnerSID, sid) + } + if item.Scope != ScopePrivate { + t.Errorf("Scope = %v, want %v", item.Scope, ScopePrivate) + } + if item.Type != TypeContext { + t.Errorf("Type = %v, want %v", item.Type, TypeContext) + } + if item.Content != content { + t.Errorf("Content = %v, want %v", item.Content, content) + } + if item.ID == "" { + t.Errorf("ID should not be empty") + } + if item.CreatedAt == 0 { + t.Errorf("CreatedAt should be set") + } + if !item.IsValid() { + t.Errorf("Item should be valid") + } +} + +func TestMemoryItem_Expiration(t *testing.T) { + item := NewMemoryItem("user-alice", "node-01", ScopePrivate, TypeContext, "test") + + // No expiration by default + if item.IsExpired() { + t.Errorf("New item should not be expired") + } + + // Set expiration + item.SetExpiration(1 * time.Hour) + if item.IsExpired() { + t.Errorf("Item with future expiration should not be expired") + } + + // Set past expiration + item.SetExpirationAt(time.Now().UnixMilli() - 1000) + if !item.IsExpired() { + t.Errorf("Item with past expiration should be expired") + } + + // Clear expiration + item.ClearExpiration() + if item.IsExpired() { + t.Errorf("Item with cleared expiration should not be expired") + } +} + +func TestMemoryItem_Metadata(t *testing.T) { + item := NewMemoryItem("user-alice", "node-01", ScopePrivate, TypeContext, "test") + + item.SetMetadata("key1", "value1") + item.SetMetadata("key2", "value2") + + v, ok := item.GetMetadata("key1") + if !ok || v != "value1" { + t.Errorf("Metadata not set correctly") + } + + _, ok = item.GetMetadata("missing") + if ok { + t.Errorf("Expected false for missing key") + } +} + +func TestMemoryItem_Tags(t *testing.T) { + item := NewMemoryItem("user-alice", "node-01", ScopePrivate, TypeContext, "test") + + item.AddTag("tag1") + item.AddTag("tag2") + item.AddTag("tag1") // Duplicate + + if !item.HasTag("tag1") { + t.Errorf("Expected tag1") + } + if !item.HasTag("tag2") { + t.Errorf("Expected tag2") + } + if len(item.Tags) != 2 { + t.Errorf("Expected 2 tags, got %d", len(item.Tags)) + } + + item.RemoveTag("tag1") + if item.HasTag("tag1") { + t.Errorf("tag1 should be removed") + } + if len(item.Tags) != 1 { + t.Errorf("Expected 1 tag, got %d", len(item.Tags)) + } +} + +func TestMemoryItem_Touch(t *testing.T) { + item := NewMemoryItem("user-alice", "node-01", ScopePrivate, TypeContext, "test") + + firstAccess := item.AccessedAt + firstCount := item.AccessCount + + time.Sleep(10 * time.Millisecond) + item.Touch() + + if item.AccessedAt <= firstAccess { + t.Errorf("AccessedAt should be updated") + } + if item.AccessCount != firstCount+1 { + t.Errorf("AccessCount should be incremented") + } +} + +func TestMemoryItem_Clone(t *testing.T) { + original := NewMemoryItem("user-alice", "node-01", ScopePrivate, TypeContext, "test") + original.SetMetadata("key1", "value1") + original.AddTag("tag1") + original.SetKey("test-key") + + clone := original.Clone() + + // Verify all fields match + if clone.ID != original.ID { + t.Errorf("Clone.ID mismatch") + } + if clone.OwnerHID != original.OwnerHID { + t.Errorf("Clone.OwnerHID mismatch") + } + if clone.Content != original.Content { + t.Errorf("Clone.Content mismatch") + } + if !clone.HasTag("tag1") { + t.Errorf("Clone should have tag1") + } + if clone.Key != original.Key { + t.Errorf("Clone.Key mismatch") + } + + // Modify clone and ensure original is unchanged + clone.SetMetadata("key2", "value2") + _, ok := original.GetMetadata("key2") + if ok { + t.Errorf("Modifying clone affected original") + } +} + +func TestMemoryFilter_Matches(t *testing.T) { + item := NewMemoryItem("user-alice", "node-01", ScopeShared, TypeContext, "test") + item.SetKey("test-key") + item.AddTag("tag1") + + tests := []struct { + name string + filter *MemoryFilter + want bool + }{ + { + name: "no filter", + filter: &MemoryFilter{}, + want: true, + }, + { + name: "matching HID", + filter: &MemoryFilter{OwnerHID: "user-alice"}, + want: true, + }, + { + name: "non-matching HID", + filter: &MemoryFilter{OwnerHID: "user-bob"}, + want: false, + }, + { + name: "matching SID", + filter: &MemoryFilter{OwnerSID: "node-01"}, + want: true, + }, + { + name: "matching Scope", + filter: &MemoryFilter{Scope: ScopeShared, ScopeSet: true}, + want: true, + }, + { + name: "non-matching Scope", + filter: &MemoryFilter{Scope: ScopePrivate, ScopeSet: true}, + want: false, + }, + { + name: "matching Type", + filter: &MemoryFilter{Type: TypeContext, TypeSet: true}, + want: true, + }, + { + name: "matching Key", + filter: &MemoryFilter{Key: "test-key"}, + want: true, + }, + { + name: "matching Tag", + filter: &MemoryFilter{Tags: []string{"tag1"}}, + want: true, + }, + { + name: "non-matching Tag", + filter: &MemoryFilter{Tags: []string{"tag2"}}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.filter.Matches(item); got != tt.want { + t.Errorf("MemoryFilter.Matches() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestChecker_PrivateAccess(t *testing.T) { + checker := NewChecker() + item := NewPrivateMemory("user-alice", "node-01", TypeContext, "test") + + tests := []struct { + name string + requesterHID string + requesterSID string + allowed bool + }{ + {"owner", "user-alice", "node-01", true}, + {"same tenant different node", "user-alice", "node-02", false}, + {"different tenant", "user-bob", "node-01", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &AccessRequest{ + RequesterHID: tt.requesterHID, + RequesterSID: tt.requesterSID, + Permission: PermRead, + Item: item, + } + result := checker.Check(req) + if result.Allowed != tt.allowed { + t.Errorf("Checker.Check() = %v, want %v (reason: %s)", result.Allowed, tt.allowed, result.Reason) + } + }) + } +} + +func TestChecker_SharedAccess(t *testing.T) { + checker := NewChecker() + item := NewSharedMemory("user-alice", "node-01", TypeContext, "test") + + tests := []struct { + name string + requesterHID string + requesterSID string + perm Permission + allowed bool + }{ + {"owner read", "user-alice", "node-01", PermRead, true}, + {"owner write", "user-alice", "node-01", PermWrite, true}, + {"same tenant read", "user-alice", "node-02", PermRead, true}, + {"same tenant write", "user-alice", "node-02", PermWrite, false}, + {"different tenant", "user-bob", "node-01", PermRead, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &AccessRequest{ + RequesterHID: tt.requesterHID, + RequesterSID: tt.requesterSID, + Permission: tt.perm, + Item: item, + } + result := checker.Check(req) + if result.Allowed != tt.allowed { + t.Errorf("Checker.Check() = %v, want %v", result.Allowed, tt.allowed) + } + }) + } +} + +func TestChecker_PublicAccess(t *testing.T) { + checker := NewChecker() + item := NewPublicMemory("user-alice", "node-01", TypeKnowledge, "test") + + tests := []struct { + name string + requesterHID string + requesterSID string + perm Permission + allowed bool + }{ + {"anyone read", "user-bob", "node-99", PermRead, true}, + {"owner write", "user-alice", "node-01", PermWrite, true}, + {"other write", "user-bob", "node-99", PermWrite, false}, + {"owner delete", "user-alice", "node-01", PermDelete, true}, + {"other delete", "user-bob", "node-99", PermDelete, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &AccessRequest{ + RequesterHID: tt.requesterHID, + RequesterSID: tt.requesterSID, + Permission: tt.perm, + Item: item, + } + result := checker.Check(req) + if result.Allowed != tt.allowed { + t.Errorf("Checker.Check() = %v, want %v", result.Allowed, tt.allowed) + } + }) + } +} + +func TestChecker_ConvenienceMethods(t *testing.T) { + checker := NewChecker() + item := NewSharedMemory("user-alice", "node-01", TypeContext, "test") + + if !checker.CanRead(item, "user-alice", "node-01") { + t.Errorf("CanRead should return true for owner") + } + if !checker.CanWrite(item, "user-alice", "node-01") { + t.Errorf("CanWrite should return true for owner") + } + if !checker.CanDelete(item, "user-alice", "node-01") { + t.Errorf("CanDelete should return true for owner") + } + if !checker.CanShare(item, "user-alice", "node-01") { + t.Errorf("CanShare should return true for owner") + } +} + +func TestChecker_FilterByPermission(t *testing.T) { + checker := NewChecker() + + items := []*MemoryItem{ + NewPrivateMemory("user-alice", "node-01", TypeContext, "private"), + NewSharedMemory("user-alice", "node-01", TypeContext, "shared"), + NewPublicMemory("user-alice", "node-01", TypeKnowledge, "public"), + NewPrivateMemory("user-bob", "node-01", TypeContext, "bob-private"), + } + + // Filter for user-alice/node-01 (can access all of alice's items) + filtered := checker.FilterByPermission(items, "user-alice", "node-01", PermRead) + if len(filtered) != 3 { + t.Errorf("Expected 3 items for owner, got %d", len(filtered)) + } + + // Filter for user-alice/node-02 (can access shared and public) + filtered = checker.FilterByPermission(items, "user-alice", "node-02", PermRead) + if len(filtered) != 2 { + t.Errorf("Expected 2 items for same tenant, got %d", len(filtered)) + } + + // Filter for user-bob/node-01 (can access own private and alice's public) + filtered = checker.FilterByPermission(items, "user-bob", "node-01", PermRead) + if len(filtered) != 2 { + t.Errorf("Expected 2 items for different tenant, got %d", len(filtered)) + } +} + +func TestACL_AddCheck(t *testing.T) { + acl := NewACL() + + // Add allow rule + acl.Allow("user-bob", "", PermRead) + + // Check access + allowed, _ := acl.Check("user-bob", "node-01", PermRead) + if !allowed { + t.Errorf("Expected allow for user-bob") + } + + allowed, _ = acl.Check("user-bob", "node-01", PermWrite) + if allowed { + t.Errorf("Expected deny for write permission") + } + + // Add deny rule + acl.Deny("user-charlie", "", PermRead) + + allowed, _ = acl.Check("user-charlie", "node-01", PermRead) + if allowed { + t.Errorf("Expected deny for user-charlie") + } +} + +func TestACL_EntryMatching(t *testing.T) { + entry := ACLEntry{ + Type: ACLAllow, + HID: "user-alice", + SID: "node-01", + Permission: PermRead, + } + + tests := []struct { + name string + hid string + sid string + perm Permission + want bool + }{ + {"exact match", "user-alice", "node-01", PermRead, true}, + {"different sid", "user-alice", "node-02", PermRead, false}, + {"different hid", "user-bob", "node-01", PermRead, false}, + {"different perm", "user-alice", "node-01", PermWrite, false}, + {"wildcard permission", "user-alice", "node-01", PermRead, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := entry.Matches(tt.hid, tt.sid, tt.perm); got != tt.want { + t.Errorf("ACLEntry.Matches() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestACL_AllowDenyAll(t *testing.T) { + acl := NewACL() + + // Default deny + allowed, _ := acl.Check("anyone", "anywhere", PermRead) + if allowed { + t.Errorf("Expected default deny") + } + + // Allow all + acl.AllowAll() + allowed, _ = acl.Check("anyone", "anywhere", PermRead) + if !allowed { + t.Errorf("Expected allow after AllowAll") + } + + // Deny all + acl.Clear() + acl.DenyAll() + allowed, _ = acl.Check("anyone", "anywhere", PermRead) + if allowed { + t.Errorf("Expected deny after DenyAll") + } +} + +func TestACL_Remove(t *testing.T) { + acl := NewACL() + acl.Allow("user-alice", "", PermRead) + acl.Allow("user-bob", "", PermRead) + + if len(acl.Entries()) != 2 { + t.Errorf("Expected 2 entries") + } + + // Remove user-bob + acl.Remove("user-bob", "", PermRead) + + if len(acl.Entries()) != 1 { + t.Errorf("Expected 1 entry after removal") + } + + allowed, _ := acl.Check("user-bob", "node-01", PermRead) + if allowed { + t.Errorf("Expected deny after removal") + } +} + +func TestACLChecker_Integration(t *testing.T) { + checker := NewACLChecker() + item := NewPrivateMemory("user-alice", "node-01", TypeContext, "test") + + // Without ACL, owner has access + req := &AccessRequest{ + RequesterHID: "user-alice", + RequesterSID: "node-01", + Permission: PermRead, + Item: item, + } + result := checker.Check(req) + if !result.Allowed { + t.Errorf("Owner should have access by default") + } + + // Add ACL denying user-alice + checker.Deny("user-alice", "node-01", PermRead) + + result = checker.Check(req) + if result.Allowed { + t.Errorf("ACL deny should prevent access") + } + + // Add ACL allowing user-bob + checker.Allow("user-bob", "node-01", PermRead) + + req.RequesterHID = "user-bob" + result = checker.Check(req) + if !result.Allowed { + t.Errorf("ACL allow should grant access") + } +} + +func TestACLEntry_String(t *testing.T) { + tests := []struct { + entry ACLEntry + want string + }{ + { + ACLEntry{Type: ACLAllow, Permission: PermWildcard}, + "allow * *", + }, + { + ACLEntry{Type: ACLAllow, HID: "user-alice", Permission: PermRead}, + "allow user-alice * read", + }, + { + ACLEntry{Type: ACLDeny, HID: "user-alice", SID: "node-01", Permission: PermWrite}, + "deny user-alice/node-01 write", + }, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := tt.entry.String(); got != tt.want { + t.Errorf("ACLEntry.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDefaultScopeForType(t *testing.T) { + tests := []struct { + memType MemoryType + want MemoryScope + }{ + {TypeConversation, ScopeShared}, + {TypeKnowledge, ScopeShared}, + {TypeContext, ScopePrivate}, + {TypeToolResult, ScopePrivate}, + {TypeUserPreference, ScopeShared}, + {TypeSystem, ScopeShared}, + } + + for _, tt := range tests { + t.Run(tt.memType.String(), func(t *testing.T) { + if got := DefaultScopeForType(tt.memType); got != tt.want { + t.Errorf("DefaultScopeForType() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestScopeIsAtLeast(t *testing.T) { + if !ScopeIsAtLeast(ScopePublic, ScopePrivate) { + t.Errorf("Public should be at least Private") + } + if !ScopeIsAtLeast(ScopeShared, ScopeShared) { + t.Errorf("Shared should be at least Shared") + } + if ScopeIsAtLeast(ScopePrivate, ScopeShared) { + t.Errorf("Private should not be at least Shared") + } +} + +func TestNewMemoryQuery(t *testing.T) { + query := NewMemoryQuery() + + if query.SortBy != "created_at" { + t.Errorf("Default SortBy should be created_at") + } + if query.SortOrder != "desc" { + t.Errorf("Default SortOrder should be desc") + } + if query.Limit != 100 { + t.Errorf("Default Limit should be 100") + } +} + +func TestMemoryItemWithACL(t *testing.T) { + item := NewMemoryItem("user-alice", "node-01", ScopePrivate, TypeContext, "test") + itemWithACL := NewMemoryItemWithACL(item) + + // Add ACL to allow user-bob + itemWithACL.Allow("user-bob", "", PermRead) + + checker := NewACLChecker() + + // Check that user-bob can read + result := itemWithACL.CheckAccess(checker, "user-bob", "node-99", PermRead) + if !result.Allowed { + t.Errorf("user-bob should be allowed by ACL") + } + + // Check that user-bob cannot write + result = itemWithACL.CheckAccess(checker, "user-bob", "node-99", PermWrite) + if result.Allowed { + t.Errorf("user-bob should not have write permission") + } +} diff --git a/pkg/memory/permission.go b/pkg/memory/permission.go new file mode 100644 index 000000000..df74f0649 --- /dev/null +++ b/pkg/memory/permission.go @@ -0,0 +1,336 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package memory + +import ( + "fmt" +) + +// Permission represents the type of permission being checked +type Permission int + +const ( + // PermWildcard matches any permission (used for ACL entries) + PermWildcard Permission = -1 + + // PermRead allows reading memory content + PermRead Permission = iota + + // PermWrite allows writing/updating memory content + PermWrite + + // PermDelete allows deleting memory + PermDelete + + // PermShare allows changing memory scope + PermShare +) + +// String returns the string representation of the permission +func (p Permission) String() string { + switch p { + case PermWildcard: + return "*" + case PermRead: + return "read" + case PermWrite: + return "write" + case PermDelete: + return "delete" + case PermShare: + return "share" + default: + return "unknown" + } +} + +// AccessRequest represents a request to access a memory item +type AccessRequest struct { + // RequesterHID is the H-id of the requester + RequesterHID string + + // RequesterSID is the S-id of the requester + RequesterSID string + + // Permission is the type of permission being requested + Permission Permission + + // Item is the memory item being accessed + Item *MemoryItem +} + +// AccessResult represents the result of a permission check +type AccessResult struct { + // Allowed is true if access is granted + Allowed bool + + // Reason is a human-readable explanation of the decision + Reason string +} + +// Checker handles permission checking for memory access +type Checker struct { + // DefaultDeny causes all checks to deny by default unless explicitly allowed + DefaultDeny bool +} + +// NewChecker creates a new permission checker +func NewChecker() *Checker { + return &Checker{ + DefaultDeny: false, + } +} + +// Check checks if the given access request should be allowed +func (c *Checker) Check(req *AccessRequest) *AccessResult { + if req == nil || req.Item == nil { + return &AccessResult{ + Allowed: false, + Reason: "invalid request", + } + } + + item := req.Item + + // Check expiration first + if item.IsExpired() { + return &AccessResult{ + Allowed: false, + Reason: "memory item has expired", + } + } + + // Get the effective scope for the requester + effectiveScope := c.getEffectiveScope(item, req.RequesterHID, req.RequesterSID) + + // Check permission based on scope + switch effectiveScope { + case ScopePrivate: + return c.checkPrivateAccess(req) + case ScopeShared: + return c.checkSharedAccess(req) + case ScopePublic: + return c.checkPublicAccess(req) + default: + return &AccessResult{ + Allowed: false, + Reason: "invalid scope", + } + } +} + +// getEffectiveScope determines the effective scope for a requester +func (c *Checker) getEffectiveScope(item *MemoryItem, requesterHID, requesterSID string) MemoryScope { + // Owner always has the item's actual scope + if item.OwnerHID == requesterHID && item.OwnerSID == requesterSID { + return item.Scope + } + + // Same H-id but different S-id: treat as shared access + if item.OwnerHID == requesterHID { + if item.Scope == ScopePrivate { + return ScopePrivate // No access + } + return ScopeShared + } + + // Different H-id + return item.Scope +} + +// checkPrivateAccess checks access for private-scoped memory +func (c *Checker) checkPrivateAccess(req *AccessRequest) *AccessResult { + item := req.Item + + // Only the owner S-id can access private memory + if item.OwnerHID == req.RequesterHID && item.OwnerSID == req.RequesterSID { + return c.checkOwnerPermissions(req) + } + + return &AccessResult{ + Allowed: false, + Reason: fmt.Sprintf("private memory owned by %s/%s", item.OwnerHID, item.OwnerSID), + } +} + +// checkSharedAccess checks access for shared-scoped memory +func (c *Checker) checkSharedAccess(req *AccessRequest) *AccessResult { + item := req.Item + + // Same H-id: full access + if item.OwnerHID == req.RequesterHID { + if item.OwnerSID == req.RequesterSID { + return c.checkOwnerPermissions(req) + } + return c.checkTenantPermissions(req) + } + + // Different H-id: no access for shared memory + return &AccessResult{ + Allowed: false, + Reason: fmt.Sprintf("shared memory restricted to H-id %s", item.OwnerHID), + } +} + +// checkPublicAccess checks access for public-scoped memory +func (c *Checker) checkPublicAccess(req *AccessRequest) *AccessResult { + item := req.Item + + // All H-ids can read public memory + if req.Permission == PermRead { + return &AccessResult{ + Allowed: true, + Reason: "public memory is readable", + } + } + + // Write/Delete/Share requires ownership + if item.OwnerHID == req.RequesterHID && item.OwnerSID == req.RequesterSID { + return c.checkOwnerPermissions(req) + } + + return &AccessResult{ + Allowed: false, + Reason: "only owner can modify public memory", + } +} + +// checkOwnerPermissions checks permissions for the owner +func (c *Checker) checkOwnerPermissions(req *AccessRequest) *AccessResult { + // Owners have full permissions + return &AccessResult{ + Allowed: true, + Reason: "owner has full permissions", + } +} + +// checkTenantPermissions checks permissions for same-tenant non-owners +func (c *Checker) checkTenantPermissions(req *AccessRequest) *AccessResult { + // Same-tenant S-ids can read shared memory + if req.Permission == PermRead { + return &AccessResult{ + Allowed: true, + Reason: "shared memory is readable within tenant", + } + } + + return &AccessResult{ + Allowed: false, + Reason: "only owner can modify shared memory", + } +} + +// CanRead is a convenience method to check read permission +func (c *Checker) CanRead(item *MemoryItem, requesterHID, requesterSID string) bool { + req := &AccessRequest{ + RequesterHID: requesterHID, + RequesterSID: requesterSID, + Permission: PermRead, + Item: item, + } + return c.Check(req).Allowed +} + +// CanWrite is a convenience method to check write permission +func (c *Checker) CanWrite(item *MemoryItem, requesterHID, requesterSID string) bool { + req := &AccessRequest{ + RequesterHID: requesterHID, + RequesterSID: requesterSID, + Permission: PermWrite, + Item: item, + } + return c.Check(req).Allowed +} + +// CanDelete is a convenience method to check delete permission +func (c *Checker) CanDelete(item *MemoryItem, requesterHID, requesterSID string) bool { + req := &AccessRequest{ + RequesterHID: requesterHID, + RequesterSID: requesterSID, + Permission: PermDelete, + Item: item, + } + return c.Check(req).Allowed +} + +// CanShare is a convenience method to check share permission +func (c *Checker) CanShare(item *MemoryItem, requesterHID, requesterSID string) bool { + req := &AccessRequest{ + RequesterHID: requesterHID, + RequesterSID: requesterSID, + Permission: PermShare, + Item: item, + } + return c.Check(req).Allowed +} + +// FilterByPermission filters a list of memory items to those accessible by the requester +func (c *Checker) FilterByPermission(items []*MemoryItem, requesterHID, requesterSID string, perm Permission) []*MemoryItem { + result := make([]*MemoryItem, 0, len(items)) + req := &AccessRequest{ + RequesterHID: requesterHID, + RequesterSID: requesterSID, + Permission: perm, + } + + for _, item := range items { + req.Item = item + if c.Check(req).Allowed { + result = append(result, item) + } + } + + return result +} + +// ScopeHierarchy defines the hierarchy of scopes for permission inheritance +var ScopeHierarchy = []MemoryScope{ + ScopePrivate, + ScopeShared, + ScopePublic, +} + +// ScopeIsAtLeast returns true if the given scope is at least as permissive as the minimum +func ScopeIsAtLeast(scope, minimum MemoryScope) bool { + return scope.ScopeLevel() >= minimum.ScopeLevel() +} + +// ScopeIsAtMost returns true if the given scope is at most as permissive as the maximum +func ScopeIsAtMost(scope, maximum MemoryScope) bool { + return scope.ScopeLevel() <= maximum.ScopeLevel() +} + +// ScopeIsBetween returns true if the given scope is within the inclusive range +func ScopeIsBetween(scope, min, max MemoryScope) bool { + level := scope.ScopeLevel() + return level >= min.ScopeLevel() && level <= max.ScopeLevel() +} + +// DefaultScopeForType returns the recommended default scope for a given memory type +func DefaultScopeForType(memType MemoryType) MemoryScope { + switch memType { + case TypeConversation: + return ScopeShared + case TypeKnowledge: + return ScopeShared + case TypeContext: + return ScopePrivate + case TypeToolResult: + return ScopePrivate + case TypeUserPreference: + return ScopeShared + case TypeSystem: + return ScopeShared + default: + return ScopePrivate + } +} + +// NewMemoryWithDefaultScope creates a memory item with the default scope for its type +func NewMemoryWithDefaultScope(ownerHID, ownerSID string, memType MemoryType, content string) *MemoryItem { + scope := DefaultScopeForType(memType) + return NewMemoryItem(ownerHID, ownerSID, scope, memType, content) +} diff --git a/pkg/memory/types.go b/pkg/memory/types.go new file mode 100644 index 000000000..63020e836 --- /dev/null +++ b/pkg/memory/types.go @@ -0,0 +1,216 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package memory + +// MemoryScope defines the visibility scope of a memory item +type MemoryScope int + +const ( + // ScopePrivate indicates memory is only accessible by the owning S-id + ScopePrivate MemoryScope = iota + + // ScopeShared indicates memory is accessible by all S-ids within the same H-id + ScopeShared + + // ScopePublic indicates memory is accessible by any H-id (with authorization) + ScopePublic +) + +// String returns the string representation of the scope +func (s MemoryScope) String() string { + switch s { + case ScopePrivate: + return "private" + case ScopeShared: + return "shared" + case ScopePublic: + return "public" + default: + return "unknown" + } +} + +// ParseMemoryScope parses a string into a MemoryScope +func ParseMemoryScope(s string) (MemoryScope, error) { + switch s { + case "private": + return ScopePrivate, nil + case "shared": + return ScopeShared, nil + case "public": + return ScopePublic, nil + default: + return ScopePrivate, ErrInvalidScope{s} + } +} + +// MarshalJSON implements json.Marshaler +func (s MemoryScope) MarshalJSON() ([]byte, error) { + return []byte(`"` + s.String() + `"`), nil +} + +// UnmarshalJSON implements json.Unmarshaler +func (s *MemoryScope) UnmarshalJSON(data []byte) error { + str, err := unquoteString(data) + if err != nil { + return err + } + parsed, err := ParseMemoryScope(str) + if err != nil { + return err + } + *s = parsed + return nil +} + +// MemoryType defines the type of memory content +type MemoryType int + +const ( + // TypeConversation stores conversation history + TypeConversation MemoryType = iota + + // TypeKnowledge stores facts and knowledge + TypeKnowledge + + // TypeContext stores temporary context information + TypeContext + + // TypeToolResult stores tool execution results + TypeToolResult + + // TypeUserPreference stores user preferences + TypeUserPreference + + // TypeSystem stores system-level information + TypeSystem +) + +// String returns the string representation of the memory type +func (t MemoryType) String() string { + switch t { + case TypeConversation: + return "conversation" + case TypeKnowledge: + return "knowledge" + case TypeContext: + return "context" + case TypeToolResult: + return "tool_result" + case TypeUserPreference: + return "user_preference" + case TypeSystem: + return "system" + default: + return "unknown" + } +} + +// ParseMemoryType parses a string into a MemoryType +func ParseMemoryType(s string) (MemoryType, error) { + switch s { + case "conversation": + return TypeConversation, nil + case "knowledge": + return TypeKnowledge, nil + case "context": + return TypeContext, nil + case "tool_result", "tool-result": + return TypeToolResult, nil + case "user_preference", "user-preference": + return TypeUserPreference, nil + case "system": + return TypeSystem, nil + default: + return TypeContext, ErrInvalidType{s} + } +} + +// MarshalJSON implements json.Marshaler +func (t MemoryType) MarshalJSON() ([]byte, error) { + return []byte(`"` + t.String() + `"`), nil +} + +// UnmarshalJSON implements json.Unmarshaler +func (t *MemoryType) UnmarshalJSON(data []byte) error { + str, err := unquoteString(data) + if err != nil { + return err + } + parsed, err := ParseMemoryType(str) + if err != nil { + return err + } + *t = parsed + return nil +} + +// MemoryItemType is an alias for backward compatibility +type MemoryItemType = MemoryType + +// PermissionError represents a permission check failure +type PermissionError struct { + Operation string + Scope MemoryScope + Reason string +} + +func (e PermissionError) Error() string { + if e.Reason != "" { + return e.Reason + } + return "permission denied: " + e.Operation + " on " + e.Scope.String() +} + +// ErrInvalidScope is returned when an invalid scope is provided +type ErrInvalidScope struct { + Scope string +} + +func (e ErrInvalidScope) Error() string { + return "invalid memory scope: " + e.Scope +} + +// ErrInvalidType is returned when an invalid memory type is provided +type ErrInvalidType struct { + Type string +} + +func (e ErrInvalidType) Error() string { + return "invalid memory type: " + e.Type +} + +// helper function to unquote JSON strings +func unquoteString(data []byte) (string, error) { + if len(data) >= 2 && data[0] == '"' && data[len(data)-1] == '"' { + return string(data[1 : len(data)-1]), nil + } + return "", ErrInvalidType{"not a quoted string"} +} + +// ScopeLevel returns the numeric level of a scope for comparison +func (s MemoryScope) ScopeLevel() int { + switch s { + case ScopePrivate: + return 0 + case ScopeShared: + return 1 + case ScopePublic: + return 2 + default: + return 0 + } +} + +// IsMorePermissive returns true if this scope is more permissive than the other +func (s MemoryScope) IsMorePermissive(other MemoryScope) bool { + return s.ScopeLevel() > other.ScopeLevel() +} + +// IsLessPermissive returns true if this scope is less permissive than the other +func (s MemoryScope) IsLessPermissive(other MemoryScope) bool { + return s.ScopeLevel() < other.ScopeLevel() +} diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go index d7fa63305..2237a1429 100644 --- a/pkg/migrate/config.go +++ b/pkg/migrate/config.go @@ -12,13 +12,16 @@ import ( ) var supportedProviders = map[string]bool{ - "anthropic": true, - "openai": true, - "openrouter": true, - "groq": true, - "zhipu": true, - "vllm": true, - "gemini": true, + "anthropic": true, + "openai": true, + "openrouter": true, + "groq": true, + "zhipu": true, + "vllm": true, + "gemini": true, + "qwen": true, + "deepseek": true, + "github_copilot": true, } var supportedChannels = map[string]bool{ @@ -27,7 +30,7 @@ var supportedChannels = map[string]bool{ "whatsapp": true, "feishu": true, "qq": true, - "dingtalk": true, + "dingtalk": true, "maixcam": true, } @@ -44,26 +47,26 @@ func findOpenClawConfig(openclawHome string) (string, error) { return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome) } -func LoadOpenClawConfig(configPath string) (map[string]interface{}, error) { +func LoadOpenClawConfig(configPath string) (map[string]any, error) { data, err := os.ReadFile(configPath) if err != nil { return nil, fmt.Errorf("reading OpenClaw config: %w", err) } - var raw map[string]interface{} + var raw map[string]any if err := json.Unmarshal(data, &raw); err != nil { return nil, fmt.Errorf("parsing OpenClaw config: %w", err) } converted := convertKeysToSnake(raw) - result, ok := converted.(map[string]interface{}) + result, ok := converted.(map[string]any) if !ok { return nil, fmt.Errorf("unexpected config format") } return result, nil } -func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error) { +func ConvertConfig(data map[string]any) (*config.Config, []string, error) { cfg := config.DefaultConfig() var warnings []string @@ -76,7 +79,7 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error cfg.Agents.Defaults.MaxTokens = int(v) } if v, ok := getFloat(defaults, "temperature"); ok { - cfg.Agents.Defaults.Temperature = v + cfg.Agents.Defaults.Temperature = &v } if v, ok := getFloat(defaults, "max_tool_iterations"); ok { cfg.Agents.Defaults.MaxToolIterations = int(v) @@ -89,7 +92,7 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error if providers, ok := getMap(data, "providers"); ok { for name, val := range providers { - pMap, ok := val.(map[string]interface{}) + pMap, ok := val.(map[string]any) if !ok { continue } @@ -108,7 +111,10 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error case "anthropic": cfg.Providers.Anthropic = pc case "openai": - cfg.Providers.OpenAI = pc + cfg.Providers.OpenAI = config.OpenAIProviderConfig{ + ProviderConfig: pc, + WebSearch: getBoolOrDefault(pMap, "web_search", true), + } case "openrouter": cfg.Providers.OpenRouter = pc case "groq": @@ -125,7 +131,7 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error if channels, ok := getMap(data, "channels"); ok { for name, val := range channels { - cMap, ok := val.(map[string]interface{}) + cMap, ok := val.(map[string]any) if !ok { continue } @@ -212,12 +218,17 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error if tools, ok := getMap(data, "tools"); ok { if web, ok := getMap(tools, "web"); ok { + // Migrate old "search" config to "brave" if api_key is present if search, ok := getMap(web, "search"); ok { if v, ok := getString(search, "api_key"); ok { - cfg.Tools.Web.Search.APIKey = v + cfg.Tools.Web.Brave.APIKey = v + if v != "" { + cfg.Tools.Web.Brave.Enabled = true + } } if v, ok := getFloat(search, "max_results"); ok { - cfg.Tools.Web.Search.MaxResults = int(v) + cfg.Tools.Web.Brave.MaxResults = int(v) + cfg.Tools.Web.DuckDuckGo.MaxResults = int(v) } } } @@ -248,6 +259,15 @@ func MergeConfig(existing, incoming *config.Config) *config.Config { if existing.Providers.Gemini.APIKey == "" { existing.Providers.Gemini = incoming.Providers.Gemini } + if existing.Providers.DeepSeek.APIKey == "" { + existing.Providers.DeepSeek = incoming.Providers.DeepSeek + } + if existing.Providers.GitHubCopilot.APIBase == "" { + existing.Providers.GitHubCopilot = incoming.Providers.GitHubCopilot + } + if existing.Providers.Qwen.APIKey == "" { + existing.Providers.Qwen = incoming.Providers.Qwen + } if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled { existing.Channels.Telegram = incoming.Channels.Telegram @@ -271,8 +291,8 @@ func MergeConfig(existing, incoming *config.Config) *config.Config { existing.Channels.MaixCam = incoming.Channels.MaixCam } - if existing.Tools.Web.Search.APIKey == "" { - existing.Tools.Web.Search = incoming.Tools.Web.Search + if existing.Tools.Web.Brave.APIKey == "" { + existing.Tools.Web.Brave = incoming.Tools.Web.Brave } return existing @@ -298,16 +318,16 @@ func camelToSnake(s string) string { return result.String() } -func convertKeysToSnake(data interface{}) interface{} { +func convertKeysToSnake(data any) any { switch v := data.(type) { - case map[string]interface{}: - result := make(map[string]interface{}, len(v)) + case map[string]any: + result := make(map[string]any, len(v)) for key, val := range v { result[camelToSnake(key)] = convertKeysToSnake(val) } return result - case []interface{}: - result := make([]interface{}, len(v)) + case []any: + result := make([]any, len(v)) for i, val := range v { result[i] = convertKeysToSnake(val) } @@ -322,16 +342,16 @@ func rewriteWorkspacePath(path string) string { return path } -func getMap(data map[string]interface{}, key string) (map[string]interface{}, bool) { +func getMap(data map[string]any, key string) (map[string]any, bool) { v, ok := data[key] if !ok { return nil, false } - m, ok := v.(map[string]interface{}) + m, ok := v.(map[string]any) return m, ok } -func getString(data map[string]interface{}, key string) (string, bool) { +func getString(data map[string]any, key string) (string, bool) { v, ok := data[key] if !ok { return "", false @@ -340,7 +360,7 @@ func getString(data map[string]interface{}, key string) (string, bool) { return s, ok } -func getFloat(data map[string]interface{}, key string) (float64, bool) { +func getFloat(data map[string]any, key string) (float64, bool) { v, ok := data[key] if !ok { return 0, false @@ -349,7 +369,7 @@ func getFloat(data map[string]interface{}, key string) (float64, bool) { return f, ok } -func getBool(data map[string]interface{}, key string) (bool, bool) { +func getBool(data map[string]any, key string) (bool, bool) { v, ok := data[key] if !ok { return false, false @@ -358,12 +378,19 @@ func getBool(data map[string]interface{}, key string) (bool, bool) { return b, ok } -func getStringSlice(data map[string]interface{}, key string) []string { +func getBoolOrDefault(data map[string]any, key string, defaultVal bool) bool { + if v, ok := getBool(data, key); ok { + return v + } + return defaultVal +} + +func getStringSlice(data map[string]any, key string) []string { v, ok := data[key] if !ok { return []string{} } - arr, ok := v.([]interface{}) + arr, ok := v.([]any) if !ok { return []string{} } diff --git a/pkg/migrate/migrate.go b/pkg/migrate/migrate.go index 921f821cb..cfa82b7d7 100644 --- a/pkg/migrate/migrate.go +++ b/pkg/migrate/migrate.go @@ -67,7 +67,7 @@ func Run(opts Options) (*Result, error) { return nil, err } - if _, err := os.Stat(openclawHome); os.IsNotExist(err) { + if _, err = os.Stat(openclawHome); os.IsNotExist(err) { return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome) } @@ -161,7 +161,7 @@ func Execute(actions []Action, openclawHome, picoClawHome string) *Result { fmt.Printf(" ✓ Converted config: %s\n", action.Destination) } case ActionCreateDir: - if err := os.MkdirAll(action.Destination, 0755); err != nil { + if err := os.MkdirAll(action.Destination, 0o755); err != nil { result.Errors = append(result.Errors, err) } else { result.DirsCreated++ @@ -174,9 +174,13 @@ func Execute(actions []Action, openclawHome, picoClawHome string) *Result { continue } result.BackupsCreated++ - fmt.Printf(" ✓ Backed up %s -> %s.bak\n", filepath.Base(action.Destination), filepath.Base(action.Destination)) + fmt.Printf( + " ✓ Backed up %s -> %s.bak\n", + filepath.Base(action.Destination), + filepath.Base(action.Destination), + ) - if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(action.Destination), 0o755); err != nil { result.Errors = append(result.Errors, err) continue } @@ -188,7 +192,7 @@ func Execute(actions []Action, openclawHome, picoClawHome string) *Result { fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome)) } case ActionCopy: - if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(action.Destination), 0o755); err != nil { result.Errors = append(result.Errors, err) continue } @@ -226,7 +230,7 @@ func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) e incoming = MergeConfig(existing, incoming) } - if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0o755); err != nil { return err } return config.SaveConfig(dstConfigPath, incoming) diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go index d93ea28fc..b6b3d70aa 100644 --- a/pkg/migrate/migrate_test.go +++ b/pkg/migrate/migrate_test.go @@ -40,43 +40,43 @@ func TestCamelToSnake(t *testing.T) { } func TestConvertKeysToSnake(t *testing.T) { - input := map[string]interface{}{ + input := map[string]any{ "apiKey": "test-key", "apiBase": "https://example.com", - "nested": map[string]interface{}{ - "maxTokens": float64(8192), - "allowFrom": []interface{}{"user1", "user2"}, - "deeperLevel": map[string]interface{}{ + "nested": map[string]any{ + "maxTokens": float64(8192), + "allowFrom": []any{"user1", "user2"}, + "deeperLevel": map[string]any{ "clientId": "abc", }, }, } result := convertKeysToSnake(input) - m, ok := result.(map[string]interface{}) + m, ok := result.(map[string]any) if !ok { t.Fatal("expected map[string]interface{}") } - if _, ok := m["api_key"]; !ok { + if _, ok = m["api_key"]; !ok { t.Error("expected key 'api_key' after conversion") } - if _, ok := m["api_base"]; !ok { + if _, ok = m["api_base"]; !ok { t.Error("expected key 'api_base' after conversion") } - nested, ok := m["nested"].(map[string]interface{}) + nested, ok := m["nested"].(map[string]any) if !ok { t.Fatal("expected nested map") } - if _, ok := nested["max_tokens"]; !ok { + if _, ok = nested["max_tokens"]; !ok { t.Error("expected key 'max_tokens' in nested map") } - if _, ok := nested["allow_from"]; !ok { + if _, ok = nested["allow_from"]; !ok { t.Error("expected key 'allow_from' in nested map") } - deeper, ok := nested["deeper_level"].(map[string]interface{}) + deeper, ok := nested["deeper_level"].(map[string]any) if !ok { t.Fatal("expected deeper_level map") } @@ -89,15 +89,15 @@ func TestLoadOpenClawConfig(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "openclaw.json") - openclawConfig := map[string]interface{}{ - "providers": map[string]interface{}{ - "anthropic": map[string]interface{}{ + openclawConfig := map[string]any{ + "providers": map[string]any{ + "anthropic": map[string]any{ "apiKey": "sk-ant-test123", "apiBase": "https://api.anthropic.com", }, }, - "agents": map[string]interface{}{ - "defaults": map[string]interface{}{ + "agents": map[string]any{ + "defaults": map[string]any{ "maxTokens": float64(4096), "model": "claude-3-opus", }, @@ -108,7 +108,7 @@ func TestLoadOpenClawConfig(t *testing.T) { if err != nil { t.Fatal(err) } - if err := os.WriteFile(configPath, data, 0644); err != nil { + if err = os.WriteFile(configPath, data, 0o644); err != nil { t.Fatal(err) } @@ -117,11 +117,11 @@ func TestLoadOpenClawConfig(t *testing.T) { t.Fatalf("LoadOpenClawConfig: %v", err) } - providers, ok := result["providers"].(map[string]interface{}) + providers, ok := result["providers"].(map[string]any) if !ok { t.Fatal("expected providers map") } - anthropic, ok := providers["anthropic"].(map[string]interface{}) + anthropic, ok := providers["anthropic"].(map[string]any) if !ok { t.Fatal("expected anthropic map") } @@ -129,11 +129,11 @@ func TestLoadOpenClawConfig(t *testing.T) { t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"]) } - agents, ok := result["agents"].(map[string]interface{}) + agents, ok := result["agents"].(map[string]any) if !ok { t.Fatal("expected agents map") } - defaults, ok := agents["defaults"].(map[string]interface{}) + defaults, ok := agents["defaults"].(map[string]any) if !ok { t.Fatal("expected defaults map") } @@ -144,16 +144,16 @@ func TestLoadOpenClawConfig(t *testing.T) { func TestConvertConfig(t *testing.T) { t.Run("providers mapping", func(t *testing.T) { - data := map[string]interface{}{ - "providers": map[string]interface{}{ - "anthropic": map[string]interface{}{ + data := map[string]any{ + "providers": map[string]any{ + "anthropic": map[string]any{ "api_key": "sk-ant-test", "api_base": "https://api.anthropic.com", }, - "openrouter": map[string]interface{}{ + "openrouter": map[string]any{ "api_key": "sk-or-test", }, - "groq": map[string]interface{}{ + "groq": map[string]any{ "api_key": "gsk-test", }, }, @@ -178,10 +178,10 @@ func TestConvertConfig(t *testing.T) { }) t.Run("unsupported provider warning", func(t *testing.T) { - data := map[string]interface{}{ - "providers": map[string]interface{}{ - "deepseek": map[string]interface{}{ - "api_key": "sk-deep-test", + data := map[string]any{ + "providers": map[string]any{ + "unknown_provider": map[string]any{ + "api_key": "sk-test", }, }, } @@ -193,20 +193,20 @@ func TestConvertConfig(t *testing.T) { if len(warnings) != 1 { t.Fatalf("expected 1 warning, got %d", len(warnings)) } - if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" { + if warnings[0] != "Provider 'unknown_provider' not supported in PicoClaw, skipping" { t.Errorf("unexpected warning: %s", warnings[0]) } }) t.Run("channels mapping", func(t *testing.T) { - data := map[string]interface{}{ - "channels": map[string]interface{}{ - "telegram": map[string]interface{}{ + data := map[string]any{ + "channels": map[string]any{ + "telegram": map[string]any{ "enabled": true, "token": "tg-token-123", - "allow_from": []interface{}{"user1"}, + "allow_from": []any{"user1"}, }, - "discord": map[string]interface{}{ + "discord": map[string]any{ "enabled": true, "token": "disc-token-456", }, @@ -232,9 +232,9 @@ func TestConvertConfig(t *testing.T) { }) t.Run("unsupported channel warning", func(t *testing.T) { - data := map[string]interface{}{ - "channels": map[string]interface{}{ - "email": map[string]interface{}{ + data := map[string]any{ + "channels": map[string]any{ + "email": map[string]any{ "enabled": true, }, }, @@ -253,14 +253,14 @@ func TestConvertConfig(t *testing.T) { }) t.Run("agent defaults", func(t *testing.T) { - data := map[string]interface{}{ - "agents": map[string]interface{}{ - "defaults": map[string]interface{}{ - "model": "claude-3-opus", - "max_tokens": float64(4096), - "temperature": 0.5, - "max_tool_iterations": float64(10), - "workspace": "~/.openclaw/workspace", + data := map[string]any{ + "agents": map[string]any{ + "defaults": map[string]any{ + "model": "claude-3-opus", + "max_tokens": float64(4096), + "temperature": 0.5, + "max_tool_iterations": float64(10), + "workspace": "~/.openclaw/workspace", }, }, } @@ -275,8 +275,11 @@ func TestConvertConfig(t *testing.T) { if cfg.Agents.Defaults.MaxTokens != 4096 { t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096) } - if cfg.Agents.Defaults.Temperature != 0.5 { - t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5) + if cfg.Agents.Defaults.Temperature == nil { + t.Fatalf("Temperature is nil, want %f", 0.5) + } + if *cfg.Agents.Defaults.Temperature != 0.5 { + t.Errorf("Temperature = %f, want %f", *cfg.Agents.Defaults.Temperature, 0.5) } if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" { t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace") @@ -284,7 +287,7 @@ func TestConvertConfig(t *testing.T) { }) t.Run("empty config", func(t *testing.T) { - data := map[string]interface{}{} + data := map[string]any{} cfg, warnings, err := ConvertConfig(data) if err != nil { @@ -299,6 +302,24 @@ func TestConvertConfig(t *testing.T) { }) } +func TestSupportedProvidersCompatibility(t *testing.T) { + expected := []string{ + "anthropic", + "openai", + "openrouter", + "groq", + "zhipu", + "vllm", + "gemini", + } + + for _, provider := range expected { + if !supportedProviders[provider] { + t.Fatalf("supportedProviders missing expected key %q", provider) + } + } +} + func TestMergeConfig(t *testing.T) { t.Run("fills empty fields", func(t *testing.T) { existing := config.DefaultConfig() @@ -368,9 +389,9 @@ func TestPlanWorkspaceMigration(t *testing.T) { srcDir := t.TempDir() dstDir := t.TempDir() - os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644) - os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0644) - os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0644) + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0o644) + os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0o644) + os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0o644) actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) if err != nil { @@ -399,8 +420,8 @@ func TestPlanWorkspaceMigration(t *testing.T) { srcDir := t.TempDir() dstDir := t.TempDir() - os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644) - os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0644) + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0o644) + os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0o644) actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) if err != nil { @@ -422,8 +443,8 @@ func TestPlanWorkspaceMigration(t *testing.T) { srcDir := t.TempDir() dstDir := t.TempDir() - os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644) - os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0644) + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0o644) + os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0o644) actions, err := PlanWorkspaceMigration(srcDir, dstDir, true) if err != nil { @@ -442,8 +463,8 @@ func TestPlanWorkspaceMigration(t *testing.T) { dstDir := t.TempDir() memDir := filepath.Join(srcDir, "memory") - os.MkdirAll(memDir, 0755) - os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0644) + os.MkdirAll(memDir, 0o755) + os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0o644) actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) if err != nil { @@ -473,8 +494,8 @@ func TestPlanWorkspaceMigration(t *testing.T) { dstDir := t.TempDir() skillDir := filepath.Join(srcDir, "skills", "weather") - os.MkdirAll(skillDir, 0755) - os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0644) + os.MkdirAll(skillDir, 0o755) + os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0o644) actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) if err != nil { @@ -497,7 +518,7 @@ func TestFindOpenClawConfig(t *testing.T) { t.Run("finds openclaw.json", func(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "openclaw.json") - os.WriteFile(configPath, []byte("{}"), 0644) + os.WriteFile(configPath, []byte("{}"), 0o644) found, err := findOpenClawConfig(tmpDir) if err != nil { @@ -511,7 +532,7 @@ func TestFindOpenClawConfig(t *testing.T) { t.Run("falls back to config.json", func(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "config.json") - os.WriteFile(configPath, []byte("{}"), 0644) + os.WriteFile(configPath, []byte("{}"), 0o644) found, err := findOpenClawConfig(tmpDir) if err != nil { @@ -525,8 +546,8 @@ func TestFindOpenClawConfig(t *testing.T) { t.Run("prefers openclaw.json over config.json", func(t *testing.T) { tmpDir := t.TempDir() openclawPath := filepath.Join(tmpDir, "openclaw.json") - os.WriteFile(openclawPath, []byte("{}"), 0644) - os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644) + os.WriteFile(openclawPath, []byte("{}"), 0o644) + os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0o644) found, err := findOpenClawConfig(tmpDir) if err != nil { @@ -572,19 +593,19 @@ func TestRunDryRun(t *testing.T) { picoClawHome := t.TempDir() wsDir := filepath.Join(openclawHome, "workspace") - os.MkdirAll(wsDir, 0755) - os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) - os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0644) + os.MkdirAll(wsDir, 0o755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0o644) + os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0o644) - configData := map[string]interface{}{ - "providers": map[string]interface{}{ - "anthropic": map[string]interface{}{ + configData := map[string]any{ + "providers": map[string]any{ + "anthropic": map[string]any{ "apiKey": "test-key", }, }, } data, _ := json.Marshal(configData) - os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644) opts := Options{ DryRun: true, @@ -613,33 +634,33 @@ func TestRunFullMigration(t *testing.T) { picoClawHome := t.TempDir() wsDir := filepath.Join(openclawHome, "workspace") - os.MkdirAll(wsDir, 0755) - os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0644) - os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644) - os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0644) + os.MkdirAll(wsDir, 0o755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0o644) + os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0o644) + os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0o644) memDir := filepath.Join(wsDir, "memory") - os.MkdirAll(memDir, 0755) - os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0644) + os.MkdirAll(memDir, 0o755) + os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0o644) - configData := map[string]interface{}{ - "providers": map[string]interface{}{ - "anthropic": map[string]interface{}{ + configData := map[string]any{ + "providers": map[string]any{ + "anthropic": map[string]any{ "apiKey": "sk-ant-migrate-test", }, - "openrouter": map[string]interface{}{ + "openrouter": map[string]any{ "apiKey": "sk-or-migrate-test", }, }, - "channels": map[string]interface{}{ - "telegram": map[string]interface{}{ + "channels": map[string]any{ + "telegram": map[string]any{ "enabled": true, "token": "tg-migrate-test", }, }, } data, _ := json.Marshal(configData) - os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644) opts := Options{ Force: true, @@ -733,7 +754,7 @@ func TestRunMutuallyExclusiveFlags(t *testing.T) { func TestBackupFile(t *testing.T) { tmpDir := t.TempDir() filePath := filepath.Join(tmpDir, "test.md") - os.WriteFile(filePath, []byte("original content"), 0644) + os.WriteFile(filePath, []byte("original content"), 0o644) if err := backupFile(filePath); err != nil { t.Fatalf("backupFile: %v", err) @@ -754,7 +775,7 @@ func TestCopyFile(t *testing.T) { srcPath := filepath.Join(tmpDir, "src.md") dstPath := filepath.Join(tmpDir, "dst.md") - os.WriteFile(srcPath, []byte("file content"), 0644) + os.WriteFile(srcPath, []byte("file content"), 0o644) if err := copyFile(srcPath, dstPath); err != nil { t.Fatalf("copyFile: %v", err) @@ -774,18 +795,18 @@ func TestRunConfigOnly(t *testing.T) { picoClawHome := t.TempDir() wsDir := filepath.Join(openclawHome, "workspace") - os.MkdirAll(wsDir, 0755) - os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) + os.MkdirAll(wsDir, 0o755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0o644) - configData := map[string]interface{}{ - "providers": map[string]interface{}{ - "anthropic": map[string]interface{}{ + configData := map[string]any{ + "providers": map[string]any{ + "anthropic": map[string]any{ "apiKey": "sk-config-only", }, }, } data, _ := json.Marshal(configData) - os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644) opts := Options{ Force: true, @@ -814,18 +835,18 @@ func TestRunWorkspaceOnly(t *testing.T) { picoClawHome := t.TempDir() wsDir := filepath.Join(openclawHome, "workspace") - os.MkdirAll(wsDir, 0755) - os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) + os.MkdirAll(wsDir, 0o755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0o644) - configData := map[string]interface{}{ - "providers": map[string]interface{}{ - "anthropic": map[string]interface{}{ + configData := map[string]any{ + "providers": map[string]any{ + "anthropic": map[string]any{ "apiKey": "sk-ws-only", }, }, } data, _ := json.Marshal(configData) - os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644) opts := Options{ Force: true, diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go new file mode 100644 index 000000000..35f6b8f62 --- /dev/null +++ b/pkg/providers/anthropic/provider.go @@ -0,0 +1,262 @@ +package anthropicprovider + +import ( + "context" + "encoding/json" + "fmt" + "log" + "strings" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ( + ToolCall = protocoltypes.ToolCall + FunctionCall = protocoltypes.FunctionCall + LLMResponse = protocoltypes.LLMResponse + UsageInfo = protocoltypes.UsageInfo + Message = protocoltypes.Message + ToolDefinition = protocoltypes.ToolDefinition + ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition +) + +const defaultBaseURL = "https://api.anthropic.com" + +type Provider struct { + client *anthropic.Client + tokenSource func() (string, error) + baseURL string +} + +func NewProvider(token string) *Provider { + return NewProviderWithBaseURL(token, "") +} + +func NewProviderWithBaseURL(token, apiBase string) *Provider { + baseURL := normalizeBaseURL(apiBase) + client := anthropic.NewClient( + option.WithAuthToken(token), + option.WithBaseURL(baseURL), + ) + return &Provider{ + client: &client, + baseURL: baseURL, + } +} + +func NewProviderWithClient(client *anthropic.Client) *Provider { + return &Provider{ + client: client, + baseURL: defaultBaseURL, + } +} + +func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider { + return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "") +} + +func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider { + p := NewProviderWithBaseURL(token, apiBase) + p.tokenSource = tokenSource + return p +} + +func (p *Provider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAuthToken(tok)) + } + + params, err := buildParams(messages, tools, model, options) + if err != nil { + return nil, err + } + + resp, err := p.client.Messages.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("claude API call: %w", err) + } + + return parseResponse(resp), nil +} + +func (p *Provider) GetDefaultModel() string { + return "claude-sonnet-4.6" +} + +func (p *Provider) BaseURL() string { + return p.baseURL +} + +func buildParams( + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (anthropic.MessageNewParams, error) { + var system []anthropic.TextBlockParam + var anthropicMessages []anthropic.MessageParam + + for _, msg := range messages { + switch msg.Role { + case "system": + system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + case "user": + if msg.ToolCallID != "" { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + var blocks []anthropic.ContentBlockParamUnion + if msg.Content != "" { + blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) + } + for _, tc := range msg.ToolCalls { + blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "tool": + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } + } + + maxTokens := int64(4096) + if mt, ok := options["max_tokens"].(int); ok { + maxTokens = int64(mt) + } + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(model), + Messages: anthropicMessages, + MaxTokens: maxTokens, + } + + if len(system) > 0 { + params.System = system + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = anthropic.Float(temp) + } + + if len(tools) > 0 { + params.Tools = translateTools(tools) + } + + return params, nil +} + +func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam { + result := make([]anthropic.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + tool := anthropic.ToolParam{ + Name: t.Function.Name, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: t.Function.Parameters["properties"], + }, + } + if desc := t.Function.Description; desc != "" { + tool.Description = anthropic.String(desc) + } + if req, ok := t.Function.Parameters["required"].([]any); ok { + required := make([]string, 0, len(req)) + for _, r := range req { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + tool.InputSchema.Required = required + } + result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) + } + return result +} + +func parseResponse(resp *anthropic.Message) *LLMResponse { + var content string + var toolCalls []ToolCall + + for _, block := range resp.Content { + switch block.Type { + case "text": + tb := block.AsText() + content += tb.Text + case "tool_use": + tu := block.AsToolUse() + var args map[string]any + if err := json.Unmarshal(tu.Input, &args); err != nil { + log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err) + args = map[string]any{"raw": string(tu.Input)} + } + toolCalls = append(toolCalls, ToolCall{ + ID: tu.ID, + Name: tu.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + switch resp.StopReason { + case anthropic.StopReasonToolUse: + finishReason = "tool_calls" + case anthropic.StopReasonMaxTokens: + finishReason = "length" + case anthropic.StopReasonEndTurn: + finishReason = "stop" + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), + }, + } +} + +func normalizeBaseURL(apiBase string) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return defaultBaseURL + } + + base = strings.TrimRight(base, "/") + if strings.HasSuffix(base, "/v1") { + base = strings.TrimSuffix(base, "/v1") + } + if base == "" { + return defaultBaseURL + } + + return base +} diff --git a/pkg/providers/anthropic/provider_test.go b/pkg/providers/anthropic/provider_test.go new file mode 100644 index 000000000..3d21c1d0b --- /dev/null +++ b/pkg/providers/anthropic/provider_test.go @@ -0,0 +1,271 @@ +package anthropicprovider + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestBuildParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4.6", map[string]any{ + "max_tokens": 1024, + }) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if string(params.Model) != "claude-sonnet-4.6" { + t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4.6") + } + if params.MaxTokens != 1024 { + t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildParams_SystemMessage(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4.6", map[string]any{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.System) != 1 { + t.Fatalf("len(System) = %d, want 1", len(params.System)) + } + if params.System[0].Text != "You are helpful" { + t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildParams_ToolCallMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Name: "get_weather", + Arguments: map[string]any{"city": "SF"}, + }, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4.6", map[string]any{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) + } +} + +func TestBuildParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, + "required": []any{"city"}, + }, + }, + }, + } + params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4.6", map[string]any{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } +} + +func TestParseResponse_TextOnly(t *testing.T) { + resp := &anthropic.Message{ + Content: []anthropic.ContentBlockUnion{}, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 20, + }, + } + result := parseResponse(resp) + if result.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) + } + if result.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } +} + +func TestParseResponse_StopReasons(t *testing.T) { + tests := []struct { + stopReason anthropic.StopReason + want string + }{ + {anthropic.StopReasonEndTurn, "stop"}, + {anthropic.StopReasonMaxTokens, "length"}, + {anthropic.StopReasonToolUse, "tool_calls"}, + } + for _, tt := range tests { + resp := &anthropic.Message{ + StopReason: tt.stopReason, + } + result := parseResponse(resp) + if result.FinishReason != tt.want { + t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) + } + } +} + +func TestProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]any + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]any{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]any{ + {"type": "text", "text": "Hello! How can I help you?"}, + }, + "usage": map[string]any{ + "input_tokens": 15, + "output_tokens": 8, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token")) + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4.6", map[string]any{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello! How can I help you?" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 15 { + t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens) + } +} + +func TestProvider_GetDefaultModel(t *testing.T) { + p := NewProvider("test-token") + if got := p.GetDefaultModel(); got != "claude-sonnet-4.6" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4.6") + } +} + +func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) { + p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/") + if got := p.BaseURL(); got != "https://api.anthropic.com" { + t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com") + } +} + +func TestProvider_ChatUsesTokenSource(t *testing.T) { + var requests int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + atomic.AddInt32(&requests, 1) + + if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]any + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]any{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]any{ + {"type": "text", "text": "ok"}, + }, + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) { + return "refreshed-token", nil + }, server.URL) + + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hello"}}, + nil, + "claude-sonnet-4.6", + map[string]any{}, + ) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if got := atomic.LoadInt32(&requests); got != 1 { + t.Fatalf("requests = %d, want 1", got) + } +} + +func createAnthropicTestClient(baseURL, token string) *anthropic.Client { + c := anthropic.NewClient( + anthropicoption.WithAuthToken(token), + anthropicoption.WithBaseURL(baseURL), + ) + return &c +} diff --git a/pkg/providers/antigravity_provider.go b/pkg/providers/antigravity_provider.go new file mode 100644 index 000000000..cff67c88c --- /dev/null +++ b/pkg/providers/antigravity_provider.go @@ -0,0 +1,855 @@ +package providers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + antigravityBaseURL = "https://cloudcode-pa.googleapis.com" + antigravityDefaultModel = "gemini-3-flash" + antigravityUserAgent = "antigravity" + antigravityXGoogClient = "google-cloud-sdk vscode_cloudshelleditor/0.1" + antigravityVersion = "1.15.8" +) + +// AntigravityProvider implements LLMProvider using Google's Cloud Code Assist (Antigravity) API. +// This provider authenticates via Google OAuth and provides access to models like Claude and Gemini +// through Google's infrastructure. +type AntigravityProvider struct { + tokenSource func() (string, string, error) // Returns (accessToken, projectID, error) + httpClient *http.Client +} + +// NewAntigravityProvider creates a new Antigravity provider using stored auth credentials. +func NewAntigravityProvider() *AntigravityProvider { + return &AntigravityProvider{ + tokenSource: createAntigravityTokenSource(), + httpClient: &http.Client{ + Timeout: 120 * time.Second, + }, + } +} + +// Chat implements LLMProvider.Chat using the Cloud Code Assist v1internal API. +// The v1internal endpoint wraps the standard Gemini request in an envelope with +// project, model, request, requestType, userAgent, and requestId fields. +func (p *AntigravityProvider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + accessToken, projectID, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("antigravity auth: %w", err) + } + + if model == "" || model == "antigravity" || model == "google-antigravity" { + model = antigravityDefaultModel + } + // Strip provider prefixes if present + model = strings.TrimPrefix(model, "google-antigravity/") + model = strings.TrimPrefix(model, "antigravity/") + + logger.DebugCF("provider.antigravity", "Starting chat", map[string]any{ + "model": model, + "project": projectID, + "requestId": fmt.Sprintf("agent-%d-%s", time.Now().UnixMilli(), randomString(9)), + }) + + // Build the inner Gemini-format request + innerRequest := p.buildRequest(messages, tools, model, options) + + // Wrap in v1internal envelope (matches pi-ai SDK format) + envelope := map[string]any{ + "project": projectID, + "model": model, + "request": innerRequest, + "requestType": "agent", + "userAgent": antigravityUserAgent, + "requestId": fmt.Sprintf("agent-%d-%s", time.Now().UnixMilli(), randomString(9)), + } + + bodyBytes, err := json.Marshal(envelope) + if err != nil { + return nil, fmt.Errorf("marshaling request: %w", err) + } + + // Build API URL — uses Cloud Code Assist v1internal streaming endpoint + apiURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", antigravityBaseURL) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + // Headers matching the pi-ai SDK antigravity format + clientMetadata, _ := json.Marshal(map[string]string{ + "ideType": "IDE_UNSPECIFIED", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + }) + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("User-Agent", fmt.Sprintf("antigravity/%s linux/amd64", antigravityVersion)) + req.Header.Set("X-Goog-Api-Client", antigravityXGoogClient) + req.Header.Set("Client-Metadata", string(clientMetadata)) + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("antigravity API call: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + logger.ErrorCF("provider.antigravity", "API call failed", map[string]any{ + "status_code": resp.StatusCode, + "response": string(respBody), + "model": model, + }) + + return nil, p.parseAntigravityError(resp.StatusCode, respBody) + } + + // Response is always SSE from streamGenerateContent — each line is "data: {...}" + // with a "response" wrapper containing the standard Gemini response + llmResp, err := p.parseSSEResponse(string(respBody)) + if err != nil { + return nil, err + } + + // Check for empty response (some models might return valid success but empty text) + if llmResp.Content == "" && len(llmResp.ToolCalls) == 0 { + return nil, fmt.Errorf( + "antigravity: model returned an empty response (this model might be invalid or restricted)", + ) + } + + return llmResp, nil +} + +// GetDefaultModel returns the default model identifier. +func (p *AntigravityProvider) GetDefaultModel() string { + return antigravityDefaultModel +} + +// --- Request building --- + +type antigravityRequest struct { + Contents []antigravityContent `json:"contents"` + Tools []antigravityTool `json:"tools,omitempty"` + SystemPrompt *antigravitySystemPrompt `json:"systemInstruction,omitempty"` + Config *antigravityGenConfig `json:"generationConfig,omitempty"` +} + +type antigravityContent struct { + Role string `json:"role"` + Parts []antigravityPart `json:"parts"` +} + +type antigravityPart struct { + Text string `json:"text,omitempty"` + ThoughtSignature string `json:"thoughtSignature,omitempty"` + ThoughtSignatureSnake string `json:"thought_signature,omitempty"` + FunctionCall *antigravityFunctionCall `json:"functionCall,omitempty"` + FunctionResponse *antigravityFunctionResponse `json:"functionResponse,omitempty"` +} + +type antigravityFunctionCall struct { + Name string `json:"name"` + Args map[string]any `json:"args"` +} + +type antigravityFunctionResponse struct { + Name string `json:"name"` + Response map[string]any `json:"response"` +} + +type antigravityTool struct { + FunctionDeclarations []antigravityFuncDecl `json:"functionDeclarations"` +} + +type antigravityFuncDecl struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters any `json:"parameters,omitempty"` +} + +type antigravitySystemPrompt struct { + Parts []antigravityPart `json:"parts"` +} + +type antigravityGenConfig struct { + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +func (p *AntigravityProvider) buildRequest( + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) antigravityRequest { + req := antigravityRequest{} + toolCallNames := make(map[string]string) + + // Build contents from messages + for _, msg := range messages { + switch msg.Role { + case "system": + req.SystemPrompt = &antigravitySystemPrompt{ + Parts: []antigravityPart{{Text: msg.Content}}, + } + case "user": + if msg.ToolCallID != "" { + toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + // Tool result + req.Contents = append(req.Contents, antigravityContent{ + Role: "user", + Parts: []antigravityPart{{ + FunctionResponse: &antigravityFunctionResponse{ + Name: toolName, + Response: map[string]any{ + "result": msg.Content, + }, + }, + }}, + }) + } else { + req.Contents = append(req.Contents, antigravityContent{ + Role: "user", + Parts: []antigravityPart{{Text: msg.Content}}, + }) + } + case "assistant": + content := antigravityContent{ + Role: "model", + } + if msg.Content != "" { + content.Parts = append(content.Parts, antigravityPart{Text: msg.Content}) + } + for _, tc := range msg.ToolCalls { + toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc) + if toolName == "" { + logger.WarnCF( + "provider.antigravity", + "Skipping tool call with empty name in history", + map[string]any{ + "tool_call_id": tc.ID, + }, + ) + continue + } + if tc.ID != "" { + toolCallNames[tc.ID] = toolName + } + content.Parts = append(content.Parts, antigravityPart{ + ThoughtSignature: thoughtSignature, + ThoughtSignatureSnake: thoughtSignature, + FunctionCall: &antigravityFunctionCall{ + Name: toolName, + Args: toolArgs, + }, + }) + } + if len(content.Parts) > 0 { + req.Contents = append(req.Contents, content) + } + case "tool": + toolName := resolveToolResponseName(msg.ToolCallID, toolCallNames) + req.Contents = append(req.Contents, antigravityContent{ + Role: "user", + Parts: []antigravityPart{{ + FunctionResponse: &antigravityFunctionResponse{ + Name: toolName, + Response: map[string]any{ + "result": msg.Content, + }, + }, + }}, + }) + } + } + + // Build tools (sanitize schemas for Gemini compatibility) + if len(tools) > 0 { + var funcDecls []antigravityFuncDecl + for _, t := range tools { + if t.Type != "function" { + continue + } + params := sanitizeSchemaForGemini(t.Function.Parameters) + funcDecls = append(funcDecls, antigravityFuncDecl{ + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: params, + }) + } + if len(funcDecls) > 0 { + req.Tools = []antigravityTool{{FunctionDeclarations: funcDecls}} + } + } + + // Generation config + config := &antigravityGenConfig{} + if val, ok := options["max_tokens"]; ok { + if maxTokens, ok := val.(int); ok && maxTokens > 0 { + config.MaxOutputTokens = maxTokens + } else if maxTokens, ok := val.(float64); ok && maxTokens > 0 { + config.MaxOutputTokens = int(maxTokens) + } + } + if temp, ok := options["temperature"].(float64); ok { + config.Temperature = temp + } + if config.MaxOutputTokens > 0 || config.Temperature > 0 { + req.Config = config + } + + return req +} + +func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) { + name := tc.Name + args := tc.Arguments + thoughtSignature := "" + + if name == "" && tc.Function != nil { + name = tc.Function.Name + thoughtSignature = tc.Function.ThoughtSignature + } else if tc.Function != nil { + thoughtSignature = tc.Function.ThoughtSignature + } + + if args == nil { + args = map[string]any{} + } + + if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" { + var parsed map[string]any + if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil { + args = parsed + } + } + + return name, args, thoughtSignature +} + +func resolveToolResponseName(toolCallID string, toolCallNames map[string]string) string { + if toolCallID == "" { + return "" + } + + if name, ok := toolCallNames[toolCallID]; ok && name != "" { + return name + } + + return inferToolNameFromCallID(toolCallID) +} + +func inferToolNameFromCallID(toolCallID string) string { + if !strings.HasPrefix(toolCallID, "call_") { + return toolCallID + } + + rest := strings.TrimPrefix(toolCallID, "call_") + if idx := strings.LastIndex(rest, "_"); idx > 0 { + candidate := rest[:idx] + if candidate != "" { + return candidate + } + } + + return toolCallID +} + +// --- Response parsing --- + +type antigravityJSONResponse struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text,omitempty"` + ThoughtSignature string `json:"thoughtSignature,omitempty"` + ThoughtSignatureSnake string `json:"thought_signature,omitempty"` + FunctionCall *antigravityFunctionCall `json:"functionCall,omitempty"` + } `json:"parts"` + Role string `json:"role"` + } `json:"content"` + FinishReason string `json:"finishReason"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + } `json:"usageMetadata"` +} + +func (p *AntigravityProvider) parseJSONResponse(body []byte) (*LLMResponse, error) { + var resp antigravityJSONResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("parsing antigravity response: %w", err) + } + + if len(resp.Candidates) == 0 { + return nil, fmt.Errorf("antigravity: no candidates in response") + } + + candidate := resp.Candidates[0] + var contentParts []string + var toolCalls []ToolCall + + for _, part := range candidate.Content.Parts { + if part.Text != "" { + contentParts = append(contentParts, part.Text) + } + if part.FunctionCall != nil { + argumentsJSON, _ := json.Marshal(part.FunctionCall.Args) + toolCalls = append(toolCalls, ToolCall{ + ID: fmt.Sprintf("call_%s_%d", part.FunctionCall.Name, time.Now().UnixNano()), + Name: part.FunctionCall.Name, + Arguments: part.FunctionCall.Args, + Function: &FunctionCall{ + Name: part.FunctionCall.Name, + Arguments: string(argumentsJSON), + ThoughtSignature: extractPartThoughtSignature(part.ThoughtSignature, part.ThoughtSignatureSnake), + }, + }) + } + } + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + if candidate.FinishReason == "MAX_TOKENS" { + finishReason = "length" + } + + var usage *UsageInfo + if resp.UsageMetadata.TotalTokenCount > 0 { + usage = &UsageInfo{ + PromptTokens: resp.UsageMetadata.PromptTokenCount, + CompletionTokens: resp.UsageMetadata.CandidatesTokenCount, + TotalTokens: resp.UsageMetadata.TotalTokenCount, + } + } + + return &LLMResponse{ + Content: strings.Join(contentParts, ""), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + }, nil +} + +func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error) { + var contentParts []string + var toolCalls []ToolCall + var usage *UsageInfo + var finishReason string + + scanner := bufio.NewScanner(strings.NewReader(body)) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + // v1internal SSE wraps the Gemini response in a "response" field + var sseChunk struct { + Response antigravityJSONResponse `json:"response"` + } + if err := json.Unmarshal([]byte(data), &sseChunk); err != nil { + continue + } + resp := sseChunk.Response + + for _, candidate := range resp.Candidates { + for _, part := range candidate.Content.Parts { + if part.Text != "" { + contentParts = append(contentParts, part.Text) + } + if part.FunctionCall != nil { + argumentsJSON, _ := json.Marshal(part.FunctionCall.Args) + toolCalls = append(toolCalls, ToolCall{ + ID: fmt.Sprintf("call_%s_%d", part.FunctionCall.Name, time.Now().UnixNano()), + Name: part.FunctionCall.Name, + Arguments: part.FunctionCall.Args, + Function: &FunctionCall{ + Name: part.FunctionCall.Name, + Arguments: string(argumentsJSON), + ThoughtSignature: extractPartThoughtSignature( + part.ThoughtSignature, + part.ThoughtSignatureSnake, + ), + }, + }) + } + } + if candidate.FinishReason != "" { + finishReason = candidate.FinishReason + } + } + + if resp.UsageMetadata.TotalTokenCount > 0 { + usage = &UsageInfo{ + PromptTokens: resp.UsageMetadata.PromptTokenCount, + CompletionTokens: resp.UsageMetadata.CandidatesTokenCount, + TotalTokens: resp.UsageMetadata.TotalTokenCount, + } + } + } + + mappedFinish := "stop" + if len(toolCalls) > 0 { + mappedFinish = "tool_calls" + } + if finishReason == "MAX_TOKENS" { + mappedFinish = "length" + } + + return &LLMResponse{ + Content: strings.Join(contentParts, ""), + ToolCalls: toolCalls, + FinishReason: mappedFinish, + Usage: usage, + }, nil +} + +func extractPartThoughtSignature(thoughtSignature string, thoughtSignatureSnake string) string { + if thoughtSignature != "" { + return thoughtSignature + } + if thoughtSignatureSnake != "" { + return thoughtSignatureSnake + } + return "" +} + +// --- Schema sanitization --- + +// Google/Gemini doesn't support many JSON Schema keywords that other providers accept. +var geminiUnsupportedKeywords = map[string]bool{ + "patternProperties": true, + "additionalProperties": true, + "$schema": true, + "$id": true, + "$ref": true, + "$defs": true, + "definitions": true, + "examples": true, + "minLength": true, + "maxLength": true, + "minimum": true, + "maximum": true, + "multipleOf": true, + "pattern": true, + "format": true, + "minItems": true, + "maxItems": true, + "uniqueItems": true, + "minProperties": true, + "maxProperties": true, +} + +func sanitizeSchemaForGemini(schema map[string]any) map[string]any { + if schema == nil { + return nil + } + + result := make(map[string]any) + for k, v := range schema { + if geminiUnsupportedKeywords[k] { + continue + } + // Recursively sanitize nested objects + switch val := v.(type) { + case map[string]any: + result[k] = sanitizeSchemaForGemini(val) + case []any: + sanitized := make([]any, len(val)) + for i, item := range val { + if m, ok := item.(map[string]any); ok { + sanitized[i] = sanitizeSchemaForGemini(m) + } else { + sanitized[i] = item + } + } + result[k] = sanitized + default: + result[k] = v + } + } + + // Ensure top-level has type: "object" if properties are present + if _, hasProps := result["properties"]; hasProps { + if _, hasType := result["type"]; !hasType { + result["type"] = "object" + } + } + + return result +} + +// --- Token source --- + +func createAntigravityTokenSource() func() (string, string, error) { + return func() (string, string, error) { + cred, err := auth.GetCredential("google-antigravity") + if err != nil { + return "", "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", "", fmt.Errorf( + "no credentials for google-antigravity. Run: picoclaw auth login --provider google-antigravity", + ) + } + + // Refresh if needed + if cred.NeedsRefresh() && cred.RefreshToken != "" { + oauthCfg := auth.GoogleAntigravityOAuthConfig() + refreshed, err := auth.RefreshAccessToken(cred, oauthCfg) + if err != nil { + return "", "", fmt.Errorf("refreshing token: %w", err) + } + refreshed.Email = cred.Email + if refreshed.ProjectID == "" { + refreshed.ProjectID = cred.ProjectID + } + if err := auth.SetCredential("google-antigravity", refreshed); err != nil { + return "", "", fmt.Errorf("saving refreshed token: %w", err) + } + cred = refreshed + } + + if cred.IsExpired() { + return "", "", fmt.Errorf( + "antigravity credentials expired. Run: picoclaw auth login --provider google-antigravity", + ) + } + + projectID := cred.ProjectID + if projectID == "" { + // Try to fetch project ID from API + fetchedID, err := FetchAntigravityProjectID(cred.AccessToken) + if err != nil { + logger.WarnCF("provider.antigravity", "Could not fetch project ID, using fallback", map[string]any{ + "error": err.Error(), + }) + projectID = "rising-fact-p41fc" // Default fallback (same as OpenCode) + } else { + projectID = fetchedID + cred.ProjectID = projectID + _ = auth.SetCredential("google-antigravity", cred) + } + } + + return cred.AccessToken, projectID, nil + } +} + +// FetchAntigravityProjectID retrieves the Google Cloud project ID from the loadCodeAssist endpoint. +func FetchAntigravityProjectID(accessToken string) (string, error) { + reqBody, _ := json.Marshal(map[string]any{ + "metadata": map[string]any{ + "ideType": "IDE_UNSPECIFIED", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + }, + }) + + req, err := http.NewRequest("POST", antigravityBaseURL+"/v1internal:loadCodeAssist", bytes.NewReader(reqBody)) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", antigravityUserAgent) + req.Header.Set("X-Goog-Api-Client", antigravityXGoogClient) + + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("loadCodeAssist failed: %s", string(body)) + } + + var result struct { + CloudAICompanionProject string `json:"cloudaicompanionProject"` + } + if err := json.Unmarshal(body, &result); err != nil { + return "", err + } + + if result.CloudAICompanionProject == "" { + return "", fmt.Errorf("no project ID in loadCodeAssist response") + } + + return result.CloudAICompanionProject, nil +} + +// FetchAntigravityModels fetches available models from the Cloud Code Assist API. +func FetchAntigravityModels(accessToken, projectID string) ([]AntigravityModelInfo, error) { + reqBody, _ := json.Marshal(map[string]any{ + "project": projectID, + }) + + req, err := http.NewRequest("POST", antigravityBaseURL+"/v1internal:fetchAvailableModels", bytes.NewReader(reqBody)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", antigravityUserAgent) + req.Header.Set("X-Goog-Api-Client", antigravityXGoogClient) + + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf( + "fetchAvailableModels failed (HTTP %d): %s", + resp.StatusCode, + truncateString(string(body), 200), + ) + } + + var result struct { + Models map[string]struct { + DisplayName string `json:"displayName"` + QuotaInfo struct { + RemainingFraction any `json:"remainingFraction"` + ResetTime string `json:"resetTime"` + IsExhausted bool `json:"isExhausted"` + } `json:"quotaInfo"` + } `json:"models"` + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("parsing models response: %w", err) + } + + var models []AntigravityModelInfo + for id, info := range result.Models { + models = append(models, AntigravityModelInfo{ + ID: id, + DisplayName: info.DisplayName, + IsExhausted: info.QuotaInfo.IsExhausted, + }) + } + + // Ensure gemini-3-flash-preview and gemini-3-flash are in the list if they aren't already + hasFlashPreview := false + hasFlash := false + for _, m := range models { + if m.ID == "gemini-3-flash-preview" { + hasFlashPreview = true + } + if m.ID == "gemini-3-flash" { + hasFlash = true + } + } + if !hasFlashPreview { + models = append(models, AntigravityModelInfo{ + ID: "gemini-3-flash-preview", + DisplayName: "Gemini 3 Flash (Preview)", + }) + } + if !hasFlash { + models = append(models, AntigravityModelInfo{ + ID: "gemini-3-flash", + DisplayName: "Gemini 3 Flash", + }) + } + + return models, nil +} + +type AntigravityModelInfo struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + IsExhausted bool `json:"is_exhausted"` +} + +// --- Helpers --- + +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +func randomString(n int) string { + const letters = "abcdefghijklmnopqrstuvwxyz0123456789" + b := make([]byte, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} + +func (p *AntigravityProvider) parseAntigravityError(statusCode int, body []byte) error { + var errResp struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + Details []map[string]any `json:"details"` + } `json:"error"` + } + + if err := json.Unmarshal(body, &errResp); err != nil { + return fmt.Errorf("antigravity API error (HTTP %d): %s", statusCode, truncateString(string(body), 500)) + } + + msg := errResp.Error.Message + if statusCode == 429 { + // Try to extract quota reset info + for _, detail := range errResp.Error.Details { + if typeVal, ok := detail["@type"].(string); ok && strings.HasSuffix(typeVal, "ErrorInfo") { + if metadata, ok := detail["metadata"].(map[string]any); ok { + if delay, ok := metadata["quotaResetDelay"].(string); ok { + return fmt.Errorf("antigravity rate limit exceeded: %s (reset in %s)", msg, delay) + } + } + } + } + return fmt.Errorf("antigravity rate limit exceeded: %s", msg) + } + + return fmt.Errorf("antigravity API error (%s): %s", errResp.Error.Status, msg) +} diff --git a/pkg/providers/antigravity_provider_test.go b/pkg/providers/antigravity_provider_test.go new file mode 100644 index 000000000..238765321 --- /dev/null +++ b/pkg/providers/antigravity_provider_test.go @@ -0,0 +1,56 @@ +package providers + +import "testing" + +func TestBuildRequestUsesFunctionFieldsWhenToolCallNameMissing(t *testing.T) { + p := &AntigravityProvider{} + + messages := []Message{ + { + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call_read_file_123", + Function: &FunctionCall{ + Name: "read_file", + Arguments: `{"path":"README.md"}`, + }, + }}, + }, + { + Role: "tool", + ToolCallID: "call_read_file_123", + Content: "ok", + }, + } + + req := p.buildRequest(messages, nil, "", nil) + if len(req.Contents) != 2 { + t.Fatalf("expected 2 contents, got %d", len(req.Contents)) + } + + modelPart := req.Contents[0].Parts[0] + if modelPart.FunctionCall == nil { + t.Fatal("expected functionCall in assistant message") + } + if modelPart.FunctionCall.Name != "read_file" { + t.Fatalf("expected functionCall name read_file, got %q", modelPart.FunctionCall.Name) + } + if got := modelPart.FunctionCall.Args["path"]; got != "README.md" { + t.Fatalf("expected functionCall args[path] to be README.md, got %v", got) + } + + toolPart := req.Contents[1].Parts[0] + if toolPart.FunctionResponse == nil { + t.Fatal("expected functionResponse in tool message") + } + if toolPart.FunctionResponse.Name != "read_file" { + t.Fatalf("expected functionResponse name read_file, got %q", toolPart.FunctionResponse.Name) + } +} + +func TestResolveToolResponseNameInfersNameFromGeneratedCallID(t *testing.T) { + got := resolveToolResponseName("call_search_docs_999", map[string]string{}) + if got != "search_docs" { + t.Fatalf("expected inferred tool name search_docs, got %q", got) + } +} diff --git a/pkg/providers/claude_cli_provider.go b/pkg/providers/claude_cli_provider.go index 242126aa2..74ec33b98 100644 --- a/pkg/providers/claude_cli_provider.go +++ b/pkg/providers/claude_cli_provider.go @@ -24,7 +24,9 @@ func NewClaudeCliProvider(workspace string) *ClaudeCliProvider { } // Chat implements LLMProvider.Chat by executing the claude CLI. -func (p *ClaudeCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { +func (p *ClaudeCliProvider) Chat( + ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any, +) (*LLMResponse, error) { systemPrompt := p.buildSystemPrompt(messages, tools) prompt := p.messagesToPrompt(messages) @@ -111,7 +113,9 @@ func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string { sb.WriteString("## Available Tools\n\n") sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n") sb.WriteString("```json\n") - sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`) + sb.WriteString( + `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`, + ) sb.WriteString("\n```\n\n") sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n") sb.WriteString("### Tool Definitions:\n\n") @@ -171,68 +175,14 @@ func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, }, nil } -// extractToolCalls parses tool call JSON from the response text. +// extractToolCalls delegates to the shared extractToolCallsFromText function. func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall { - start := strings.Index(text, `{"tool_calls"`) - if start == -1 { - return nil - } - - end := findMatchingBrace(text, start) - if end == start { - return nil - } - - jsonStr := text[start:end] - - var wrapper struct { - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` - } `json:"tool_calls"` - } - - if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil { - return nil - } - - var result []ToolCall - for _, tc := range wrapper.ToolCalls { - var args map[string]interface{} - json.Unmarshal([]byte(tc.Function.Arguments), &args) - - result = append(result, ToolCall{ - ID: tc.ID, - Type: tc.Type, - Name: tc.Function.Name, - Arguments: args, - Function: &FunctionCall{ - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - }, - }) - } - - return result + return extractToolCallsFromText(text) } -// stripToolCallsJSON removes tool call JSON from response text. +// stripToolCallsJSON delegates to the shared stripToolCallsFromText function. func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string { - start := strings.Index(text, `{"tool_calls"`) - if start == -1 { - return text - } - - end := findMatchingBrace(text, start) - if end == start { - return text - } - - return strings.TrimSpace(text[:start] + text[end:]) + return stripToolCallsFromText(text) } // findMatchingBrace finds the index after the closing brace matching the opening brace at pos. @@ -254,22 +204,22 @@ func findMatchingBrace(text string, pos int) int { // claudeCliJSONResponse represents the JSON output from the claude CLI. // Matches the real claude CLI v2.x output format. type claudeCliJSONResponse struct { - Type string `json:"type"` - Subtype string `json:"subtype"` - IsError bool `json:"is_error"` - Result string `json:"result"` - SessionID string `json:"session_id"` - TotalCostUSD float64 `json:"total_cost_usd"` - DurationMS int `json:"duration_ms"` - DurationAPI int `json:"duration_api_ms"` - NumTurns int `json:"num_turns"` - Usage claudeCliUsageInfo `json:"usage"` + Type string `json:"type"` + Subtype string `json:"subtype"` + IsError bool `json:"is_error"` + Result string `json:"result"` + SessionID string `json:"session_id"` + TotalCostUSD float64 `json:"total_cost_usd"` + DurationMS int `json:"duration_ms"` + DurationAPI int `json:"duration_api_ms"` + NumTurns int `json:"num_turns"` + Usage claudeCliUsageInfo `json:"usage"` } // claudeCliUsageInfo represents token usage from the claude CLI response. type claudeCliUsageInfo struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CacheCreationInputTokens int `json:"cache_creation_input_tokens"` - CacheReadInputTokens int `json:"cache_read_input_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` } diff --git a/pkg/providers/claude_cli_provider_integration_test.go b/pkg/providers/claude_cli_provider_integration_test.go index 9d1131ac4..f6e0d787a 100644 --- a/pkg/providers/claude_cli_provider_integration_test.go +++ b/pkg/providers/claude_cli_provider_integration_test.go @@ -28,7 +28,6 @@ func TestIntegration_RealClaudeCLI(t *testing.T) { resp, err := p.Chat(ctx, []Message{ {Role: "user", Content: "Respond with only the word 'pong'. Nothing else."}, }, nil, "", nil) - if err != nil { t.Fatalf("Chat() with real CLI error = %v", err) } @@ -75,7 +74,6 @@ func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) { {Role: "system", Content: "You are a calculator. Only respond with numbers. No text."}, {Role: "user", Content: "What is 2+2?"}, }, nil, "", nil) - if err != nil { t.Fatalf("Chat() error = %v", err) } diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go index 4d75e60e2..3a3cafaca 100644 --- a/pkg/providers/claude_cli_provider_test.go +++ b/pkg/providers/claude_cli_provider_test.go @@ -30,12 +30,12 @@ func createMockCLI(t *testing.T, stdout, stderr string, exitCode int) string { dir := t.TempDir() if stdout != "" { - if err := os.WriteFile(filepath.Join(dir, "stdout.txt"), []byte(stdout), 0644); err != nil { + if err := os.WriteFile(filepath.Join(dir, "stdout.txt"), []byte(stdout), 0o644); err != nil { t.Fatal(err) } } if stderr != "" { - if err := os.WriteFile(filepath.Join(dir, "stderr.txt"), []byte(stderr), 0644); err != nil { + if err := os.WriteFile(filepath.Join(dir, "stderr.txt"), []byte(stderr), 0o644); err != nil { t.Fatal(err) } } @@ -51,7 +51,7 @@ func createMockCLI(t *testing.T, stdout, stderr string, exitCode int) string { sb.WriteString(fmt.Sprintf("exit %d\n", exitCode)) script := filepath.Join(dir, "claude") - if err := os.WriteFile(script, []byte(sb.String()), 0755); err != nil { + if err := os.WriteFile(script, []byte(sb.String()), 0o755); err != nil { t.Fatal(err) } return script @@ -67,7 +67,7 @@ func createSlowMockCLI(t *testing.T, sleepSeconds int) string { dir := t.TempDir() script := filepath.Join(dir, "claude") content := fmt.Sprintf("#!/bin/sh\nsleep %d\necho '{\"type\":\"result\",\"result\":\"late\"}'\n", sleepSeconds) - if err := os.WriteFile(script, []byte(content), 0755); err != nil { + if err := os.WriteFile(script, []byte(content), 0o755); err != nil { t.Fatal(err) } return script @@ -88,7 +88,7 @@ cat <<'EOFMOCK' {"type":"result","result":"ok","session_id":"test"} EOFMOCK `, argsFile) - if err := os.WriteFile(script, []byte(content), 0755); err != nil { + if err := os.WriteFile(script, []byte(content), 0o755); err != nil { t.Fatal(err) } return script @@ -137,7 +137,6 @@ func TestChat_Success(t *testing.T) { resp, err := p.Chat(context.Background(), []Message{ {Role: "user", Content: "Hello"}, }, nil, "", nil) - if err != nil { t.Fatalf("Chat() error = %v", err) } @@ -193,7 +192,6 @@ func TestChat_WithToolCallsInResponse(t *testing.T) { resp, err := p.Chat(context.Background(), []Message{ {Role: "user", Content: "What's the weather?"}, }, nil, "", nil) - if err != nil { t.Fatalf("Chat() error = %v", err) } @@ -336,7 +334,7 @@ func TestChat_PassesModelFlag(t *testing.T) { _, err := p.Chat(context.Background(), []Message{ {Role: "user", Content: "Hi"}, - }, nil, "claude-sonnet-4-5-20250929", nil) + }, nil, "claude-sonnet-4.6", nil) if err != nil { t.Fatalf("Chat() error = %v", err) } @@ -346,7 +344,7 @@ func TestChat_PassesModelFlag(t *testing.T) { if !strings.Contains(args, "--model") { t.Errorf("CLI args missing --model, got: %s", args) } - if !strings.Contains(args, "claude-sonnet-4-5-20250929") { + if !strings.Contains(args, "claude-sonnet-4.6") { t.Errorf("CLI args missing model name, got: %s", args) } } @@ -403,7 +401,6 @@ func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) { resp, err := p.Chat(context.Background(), []Message{ {Role: "user", Content: "Hello"}, }, nil, "", nil) - if err != nil { t.Fatalf("Chat() with empty workspace error = %v", err) } @@ -416,10 +413,12 @@ func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) { func TestCreateProvider_ClaudeCli(t *testing.T) { cfg := config.DefaultConfig() - cfg.Agents.Defaults.Provider = "claude-cli" - cfg.Agents.Defaults.Workspace = "/test/ws" + cfg.ModelList = []config.ModelConfig{ + {ModelName: "claude-sonnet-4.6", Model: "claude-cli/claude-sonnet-4.6", Workspace: "/test/ws"}, + } + cfg.Agents.Defaults.Model = "claude-sonnet-4.6" - provider, err := CreateProvider(cfg) + provider, _, err := CreateProvider(cfg) if err != nil { t.Fatalf("CreateProvider(claude-cli) error = %v", err) } @@ -435,9 +434,12 @@ func TestCreateProvider_ClaudeCli(t *testing.T) { func TestCreateProvider_ClaudeCode(t *testing.T) { cfg := config.DefaultConfig() - cfg.Agents.Defaults.Provider = "claude-code" + cfg.ModelList = []config.ModelConfig{ + {ModelName: "claude-code", Model: "claude-cli/claude-code"}, + } + cfg.Agents.Defaults.Model = "claude-code" - provider, err := CreateProvider(cfg) + provider, _, err := CreateProvider(cfg) if err != nil { t.Fatalf("CreateProvider(claude-code) error = %v", err) } @@ -448,9 +450,12 @@ func TestCreateProvider_ClaudeCode(t *testing.T) { func TestCreateProvider_ClaudeCodec(t *testing.T) { cfg := config.DefaultConfig() - cfg.Agents.Defaults.Provider = "claudecode" + cfg.ModelList = []config.ModelConfig{ + {ModelName: "claudecode", Model: "claude-cli/claudecode"}, + } + cfg.Agents.Defaults.Model = "claudecode" - provider, err := CreateProvider(cfg) + provider, _, err := CreateProvider(cfg) if err != nil { t.Fatalf("CreateProvider(claudecode) error = %v", err) } @@ -461,10 +466,13 @@ func TestCreateProvider_ClaudeCodec(t *testing.T) { func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) { cfg := config.DefaultConfig() - cfg.Agents.Defaults.Provider = "claude-cli" + cfg.ModelList = []config.ModelConfig{ + {ModelName: "claude-cli", Model: "claude-cli/claude-sonnet"}, + } + cfg.Agents.Defaults.Model = "claude-cli" cfg.Agents.Defaults.Workspace = "" - provider, err := CreateProvider(cfg) + provider, _, err := CreateProvider(cfg) if err != nil { t.Fatalf("CreateProvider error = %v", err) } @@ -611,10 +619,10 @@ func TestBuildSystemPrompt_WithTools(t *testing.T) { Function: ToolFunctionDefinition{ Name: "get_weather", Description: "Get weather for a location", - Parameters: map[string]interface{}{ + Parameters: map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "location": map[string]interface{}{"type": "string"}, + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, }, }, }, @@ -967,9 +975,9 @@ func TestFindMatchingBrace(t *testing.T) { {`{"a":1}`, 0, 7}, {`{"a":{"b":2}}`, 0, 13}, {`text {"a":1} more`, 5, 12}, - {`{unclosed`, 0, 0}, // no match returns pos - {`{}`, 0, 2}, // empty object - {`{{{}}}`, 0, 6}, // deeply nested + {`{unclosed`, 0, 0}, // no match returns pos + {`{}`, 0, 2}, // empty object + {`{{{}}}`, 0, 6}, // deeply nested {`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher) } for _, tt := range tests { diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go index ae6aca96d..60639ca18 100644 --- a/pkg/providers/claude_provider.go +++ b/pkg/providers/claude_provider.go @@ -2,200 +2,62 @@ package providers import ( "context" - "encoding/json" "fmt" - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/option" - "github.com/sipeed/picoclaw/pkg/auth" + anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" ) type ClaudeProvider struct { - client *anthropic.Client - tokenSource func() (string, error) + delegate *anthropicprovider.Provider } func NewClaudeProvider(token string) *ClaudeProvider { - client := anthropic.NewClient( - option.WithAuthToken(token), - option.WithBaseURL("https://api.anthropic.com"), - ) - return &ClaudeProvider{client: &client} -} - -func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { - p := NewClaudeProvider(token) - p.tokenSource = tokenSource - return p -} - -func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - var opts []option.RequestOption - if p.tokenSource != nil { - tok, err := p.tokenSource() - if err != nil { - return nil, fmt.Errorf("refreshing token: %w", err) - } - opts = append(opts, option.WithAuthToken(tok)) - } - - params, err := buildClaudeParams(messages, tools, model, options) - if err != nil { - return nil, err + return &ClaudeProvider{ + delegate: anthropicprovider.NewProvider(token), } - - resp, err := p.client.Messages.New(ctx, params, opts...) - if err != nil { - return nil, fmt.Errorf("claude API call: %w", err) - } - - return parseClaudeResponse(resp), nil } -func (p *ClaudeProvider) GetDefaultModel() string { - return "claude-sonnet-4-5-20250929" -} - -func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { - var system []anthropic.TextBlockParam - var anthropicMessages []anthropic.MessageParam - - for _, msg := range messages { - switch msg.Role { - case "system": - system = append(system, anthropic.TextBlockParam{Text: msg.Content}) - case "user": - if msg.ToolCallID != "" { - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), - ) - } else { - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), - ) - } - case "assistant": - if len(msg.ToolCalls) > 0 { - var blocks []anthropic.ContentBlockParamUnion - if msg.Content != "" { - blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) - } - for _, tc := range msg.ToolCalls { - blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) - } - anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - } else { - anthropicMessages = append(anthropicMessages, - anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), - ) - } - case "tool": - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), - ) - } - } - - maxTokens := int64(4096) - if mt, ok := options["max_tokens"].(int); ok { - maxTokens = int64(mt) - } - - params := anthropic.MessageNewParams{ - Model: anthropic.Model(model), - Messages: anthropicMessages, - MaxTokens: maxTokens, - } - - if len(system) > 0 { - params.System = system - } - - if temp, ok := options["temperature"].(float64); ok { - params.Temperature = anthropic.Float(temp) +func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase), } +} - if len(tools) > 0 { - params.Tools = translateToolsForClaude(tools) +func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource), } - - return params, nil } -func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam { - result := make([]anthropic.ToolUnionParam, 0, len(tools)) - for _, t := range tools { - tool := anthropic.ToolParam{ - Name: t.Function.Name, - InputSchema: anthropic.ToolInputSchemaParam{ - Properties: t.Function.Parameters["properties"], - }, - } - if desc := t.Function.Description; desc != "" { - tool.Description = anthropic.String(desc) - } - if req, ok := t.Function.Parameters["required"].([]interface{}); ok { - required := make([]string, 0, len(req)) - for _, r := range req { - if s, ok := r.(string); ok { - required = append(required, s) - } - } - tool.InputSchema.Required = required - } - result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) +func NewClaudeProviderWithTokenSourceAndBaseURL( + token string, tokenSource func() (string, error), apiBase string, +) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase), } - return result } -func parseClaudeResponse(resp *anthropic.Message) *LLMResponse { - var content string - var toolCalls []ToolCall - - for _, block := range resp.Content { - switch block.Type { - case "text": - tb := block.AsText() - content += tb.Text - case "tool_use": - tu := block.AsToolUse() - var args map[string]interface{} - if err := json.Unmarshal(tu.Input, &args); err != nil { - args = map[string]interface{}{"raw": string(tu.Input)} - } - toolCalls = append(toolCalls, ToolCall{ - ID: tu.ID, - Name: tu.Name, - Arguments: args, - }) - } - } +func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider { + return &ClaudeProvider{delegate: delegate} +} - finishReason := "stop" - switch resp.StopReason { - case anthropic.StopReasonToolUse: - finishReason = "tool_calls" - case anthropic.StopReasonMaxTokens: - finishReason = "length" - case anthropic.StopReasonEndTurn: - finishReason = "stop" +func (p *ClaudeProvider) Chat( + ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any, +) (*LLMResponse, error) { + resp, err := p.delegate.Chat(ctx, messages, tools, model, options) + if err != nil { + return nil, err } + return resp, nil +} - return &LLMResponse{ - Content: content, - ToolCalls: toolCalls, - FinishReason: finishReason, - Usage: &UsageInfo{ - PromptTokens: int(resp.Usage.InputTokens), - CompletionTokens: int(resp.Usage.OutputTokens), - TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), - }, - } +func (p *ClaudeProvider) GetDefaultModel() string { + return p.delegate.GetDefaultModel() } func createClaudeTokenSource() func() (string, error) { return func() (string, error) { - cred, err := auth.GetCredential("anthropic") + cred, err := getCredential("anthropic") if err != nil { return "", fmt.Errorf("loading auth credentials: %w", err) } diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go index bbad2d269..98e07bb80 100644 --- a/pkg/providers/claude_provider_test.go +++ b/pkg/providers/claude_provider_test.go @@ -8,139 +8,9 @@ import ( "github.com/anthropics/anthropic-sdk-go" anthropicoption "github.com/anthropics/anthropic-sdk-go/option" -) - -func TestBuildClaudeParams_BasicMessage(t *testing.T) { - messages := []Message{ - {Role: "user", Content: "Hello"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ - "max_tokens": 1024, - }) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if string(params.Model) != "claude-sonnet-4-5-20250929" { - t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") - } - if params.MaxTokens != 1024 { - t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) - } - if len(params.Messages) != 1 { - t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) - } -} - -func TestBuildClaudeParams_SystemMessage(t *testing.T) { - messages := []Message{ - {Role: "system", Content: "You are helpful"}, - {Role: "user", Content: "Hi"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.System) != 1 { - t.Fatalf("len(System) = %d, want 1", len(params.System)) - } - if params.System[0].Text != "You are helpful" { - t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") - } - if len(params.Messages) != 1 { - t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) - } -} -func TestBuildClaudeParams_ToolCallMessage(t *testing.T) { - messages := []Message{ - {Role: "user", Content: "What's the weather?"}, - { - Role: "assistant", - Content: "", - ToolCalls: []ToolCall{ - { - ID: "call_1", - Name: "get_weather", - Arguments: map[string]interface{}{"city": "SF"}, - }, - }, - }, - {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.Messages) != 3 { - t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) - } -} - -func TestBuildClaudeParams_WithTools(t *testing.T) { - tools := []ToolDefinition{ - { - Type: "function", - Function: ToolFunctionDefinition{ - Name: "get_weather", - Description: "Get weather for a city", - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "city": map[string]interface{}{"type": "string"}, - }, - "required": []interface{}{"city"}, - }, - }, - }, - } - params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.Tools) != 1 { - t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) - } -} - -func TestParseClaudeResponse_TextOnly(t *testing.T) { - resp := &anthropic.Message{ - Content: []anthropic.ContentBlockUnion{}, - Usage: anthropic.Usage{ - InputTokens: 10, - OutputTokens: 20, - }, - } - result := parseClaudeResponse(resp) - if result.Usage.PromptTokens != 10 { - t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) - } - if result.Usage.CompletionTokens != 20 { - t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) - } - if result.FinishReason != "stop" { - t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") - } -} - -func TestParseClaudeResponse_StopReasons(t *testing.T) { - tests := []struct { - stopReason anthropic.StopReason - want string - }{ - {anthropic.StopReasonEndTurn, "stop"}, - {anthropic.StopReasonMaxTokens, "length"}, - {anthropic.StopReasonToolUse, "tool_calls"}, - } - for _, tt := range tests { - resp := &anthropic.Message{ - StopReason: tt.stopReason, - } - result := parseClaudeResponse(resp) - if result.FinishReason != tt.want { - t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) - } - } -} + anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" +) func TestClaudeProvider_ChatRoundTrip(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -153,19 +23,19 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) { return } - var reqBody map[string]interface{} + var reqBody map[string]any json.NewDecoder(r.Body).Decode(&reqBody) - resp := map[string]interface{}{ + resp := map[string]any{ "id": "msg_test", "type": "message", "role": "assistant", "model": reqBody["model"], "stop_reason": "end_turn", - "content": []map[string]interface{}{ + "content": []map[string]any{ {"type": "text", "text": "Hello! How can I help you?"}, }, - "usage": map[string]interface{}{ + "usage": map[string]any{ "input_tokens": 15, "output_tokens": 8, }, @@ -175,11 +45,11 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) { })) defer server.Close() - provider := NewClaudeProvider("test-token") - provider.client = createAnthropicTestClient(server.URL, "test-token") + delegate := anthropicprovider.NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token")) + provider := newClaudeProviderWithDelegate(delegate) messages := []Message{{Role: "user", Content: "Hello"}} - resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) + resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4.6", map[string]any{"max_tokens": 1024}) if err != nil { t.Fatalf("Chat() error: %v", err) } @@ -196,8 +66,8 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) { func TestClaudeProvider_GetDefaultModel(t *testing.T) { p := NewClaudeProvider("test-token") - if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" { - t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929") + if got := p.GetDefaultModel(); got != "claude-sonnet-4.6" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4.6") } } diff --git a/pkg/providers/codex_cli_credentials.go b/pkg/providers/codex_cli_credentials.go new file mode 100644 index 000000000..40f3ee2a1 --- /dev/null +++ b/pkg/providers/codex_cli_credentials.go @@ -0,0 +1,81 @@ +package providers + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" +) + +// CodexCliAuth represents the ~/.codex/auth.json file structure. +type CodexCliAuth struct { + Tokens struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + AccountID string `json:"account_id"` + } `json:"tokens"` +} + +// ReadCodexCliCredentials reads OAuth tokens from the Codex CLI's auth.json file. +// Expiry is estimated as file modification time + 1 hour (same approach as moltbot). +func ReadCodexCliCredentials() (accessToken, accountID string, expiresAt time.Time, err error) { + authPath, err := resolveCodexAuthPath() + if err != nil { + return "", "", time.Time{}, err + } + + data, err := os.ReadFile(authPath) + if err != nil { + return "", "", time.Time{}, fmt.Errorf("reading %s: %w", authPath, err) + } + + var auth CodexCliAuth + if err = json.Unmarshal(data, &auth); err != nil { + return "", "", time.Time{}, fmt.Errorf("parsing %s: %w", authPath, err) + } + + if auth.Tokens.AccessToken == "" { + return "", "", time.Time{}, fmt.Errorf("no access_token in %s", authPath) + } + + stat, err := os.Stat(authPath) + if err != nil { + expiresAt = time.Now().Add(time.Hour) + } else { + expiresAt = stat.ModTime().Add(time.Hour) + } + + return auth.Tokens.AccessToken, auth.Tokens.AccountID, expiresAt, nil +} + +// CreateCodexCliTokenSource creates a token source that reads from ~/.codex/auth.json. +// This allows the existing CodexProvider to reuse Codex CLI credentials. +func CreateCodexCliTokenSource() func() (string, string, error) { + return func() (string, string, error) { + token, accountID, expiresAt, err := ReadCodexCliCredentials() + if err != nil { + return "", "", fmt.Errorf("reading codex cli credentials: %w", err) + } + + if time.Now().After(expiresAt) { + return "", "", fmt.Errorf( + "codex cli credentials expired (auth.json last modified > 1h ago). Run: codex login", + ) + } + + return token, accountID, nil + } +} + +func resolveCodexAuthPath() (string, error) { + codexHome := os.Getenv("CODEX_HOME") + if codexHome == "" { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("getting home dir: %w", err) + } + codexHome = filepath.Join(home, ".codex") + } + return filepath.Join(codexHome, "auth.json"), nil +} diff --git a/pkg/providers/codex_cli_credentials_test.go b/pkg/providers/codex_cli_credentials_test.go new file mode 100644 index 000000000..43b21700a --- /dev/null +++ b/pkg/providers/codex_cli_credentials_test.go @@ -0,0 +1,181 @@ +package providers + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestReadCodexCliCredentials_Valid(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + authJSON := `{ + "tokens": { + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "account_id": "org-test123" + } + }` + if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + token, accountID, expiresAt, err := ReadCodexCliCredentials() + if err != nil { + t.Fatalf("ReadCodexCliCredentials() error: %v", err) + } + if token != "test-access-token" { + t.Errorf("token = %q, want %q", token, "test-access-token") + } + if accountID != "org-test123" { + t.Errorf("accountID = %q, want %q", accountID, "org-test123") + } + // Expiry should be within ~1 hour from now (file was just written) + if expiresAt.Before(time.Now()) { + t.Errorf("expiresAt = %v, should be in the future", expiresAt) + } + if expiresAt.After(time.Now().Add(2 * time.Hour)) { + t.Errorf("expiresAt = %v, should be within ~1 hour", expiresAt) + } +} + +func TestReadCodexCliCredentials_MissingFile(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("CODEX_HOME", tmpDir) + + _, _, _, err := ReadCodexCliCredentials() + if err == nil { + t.Fatal("expected error for missing auth.json") + } +} + +func TestReadCodexCliCredentials_EmptyToken(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + authJSON := `{"tokens": {"access_token": "", "refresh_token": "r", "account_id": "a"}}` + if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + _, _, _, err := ReadCodexCliCredentials() + if err == nil { + t.Fatal("expected error for empty access_token") + } +} + +func TestReadCodexCliCredentials_InvalidJSON(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + if err := os.WriteFile(authPath, []byte("not json"), 0o600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + _, _, _, err := ReadCodexCliCredentials() + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestReadCodexCliCredentials_NoAccountID(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + authJSON := `{"tokens": {"access_token": "tok123", "refresh_token": "ref456"}}` + if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + token, accountID, _, err := ReadCodexCliCredentials() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token != "tok123" { + t.Errorf("token = %q, want %q", token, "tok123") + } + if accountID != "" { + t.Errorf("accountID = %q, want empty", accountID) + } +} + +func TestReadCodexCliCredentials_CodexHomeEnv(t *testing.T) { + tmpDir := t.TempDir() + customDir := filepath.Join(tmpDir, "custom-codex") + if err := os.MkdirAll(customDir, 0o755); err != nil { + t.Fatal(err) + } + + authJSON := `{"tokens": {"access_token": "custom-token", "refresh_token": "r"}}` + if err := os.WriteFile(filepath.Join(customDir, "auth.json"), []byte(authJSON), 0o600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", customDir) + + token, _, _, err := ReadCodexCliCredentials() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token != "custom-token" { + t.Errorf("token = %q, want %q", token, "custom-token") + } +} + +func TestCreateCodexCliTokenSource_Valid(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + authJSON := `{"tokens": {"access_token": "fresh-token", "refresh_token": "r", "account_id": "acc"}}` + if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + source := CreateCodexCliTokenSource() + token, accountID, err := source() + if err != nil { + t.Fatalf("token source error: %v", err) + } + if token != "fresh-token" { + t.Errorf("token = %q, want %q", token, "fresh-token") + } + if accountID != "acc" { + t.Errorf("accountID = %q, want %q", accountID, "acc") + } +} + +func TestCreateCodexCliTokenSource_Expired(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + authJSON := `{"tokens": {"access_token": "old-token", "refresh_token": "r"}}` + if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil { + t.Fatal(err) + } + + // Set file modification time to 2 hours ago + oldTime := time.Now().Add(-2 * time.Hour) + if err := os.Chtimes(authPath, oldTime, oldTime); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + source := CreateCodexCliTokenSource() + _, _, err := source() + if err == nil { + t.Fatal("expected error for expired credentials") + } +} diff --git a/pkg/providers/codex_cli_provider.go b/pkg/providers/codex_cli_provider.go new file mode 100644 index 000000000..4c783ece5 --- /dev/null +++ b/pkg/providers/codex_cli_provider.go @@ -0,0 +1,255 @@ +package providers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "os/exec" + "strings" +) + +// CodexCliProvider implements LLMProvider by wrapping the codex CLI as a subprocess. +type CodexCliProvider struct { + command string + workspace string +} + +// NewCodexCliProvider creates a new Codex CLI provider. +func NewCodexCliProvider(workspace string) *CodexCliProvider { + return &CodexCliProvider{ + command: "codex", + workspace: workspace, + } +} + +// Chat implements LLMProvider.Chat by executing the codex CLI in non-interactive mode. +func (p *CodexCliProvider) Chat( + ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any, +) (*LLMResponse, error) { + if p.command == "" { + return nil, fmt.Errorf("codex command not configured") + } + + prompt := p.buildPrompt(messages, tools) + + args := []string{ + "exec", + "--json", + "--dangerously-bypass-approvals-and-sandbox", + "--skip-git-repo-check", + "--color", "never", + } + if model != "" && model != "codex-cli" { + args = append(args, "-m", model) + } + if p.workspace != "" { + args = append(args, "-C", p.workspace) + } + args = append(args, "-") // read prompt from stdin + + cmd := exec.CommandContext(ctx, p.command, args...) + cmd.Stdin = bytes.NewReader([]byte(prompt)) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + + // Parse JSONL from stdout even if exit code is non-zero, + // because codex writes diagnostic noise to stderr (e.g. rollout errors) + // but still produces valid JSONL output. + if stdoutStr := stdout.String(); stdoutStr != "" { + resp, parseErr := p.parseJSONLEvents(stdoutStr) + if parseErr == nil && resp != nil && (resp.Content != "" || len(resp.ToolCalls) > 0) { + return resp, nil + } + } + + if err != nil { + if ctx.Err() == context.Canceled { + return nil, ctx.Err() + } + if stderrStr := stderr.String(); stderrStr != "" { + return nil, fmt.Errorf("codex cli error: %s", stderrStr) + } + return nil, fmt.Errorf("codex cli error: %w", err) + } + + return p.parseJSONLEvents(stdout.String()) +} + +// GetDefaultModel returns the default model identifier. +func (p *CodexCliProvider) GetDefaultModel() string { + return "codex-cli" +} + +// buildPrompt converts messages to a prompt string for the Codex CLI. +// System messages are prepended as instructions since Codex CLI has no --system-prompt flag. +func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinition) string { + var systemParts []string + var conversationParts []string + + for _, msg := range messages { + switch msg.Role { + case "system": + systemParts = append(systemParts, msg.Content) + case "user": + conversationParts = append(conversationParts, msg.Content) + case "assistant": + conversationParts = append(conversationParts, "Assistant: "+msg.Content) + case "tool": + conversationParts = append(conversationParts, + fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content)) + } + } + + var sb strings.Builder + + if len(systemParts) > 0 { + sb.WriteString("## System Instructions\n\n") + sb.WriteString(strings.Join(systemParts, "\n\n")) + sb.WriteString("\n\n## Task\n\n") + } + + if len(tools) > 0 { + sb.WriteString(p.buildToolsPrompt(tools)) + sb.WriteString("\n\n") + } + + // Simplify single user message (no prefix) + if len(conversationParts) == 1 && len(systemParts) == 0 && len(tools) == 0 { + return conversationParts[0] + } + + sb.WriteString(strings.Join(conversationParts, "\n")) + return sb.String() +} + +// buildToolsPrompt creates a tool definitions section for the prompt. +func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string { + var sb strings.Builder + + sb.WriteString("## Available Tools\n\n") + sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n") + sb.WriteString("```json\n") + sb.WriteString( + `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`, + ) + sb.WriteString("\n```\n\n") + sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n") + sb.WriteString("### Tool Definitions:\n\n") + + for _, tool := range tools { + if tool.Type != "function" { + continue + } + sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name)) + if tool.Function.Description != "" { + sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description)) + } + if len(tool.Function.Parameters) > 0 { + paramsJSON, _ := json.Marshal(tool.Function.Parameters) + sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON))) + } + sb.WriteString("\n") + } + + return sb.String() +} + +// codexEvent represents a single JSONL event from `codex exec --json`. +type codexEvent struct { + Type string `json:"type"` + ThreadID string `json:"thread_id,omitempty"` + Message string `json:"message,omitempty"` + Item *codexEventItem `json:"item,omitempty"` + Usage *codexUsage `json:"usage,omitempty"` + Error *codexEventErr `json:"error,omitempty"` +} + +type codexEventItem struct { + ID string `json:"id"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Command string `json:"command,omitempty"` + Status string `json:"status,omitempty"` + ExitCode *int `json:"exit_code,omitempty"` + Output string `json:"output,omitempty"` +} + +type codexUsage struct { + InputTokens int `json:"input_tokens"` + CachedInputTokens int `json:"cached_input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type codexEventErr struct { + Message string `json:"message"` +} + +// parseJSONLEvents processes the JSONL output from codex exec --json. +func (p *CodexCliProvider) parseJSONLEvents(output string) (*LLMResponse, error) { + var contentParts []string + var usage *UsageInfo + var lastError string + + scanner := bufio.NewScanner(strings.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + var event codexEvent + if err := json.Unmarshal([]byte(line), &event); err != nil { + continue // skip malformed lines + } + + switch event.Type { + case "item.completed": + if event.Item != nil && event.Item.Type == "agent_message" && event.Item.Text != "" { + contentParts = append(contentParts, event.Item.Text) + } + case "turn.completed": + if event.Usage != nil { + promptTokens := event.Usage.InputTokens + event.Usage.CachedInputTokens + usage = &UsageInfo{ + PromptTokens: promptTokens, + CompletionTokens: event.Usage.OutputTokens, + TotalTokens: promptTokens + event.Usage.OutputTokens, + } + } + case "error": + lastError = event.Message + case "turn.failed": + if event.Error != nil { + lastError = event.Error.Message + } + } + } + + if lastError != "" && len(contentParts) == 0 { + return nil, fmt.Errorf("codex cli: %s", lastError) + } + + content := strings.Join(contentParts, "\n") + + // Extract tool calls from response text (same pattern as ClaudeCliProvider) + toolCalls := extractToolCallsFromText(content) + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + content = stripToolCallsFromText(content) + } + + return &LLMResponse{ + Content: strings.TrimSpace(content), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + }, nil +} diff --git a/pkg/providers/codex_cli_provider_integration_test.go b/pkg/providers/codex_cli_provider_integration_test.go new file mode 100644 index 000000000..17a8305ad --- /dev/null +++ b/pkg/providers/codex_cli_provider_integration_test.go @@ -0,0 +1,117 @@ +//go:build integration + +package providers + +import ( + "context" + exec "os/exec" + "strings" + "testing" + "time" +) + +// TestIntegration_RealCodexCLI tests the CodexCliProvider with a real codex CLI. +// Run with: go test -tags=integration ./pkg/providers/... +func TestIntegration_RealCodexCLI(t *testing.T) { + path, err := exec.LookPath("codex") + if err != nil { + t.Skip("codex CLI not found in PATH, skipping integration test") + } + t.Logf("Using codex CLI at: %s", path) + + p := NewCodexCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "user", Content: "Respond with only the word 'pong'. Nothing else."}, + }, nil, "", nil) + if err != nil { + t.Fatalf("Chat() with real CLI error = %v", err) + } + + if resp.Content == "" { + t.Error("Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage != nil { + t.Logf("Usage: prompt=%d, completion=%d, total=%d", + resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens) + } + + t.Logf("Response content: %q", resp.Content) + + if !strings.Contains(strings.ToLower(resp.Content), "pong") { + t.Errorf("Content = %q, expected to contain 'pong'", resp.Content) + } +} + +func TestIntegration_RealCodexCLI_WithSystemPrompt(t *testing.T) { + if _, err := exec.LookPath("codex"); err != nil { + t.Skip("codex CLI not found in PATH") + } + + p := NewCodexCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "system", Content: "You are a calculator. Only respond with numbers. No text."}, + {Role: "user", Content: "What is 2+2?"}, + }, nil, "", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + t.Logf("Response: %q", resp.Content) + + if !strings.Contains(resp.Content, "4") { + t.Errorf("Content = %q, expected to contain '4'", resp.Content) + } +} + +func TestIntegration_RealCodexCLI_ParsesRealJSONL(t *testing.T) { + if _, err := exec.LookPath("codex"); err != nil { + t.Skip("codex CLI not found in PATH") + } + + // Run codex directly and verify our parser handles real output + cmd := exec.Command("codex", "exec", + "--json", + "--dangerously-bypass-approvals-and-sandbox", + "--skip-git-repo-check", + "--color", "never", + "-C", t.TempDir(), + "-") + cmd.Stdin = strings.NewReader("Say hi") + + output, err := cmd.Output() + if err != nil { + // codex may write diagnostic noise to stderr but still produce valid output + if len(output) == 0 { + t.Fatalf("codex CLI failed: %v", err) + } + } + + t.Logf("Raw CLI output (first 500 chars): %s", string(output[:min(len(output), 500)])) + + // Verify our parser can handle real output + p := NewCodexCliProvider("") + resp, err := p.parseJSONLEvents(string(output)) + if err != nil { + t.Fatalf("parseJSONLEvents() failed on real CLI output: %v", err) + } + + if resp.Content == "" { + t.Error("parsed Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want stop", resp.FinishReason) + } + + t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage) +} diff --git a/pkg/providers/codex_cli_provider_test.go b/pkg/providers/codex_cli_provider_test.go new file mode 100644 index 000000000..414e0844d --- /dev/null +++ b/pkg/providers/codex_cli_provider_test.go @@ -0,0 +1,585 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +// --- JSONL Event Parsing Tests --- + +func TestParseJSONLEvents_AgentMessage(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"thread.started","thread_id":"abc-123"} +{"type":"turn.started"} +{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Hello from Codex!"}} +{"type":"turn.completed","usage":{"input_tokens":100,"cached_input_tokens":50,"output_tokens":20}}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + if resp.Content != "Hello from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello from Codex!") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage == nil { + t.Fatal("Usage should not be nil") + } + if resp.Usage.PromptTokens != 150 { + t.Errorf("PromptTokens = %d, want 150", resp.Usage.PromptTokens) + } + if resp.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens) + } + if resp.Usage.TotalTokens != 170 { + t.Errorf("TotalTokens = %d, want 170", resp.Usage.TotalTokens) + } + if len(resp.ToolCalls) != 0 { + t.Errorf("ToolCalls should be empty, got %d", len(resp.ToolCalls)) + } +} + +func TestParseJSONLEvents_ToolCallExtraction(t *testing.T) { + p := &CodexCliProvider{} + toolCallText := `Let me read that file. +{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test.txt\"}"}}]}` + // Build valid JSONL by marshaling the event + item := codexEvent{ + Type: "item.completed", + Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText}, + } + itemJSON, _ := json.Marshal(item) + usageEvt := `{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":0,"output_tokens":20}}` + events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + usageEvt + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + if resp.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("ToolCalls count = %d, want 1", len(resp.ToolCalls)) + } + if resp.ToolCalls[0].Name != "read_file" { + t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file") + } + if resp.ToolCalls[0].ID != "call_1" { + t.Errorf("ToolCalls[0].ID = %q, want %q", resp.ToolCalls[0].ID, "call_1") + } + if resp.ToolCalls[0].Function.Arguments != `{"path":"/tmp/test.txt"}` { + t.Errorf("ToolCalls[0].Function.Arguments = %q", resp.ToolCalls[0].Function.Arguments) + } + // Content should have the tool call JSON stripped + if strings.Contains(resp.Content, "tool_calls") { + t.Errorf("Content should not contain tool_calls JSON, got: %q", resp.Content) + } +} + +func TestParseJSONLEvents_MultipleToolCalls(t *testing.T) { + p := &CodexCliProvider{} + toolCallText := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"a.txt\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"b.txt\",\"content\":\"hello\"}"}}]}` + item := codexEvent{ + Type: "item.completed", + Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText}, + } + itemJSON, _ := json.Marshal(item) + events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + `{"type":"turn.completed"}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + if len(resp.ToolCalls) != 2 { + t.Fatalf("ToolCalls count = %d, want 2", len(resp.ToolCalls)) + } + if resp.ToolCalls[0].Name != "read_file" { + t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file") + } + if resp.ToolCalls[1].Name != "write_file" { + t.Errorf("ToolCalls[1].Name = %q, want %q", resp.ToolCalls[1].Name, "write_file") + } + if resp.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } +} + +func TestParseJSONLEvents_MultipleMessages(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"turn.started"} +{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"First part."}} +{"type":"item.completed","item":{"id":"item_2","type":"command_execution","command":"ls","status":"completed"}} +{"type":"item.completed","item":{"id":"item_3","type":"agent_message","text":"Second part."}} +{"type":"turn.completed"}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + if resp.Content != "First part.\nSecond part." { + t.Errorf("Content = %q, want %q", resp.Content, "First part.\nSecond part.") + } +} + +func TestParseJSONLEvents_ErrorEvent(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"thread.started","thread_id":"abc"} +{"type":"turn.started"} +{"type":"error","message":"token expired"} +{"type":"turn.failed","error":{"message":"token expired"}}` + + _, err := p.parseJSONLEvents(events) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "token expired") { + t.Errorf("error = %q, want to contain 'token expired'", err.Error()) + } +} + +func TestParseJSONLEvents_TurnFailed(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"turn.started"} +{"type":"turn.failed","error":{"message":"rate limit exceeded"}}` + + _, err := p.parseJSONLEvents(events) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "rate limit exceeded") { + t.Errorf("error = %q, want to contain 'rate limit exceeded'", err.Error()) + } +} + +func TestParseJSONLEvents_ErrorWithContent(t *testing.T) { + p := &CodexCliProvider{} + // If there's an error but also content, return the content (partial success) + events := `{"type":"turn.started"} +{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Partial result."}} +{"type":"error","message":"connection reset"} +{"type":"turn.failed","error":{"message":"connection reset"}}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("should not error when content exists: %v", err) + } + if resp.Content != "Partial result." { + t.Errorf("Content = %q, want %q", resp.Content, "Partial result.") + } +} + +func TestParseJSONLEvents_EmptyOutput(t *testing.T) { + p := &CodexCliProvider{} + resp, err := p.parseJSONLEvents("") + if err != nil { + t.Fatalf("empty output should not error: %v", err) + } + if resp.Content != "" { + t.Errorf("Content = %q, want empty", resp.Content) + } +} + +func TestParseJSONLEvents_MalformedLines(t *testing.T) { + p := &CodexCliProvider{} + events := `not json at all +{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Good line."}} +another bad line +{"type":"turn.completed","usage":{"input_tokens":10,"output_tokens":5}}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("should skip malformed lines: %v", err) + } + if resp.Content != "Good line." { + t.Errorf("Content = %q, want %q", resp.Content, "Good line.") + } + if resp.Usage == nil || resp.Usage.TotalTokens != 15 { + t.Errorf("Usage.TotalTokens = %v, want 15", resp.Usage) + } +} + +func TestParseJSONLEvents_CommandExecution(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"turn.started"} +{"type":"item.started","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"in_progress"}} +{"type":"item.completed","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"completed","exit_code":0,"output":"file1.go\nfile2.go"}} +{"type":"item.completed","item":{"id":"item_2","type":"agent_message","text":"Found 2 files."}} +{"type":"turn.completed"}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + // command_execution items should be skipped; only agent_message text is returned + if resp.Content != "Found 2 files." { + t.Errorf("Content = %q, want %q", resp.Content, "Found 2 files.") + } +} + +func TestParseJSONLEvents_NoUsage(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"turn.started"} +{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"No usage info."}} +{"type":"turn.completed"}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + if resp.Usage != nil { + t.Errorf("Usage should be nil when turn.completed has no usage, got %+v", resp.Usage) + } +} + +// --- Prompt Building Tests --- + +func TestBuildPrompt_SystemAsInstructions(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hi there"}, + } + + prompt := p.buildPrompt(messages, nil) + + if !strings.Contains(prompt, "## System Instructions") { + t.Error("prompt should contain '## System Instructions'") + } + if !strings.Contains(prompt, "You are helpful.") { + t.Error("prompt should contain system content") + } + if !strings.Contains(prompt, "## Task") { + t.Error("prompt should contain '## Task'") + } + if !strings.Contains(prompt, "Hi there") { + t.Error("prompt should contain user message") + } +} + +func TestBuildPrompt_NoSystem(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "user", Content: "Just a question"}, + } + + prompt := p.buildPrompt(messages, nil) + + if strings.Contains(prompt, "## System Instructions") { + t.Error("prompt should not contain system instructions header") + } + if prompt != "Just a question" { + t.Errorf("prompt = %q, want %q", prompt, "Just a question") + } +} + +func TestBuildPrompt_WithTools(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "user", Content: "Get weather"}, + } + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get current weather", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, + }, + }, + }, + } + + prompt := p.buildPrompt(messages, tools) + + if !strings.Contains(prompt, "## Available Tools") { + t.Error("prompt should contain tools section") + } + if !strings.Contains(prompt, "get_weather") { + t.Error("prompt should contain tool name") + } + if !strings.Contains(prompt, "Get current weather") { + t.Error("prompt should contain tool description") + } +} + +func TestBuildPrompt_MultipleMessages(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi! How can I help?"}, + {Role: "user", Content: "Tell me about Go"}, + } + + prompt := p.buildPrompt(messages, nil) + + if !strings.Contains(prompt, "Hello") { + t.Error("prompt should contain first user message") + } + if !strings.Contains(prompt, "Assistant: Hi! How can I help?") { + t.Error("prompt should contain assistant message with prefix") + } + if !strings.Contains(prompt, "Tell me about Go") { + t.Error("prompt should contain second user message") + } +} + +func TestBuildPrompt_ToolResults(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "user", Content: "Weather?"}, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + + prompt := p.buildPrompt(messages, nil) + + if !strings.Contains(prompt, "[Tool Result for call_1]") { + t.Error("prompt should contain tool result") + } + if !strings.Contains(prompt, `{"temp": 72}`) { + t.Error("prompt should contain tool result content") + } +} + +func TestBuildPrompt_SystemAndTools(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "system", Content: "Be concise."}, + {Role: "user", Content: "Do something"}, + } + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "my_tool", + Description: "A tool", + }, + }, + } + + prompt := p.buildPrompt(messages, tools) + + // System instructions should come first + sysIdx := strings.Index(prompt, "## System Instructions") + toolIdx := strings.Index(prompt, "## Available Tools") + taskIdx := strings.Index(prompt, "## Task") + + if sysIdx == -1 || toolIdx == -1 || taskIdx == -1 { + t.Fatal("prompt should contain all sections") + } + if sysIdx >= taskIdx { + t.Error("system instructions should come before task") + } + if taskIdx >= toolIdx { + t.Error("task section should come before tools in the output") + } +} + +// --- CLI Argument Tests --- + +func TestCodexCliProvider_GetDefaultModel(t *testing.T) { + p := NewCodexCliProvider("") + if got := p.GetDefaultModel(); got != "codex-cli" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "codex-cli") + } +} + +// --- Mock CLI Integration Test --- + +func createMockCodexCLI(t *testing.T, events []string) string { + t.Helper() + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "codex") + + var sb strings.Builder + sb.WriteString("#!/bin/bash\n") + for _, event := range events { + sb.WriteString(fmt.Sprintf("echo '%s'\n", event)) + } + + if err := os.WriteFile(scriptPath, []byte(sb.String()), 0o755); err != nil { + t.Fatal(err) + } + return scriptPath +} + +func TestCodexCliProvider_MockCLI_Success(t *testing.T) { + scriptPath := createMockCodexCLI(t, []string{ + `{"type":"thread.started","thread_id":"test-123"}`, + `{"type":"turn.started"}`, + `{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Mock response from Codex CLI"}}`, + `{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":10,"output_tokens":15}}`, + }) + + p := &CodexCliProvider{ + command: scriptPath, + workspace: "", + } + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := p.Chat(context.Background(), messages, nil, "", nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Mock response from Codex CLI" { + t.Errorf("Content = %q, want %q", resp.Content, "Mock response from Codex CLI") + } + if resp.Usage == nil { + t.Fatal("Usage should not be nil") + } + if resp.Usage.PromptTokens != 60 { + t.Errorf("PromptTokens = %d, want 60", resp.Usage.PromptTokens) + } + if resp.Usage.CompletionTokens != 15 { + t.Errorf("CompletionTokens = %d, want 15", resp.Usage.CompletionTokens) + } +} + +func TestCodexCliProvider_MockCLI_Error(t *testing.T) { + scriptPath := createMockCodexCLI(t, []string{ + `{"type":"thread.started","thread_id":"test-err"}`, + `{"type":"turn.started"}`, + `{"type":"error","message":"auth token expired"}`, + `{"type":"turn.failed","error":{"message":"auth token expired"}}`, + }) + + p := &CodexCliProvider{ + command: scriptPath, + workspace: "", + } + + messages := []Message{{Role: "user", Content: "Hello"}} + _, err := p.Chat(context.Background(), messages, nil, "", nil) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "auth token expired") { + t.Errorf("error = %q, want to contain 'auth token expired'", err.Error()) + } +} + +func TestCodexCliProvider_MockCLI_WithModel(t *testing.T) { + // Mock script that captures args to verify model flag is passed + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "codex") + script := `#!/bin/bash +# Write args to a file for verification +echo "$@" > "` + filepath.Join(tmpDir, "args.txt") + `" +echo '{"type":"item.completed","item":{"id":"1","type":"agent_message","text":"ok"}}' +echo '{"type":"turn.completed"}'` + + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatal(err) + } + + p := &CodexCliProvider{ + command: scriptPath, + workspace: "/tmp/test-workspace", + } + + messages := []Message{{Role: "user", Content: "test"}} + _, err := p.Chat(context.Background(), messages, nil, "gpt-5.2-codex", nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + + // Verify the args + argsData, err := os.ReadFile(filepath.Join(tmpDir, "args.txt")) + if err != nil { + t.Fatalf("reading args: %v", err) + } + args := string(argsData) + + if !strings.Contains(args, "-m gpt-5.2-codex") { + t.Errorf("args should contain model flag, got: %s", args) + } + if !strings.Contains(args, "-C /tmp/test-workspace") { + t.Errorf("args should contain workspace flag, got: %s", args) + } + if !strings.Contains(args, "--json") { + t.Errorf("args should contain --json, got: %s", args) + } + if !strings.Contains(args, "--dangerously-bypass-approvals-and-sandbox") { + t.Errorf("args should contain bypass flag, got: %s", args) + } +} + +func TestCodexCliProvider_MockCLI_ContextCancel(t *testing.T) { + // Script that sleeps forever + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "codex") + script := "#!/bin/bash\nsleep 60" + + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatal(err) + } + + p := &CodexCliProvider{ + command: scriptPath, + workspace: "", + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + messages := []Message{{Role: "user", Content: "test"}} + _, err := p.Chat(ctx, messages, nil, "", nil) + if err == nil { + t.Fatal("expected error on canceled context") + } +} + +func TestCodexCliProvider_EmptyCommand(t *testing.T) { + p := &CodexCliProvider{command: ""} + + messages := []Message{{Role: "user", Content: "test"}} + _, err := p.Chat(context.Background(), messages, nil, "", nil) + if err == nil { + t.Fatal("expected error for empty command") + } +} + +// --- Integration Test (requires real codex CLI with valid auth) --- + +func TestCodexCliProvider_Integration(t *testing.T) { + if os.Getenv("PICOCLAW_INTEGRATION_TESTS") == "" { + t.Skip("skipping integration test (set PICOCLAW_INTEGRATION_TESTS=1 to enable)") + } + + // Verify codex is available + codexPath, err := exec.LookPath("codex") + if err != nil { + t.Skip("codex CLI not found in PATH") + } + + p := &CodexCliProvider{ + command: codexPath, + workspace: "", + } + + messages := []Message{ + {Role: "user", Content: "Respond with just the word 'hello' and nothing else."}, + } + + resp, err := p.Chat(context.Background(), messages, nil, "", nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + + lower := strings.ToLower(strings.TrimSpace(resp.Content)) + if !strings.Contains(lower, "hello") { + t.Errorf("Content = %q, expected to contain 'hello'", resp.Content) + } +} diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index 3463389a5..ecc983642 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -3,44 +3,75 @@ package providers import ( "context" "encoding/json" + "errors" "fmt" "strings" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/responses" + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + codexDefaultModel = "gpt-5.2" + codexDefaultInstructions = "You are Codex, a coding assistant." ) type CodexProvider struct { - client *openai.Client - accountID string - tokenSource func() (string, string, error) + client *openai.Client + accountID string + tokenSource func() (string, string, error) + enableWebSearch bool } +const defaultCodexInstructions = "You are Codex, a coding assistant." + func NewCodexProvider(token, accountID string) *CodexProvider { opts := []option.RequestOption{ option.WithBaseURL("https://chatgpt.com/backend-api/codex"), option.WithAPIKey(token), + option.WithHeader("originator", "codex_cli_rs"), + option.WithHeader("OpenAI-Beta", "responses=experimental"), } if accountID != "" { opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID)) } client := openai.NewClient(opts...) return &CodexProvider{ - client: &client, - accountID: accountID, + client: &client, + accountID: accountID, + enableWebSearch: true, } } -func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider { +func NewCodexProviderWithTokenSource( + token, accountID string, tokenSource func() (string, string, error), +) *CodexProvider { p := NewCodexProvider(token, accountID) p.tokenSource = tokenSource return p } -func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { +func (p *CodexProvider) Chat( + ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any, +) (*LLMResponse, error) { var opts []option.RequestOption + accountID := p.accountID + resolvedModel, fallbackReason := resolveCodexModel(model) + if fallbackReason != "" { + logger.WarnCF( + "provider.codex", + "Requested model is not compatible with Codex backend, using fallback", + map[string]any{ + "requested_model": model, + "resolved_model": resolvedModel, + "reason": fallbackReason, + }, + ) + } if p.tokenSource != nil { tok, accID, err := p.tokenSource() if err != nil { @@ -48,25 +79,129 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To } opts = append(opts, option.WithAPIKey(tok)) if accID != "" { - opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID)) + accountID = accID } } + if accountID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID)) + } else { + logger.WarnCF( + "provider.codex", + "No account id found for Codex request; backend may reject with 400", + map[string]any{ + "requested_model": model, + "resolved_model": resolvedModel, + }, + ) + } + + params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch) - params := buildCodexParams(messages, tools, model, options) + stream := p.client.Responses.NewStreaming(ctx, params, opts...) + defer stream.Close() - resp, err := p.client.Responses.New(ctx, params, opts...) + var resp *responses.Response + for stream.Next() { + evt := stream.Current() + if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" { + evtResp := evt.Response + if evtResp.ID != "" { + copy := evtResp + resp = © + } + } + } + err := stream.Err() if err != nil { + fields := map[string]any{ + "requested_model": model, + "resolved_model": resolvedModel, + "messages_count": len(messages), + "tools_count": len(tools), + "account_id_present": accountID != "", + "error": err.Error(), + } + var apiErr *openai.Error + if errors.As(err, &apiErr) { + fields["status_code"] = apiErr.StatusCode + fields["api_type"] = apiErr.Type + fields["api_code"] = apiErr.Code + fields["api_param"] = apiErr.Param + fields["api_message"] = apiErr.Message + if apiErr.StatusCode == 400 { + fields["hint"] = "verify account id header and model compatibility for codex backend" + } + if apiErr.Response != nil { + fields["request_id"] = apiErr.Response.Header.Get("x-request-id") + } + } + logger.ErrorCF("provider.codex", "Codex API call failed", fields) return nil, fmt.Errorf("codex API call: %w", err) } + if resp == nil { + fields := map[string]any{ + "requested_model": model, + "resolved_model": resolvedModel, + "messages_count": len(messages), + "tools_count": len(tools), + "account_id_present": accountID != "", + } + logger.ErrorCF("provider.codex", "Codex stream ended without completed response event", fields) + return nil, fmt.Errorf("codex API call: stream ended without completed response") + } return parseCodexResponse(resp), nil } func (p *CodexProvider) GetDefaultModel() string { - return "gpt-4o" + return codexDefaultModel +} + +func resolveCodexModel(model string) (string, string) { + m := strings.ToLower(strings.TrimSpace(model)) + if m == "" { + return codexDefaultModel, "empty model" + } + + if strings.HasPrefix(m, "openai/") { + m = strings.TrimPrefix(m, "openai/") + } else if strings.Contains(m, "/") { + return codexDefaultModel, "non-openai model namespace" + } + + unsupportedPrefixes := []string{ + "glm", + "claude", + "anthropic", + "gemini", + "google", + "moonshot", + "kimi", + "qwen", + "deepseek", + "llama", + "meta-llama", + "mistral", + "grok", + "xai", + "zhipu", + } + for _, prefix := range unsupportedPrefixes { + if strings.HasPrefix(m, prefix) { + return codexDefaultModel, "unsupported model prefix" + } + } + + if strings.HasPrefix(m, "gpt-") || strings.HasPrefix(m, "o3") || strings.HasPrefix(m, "o4") { + return m, "" + } + + return codexDefaultModel, "unsupported model family" } -func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams { +func buildCodexParams( + messages []Message, tools []ToolDefinition, model string, options map[string]any, enableWebSearch bool, +) responses.ResponseNewParams { var inputItems responses.ResponseInputParam var instructions string @@ -79,7 +214,9 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ CallID: msg.ToolCallID, - Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)}, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{ + OfString: openai.Opt(msg.Content), + }, }, }) } else { @@ -101,12 +238,18 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, }) } for _, tc := range msg.ToolCalls { - argsJSON, _ := json.Marshal(tc.Arguments) + name, args, ok := resolveCodexToolCall(tc) + if !ok { + logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]any{ + "call_id": tc.ID, + }) + continue + } inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ OfFunctionCall: &responses.ResponseFunctionToolCallParam{ CallID: tc.ID, - Name: tc.Name, - Arguments: string(argsJSON), + Name: name, + Arguments: args, }, }) } @@ -122,7 +265,9 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ CallID: msg.ToolCallID, - Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)}, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{ + OfString: openai.Opt(msg.Content), + }, }, }) } @@ -133,31 +278,61 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, Input: responses.ResponseNewParamsInputUnion{ OfInputItemList: inputItems, }, - Store: openai.Opt(false), + Instructions: openai.Opt(instructions), + Store: openai.Opt(false), } if instructions != "" { params.Instructions = openai.Opt(instructions) + } else { + // ChatGPT Codex backend requires instructions to be present. + params.Instructions = openai.Opt(defaultCodexInstructions) } - if maxTokens, ok := options["max_tokens"].(int); ok { - params.MaxOutputTokens = openai.Opt(int64(maxTokens)) + if len(tools) > 0 || enableWebSearch { + params.Tools = translateToolsForCodex(tools, enableWebSearch) } - if temp, ok := options["temperature"].(float64); ok { - params.Temperature = openai.Opt(temp) + return params +} + +func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) { + name = tc.Name + if name == "" && tc.Function != nil { + name = tc.Function.Name + } + if name == "" { + return "", "", false + } + + if len(tc.Arguments) > 0 { + argsJSON, err := json.Marshal(tc.Arguments) + if err != nil { + return "", "", false + } + return name, string(argsJSON), true } - if len(tools) > 0 { - params.Tools = translateToolsForCodex(tools) + if tc.Function != nil && tc.Function.Arguments != "" { + return name, tc.Function.Arguments, true } - return params + return name, "{}", true } -func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam { - result := make([]responses.ToolUnionParam, 0, len(tools)) +func translateToolsForCodex(tools []ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam { + capHint := len(tools) + if enableWebSearch { + capHint++ + } + result := make([]responses.ToolUnionParam, 0, capHint) for _, t := range tools { + if t.Type != "function" { + continue + } + if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") { + continue + } ft := responses.FunctionToolParam{ Name: t.Function.Name, Parameters: t.Function.Parameters, @@ -168,6 +343,9 @@ func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam { } result = append(result, responses.ToolUnionParam{OfFunction: &ft}) } + if enableWebSearch { + result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch)) + } return result } @@ -184,9 +362,9 @@ func parseCodexResponse(resp *responses.Response) *LLMResponse { } } case "function_call": - var args map[string]interface{} + var args map[string]any if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil { - args = map[string]interface{}{"raw": item.Arguments} + args = map[string]any{"raw": item.Arguments} } toolCalls = append(toolCalls, ToolCall{ ID: item.CallID, @@ -237,6 +415,9 @@ func createCodexTokenSource() func() (string, string, error) { if err != nil { return "", "", fmt.Errorf("refreshing token: %w", err) } + if refreshed.AccountID == "" { + refreshed.AccountID = cred.AccountID + } if err := auth.SetCredential("openai", refreshed); err != nil { return "", "", fmt.Errorf("saving refreshed token: %w", err) } diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go index 605183d5e..4157e53e9 100644 --- a/pkg/providers/codex_provider_test.go +++ b/pkg/providers/codex_provider_test.go @@ -2,6 +2,7 @@ package providers import ( "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" @@ -15,12 +16,22 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) { messages := []Message{ {Role: "user", Content: "Hello"}, } - params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{ - "max_tokens": 2048, - }) + params := buildCodexParams(messages, nil, "gpt-4o", map[string]any{ + "max_tokens": 2048, + "temperature": 0.7, + }, true) if params.Model != "gpt-4o" { t.Errorf("Model = %q, want %q", params.Model, "gpt-4o") } + if !params.Instructions.Valid() { + t.Fatal("Instructions should be set") + } + if params.Instructions.Or("") != defaultCodexInstructions { + t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), defaultCodexInstructions) + } + if params.MaxOutputTokens.Valid() { + t.Fatalf("MaxOutputTokens should not be set for Codex backend") + } } func TestBuildCodexParams_SystemAsInstructions(t *testing.T) { @@ -28,7 +39,7 @@ func TestBuildCodexParams_SystemAsInstructions(t *testing.T) { {Role: "system", Content: "You are helpful"}, {Role: "user", Content: "Hi"}, } - params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams(messages, nil, "gpt-4o", map[string]any{}, true) if !params.Instructions.Valid() { t.Fatal("Instructions should be set") } @@ -43,12 +54,12 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) { { Role: "assistant", ToolCalls: []ToolCall{ - {ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}}, + {ID: "call_1", Name: "get_weather", Arguments: map[string]any{"city": "SF"}}, }, }, {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, } - params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams(messages, nil, "gpt-4o", map[string]any{}, false) if params.Input.OfInputItemList == nil { t.Fatal("Input.OfInputItemList should not be nil") } @@ -57,6 +68,45 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) { } } +func TestBuildCodexParams_ToolCallFunctionFallback(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Read a file"}, + { + Role: "assistant", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &FunctionCall{ + Name: "read_file", + Arguments: `{"path":"README.md"}`, + }, + }, + }, + }, + {Role: "tool", Content: "ok", ToolCallID: "call_1"}, + } + + params := buildCodexParams(messages, nil, "gpt-4o", map[string]any{}, false) + if params.Input.OfInputItemList == nil { + t.Fatal("Input.OfInputItemList should not be nil") + } + if len(params.Input.OfInputItemList) != 3 { + t.Fatalf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList)) + } + + fc := params.Input.OfInputItemList[1].OfFunctionCall + if fc == nil { + t.Fatal("assistant tool call should be converted to function_call input item") + } + if fc.Name != "read_file" { + t.Errorf("Function call name = %q, want %q", fc.Name, "read_file") + } + if fc.Arguments != `{"path":"README.md"}` { + t.Errorf("Function call arguments = %q, want %q", fc.Arguments, `{"path":"README.md"}`) + } +} + func TestBuildCodexParams_WithTools(t *testing.T) { tools := []ToolDefinition{ { @@ -64,16 +114,16 @@ func TestBuildCodexParams_WithTools(t *testing.T) { Function: ToolFunctionDefinition{ Name: "get_weather", Description: "Get weather", - Parameters: map[string]interface{}{ + Parameters: map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "city": map[string]interface{}{"type": "string"}, + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, }, }, }, }, } - params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]any{}, false) if len(params.Tools) != 1 { t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) } @@ -86,12 +136,65 @@ func TestBuildCodexParams_WithTools(t *testing.T) { } func TestBuildCodexParams_StoreIsFalse(t *testing.T) { - params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]any{}, false) if !params.Store.Valid() || params.Store.Or(true) != false { t.Error("Store should be explicitly set to false") } } +func TestBuildCodexParams_DefaultWebSearchEnabled(t *testing.T) { + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]any{}, true) + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } + if params.Tools[0].OfWebSearch == nil { + t.Fatal("Tool should include built-in web_search") + } + if params.Tools[0].OfWebSearch.Type != responses.WebSearchToolTypeWebSearch { + t.Errorf( + "Web search tool type = %q, want %q", + params.Tools[0].OfWebSearch.Type, + responses.WebSearchToolTypeWebSearch, + ) + } +} + +func TestBuildCodexParams_WebSearchFunctionReplacedWithBuiltin(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "web_search", + Description: "local web search", + Parameters: map[string]any{ + "type": "object", + }, + }, + }, + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "read_file", + Description: "read file", + Parameters: map[string]any{ + "type": "object", + }, + }, + }, + } + + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]any{}, true) + if len(params.Tools) != 2 { + t.Fatalf("len(Tools) = %d, want 2", len(params.Tools)) + } + if params.Tools[0].OfFunction == nil || params.Tools[0].OfFunction.Name != "read_file" { + t.Fatalf("first tool should be function read_file, got %#v", params.Tools[0]) + } + if params.Tools[1].OfWebSearch == nil { + t.Fatalf("second tool should be built-in web_search, got %#v", params.Tools[1]) + } +} + func TestParseCodexResponse_TextOutput(t *testing.T) { respJSON := `{ "id": "resp_test", @@ -197,31 +300,54 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { return } - resp := map[string]interface{}{ + var reqBody map[string]any + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if reqBody["stream"] != true { + http.Error(w, "stream must be true", http.StatusBadRequest) + return + } + if _, ok := reqBody["max_output_tokens"]; ok { + http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest) + return + } + toolsAny, ok := reqBody["tools"].([]any) + if !ok || len(toolsAny) != 1 { + http.Error(w, "missing default web search tool", http.StatusBadRequest) + return + } + toolObj, ok := toolsAny[0].(map[string]any) + if !ok || toolObj["type"] != "web_search" { + http.Error(w, "expected web_search tool", http.StatusBadRequest) + return + } + + resp := map[string]any{ "id": "resp_test", "object": "response", "status": "completed", - "output": []map[string]interface{}{ + "output": []map[string]any{ { "id": "msg_1", "type": "message", "role": "assistant", "status": "completed", - "content": []map[string]interface{}{ + "content": []map[string]any{ {"type": "output_text", "text": "Hi from Codex!"}, }, }, }, - "usage": map[string]interface{}{ + "usage": map[string]any{ "input_tokens": 12, "output_tokens": 6, "total_tokens": 18, - "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, - "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + "input_tokens_details": map[string]any{"cached_tokens": 0}, + "output_tokens_details": map[string]any{"reasoning_tokens": 0}, }, } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + writeCompletedSSE(w, resp) })) defer server.Close() @@ -229,7 +355,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") messages := []Message{{Role: "user", Content: "Hello"}} - resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024}) + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]any{"max_tokens": 1024}) if err != nil { t.Fatalf("Chat() error: %v", err) } @@ -244,10 +370,252 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { } } +func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + + var reqBody map[string]any + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if _, ok := reqBody["tools"]; ok { + http.Error(w, "tools should be absent when web search disabled", http.StatusBadRequest) + return + } + + resp := map[string]any{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]any{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]any{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]any{ + "input_tokens": 4, + "output_tokens": 3, + "total_tokens": 7, + "input_tokens_details": map[string]any{"cached_tokens": 0}, + "output_tokens_details": map[string]any{"reasoning_tokens": 0}, + }, + } + writeCompletedSSE(w, resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.enableWebSearch = false + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]any{}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } +} + +func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer refreshed-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Header.Get("Chatgpt-Account-Id") != "acc-123" { + http.Error(w, "missing account id", http.StatusBadRequest) + return + } + + var reqBody map[string]any + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if _, ok := reqBody["instructions"]; !ok { + http.Error(w, "missing instructions", http.StatusBadRequest) + return + } + if reqBody["instructions"] == "" { + http.Error(w, "instructions must not be empty", http.StatusBadRequest) + return + } + if _, ok := reqBody["temperature"]; ok { + http.Error(w, "temperature is not supported", http.StatusBadRequest) + return + } + if _, ok := reqBody["max_output_tokens"]; ok { + http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest) + return + } + if reqBody["stream"] != true { + http.Error(w, "stream must be true", http.StatusBadRequest) + return + } + + resp := map[string]any{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]any{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]any{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]any{ + "input_tokens": 8, + "output_tokens": 4, + "total_tokens": 12, + "input_tokens_details": map[string]any{"cached_tokens": 0}, + "output_tokens_details": map[string]any{"reasoning_tokens": 0}, + }, + } + writeCompletedSSE(w, resp) + })) + defer server.Close() + + provider := NewCodexProvider("stale-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "stale-token", "") + provider.tokenSource = func() (string, string, error) { + return "refreshed-token", "", nil + } + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]any{"temperature": 0.7}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } +} + +func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + + var reqBody map[string]any + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if reqBody["model"] != codexDefaultModel { + http.Error(w, "unsupported model", http.StatusBadRequest) + return + } + if reqBody["stream"] != true { + http.Error(w, "stream must be true", http.StatusBadRequest) + return + } + if reqBody["instructions"] != codexDefaultInstructions { + http.Error(w, "missing default instructions", http.StatusBadRequest) + return + } + + resp := map[string]any{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]any{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]any{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]any{ + "input_tokens": 8, + "output_tokens": 4, + "total_tokens": 12, + "input_tokens_details": map[string]any{"cached_tokens": 0}, + "output_tokens_details": map[string]any{"reasoning_tokens": 0}, + }, + } + writeCompletedSSE(w, resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-5.2", nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } +} + func TestCodexProvider_GetDefaultModel(t *testing.T) { p := NewCodexProvider("test-token", "") - if got := p.GetDefaultModel(); got != "gpt-4o" { - t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o") + if got := p.GetDefaultModel(); got != codexDefaultModel { + t.Errorf("GetDefaultModel() = %q, want %q", got, codexDefaultModel) + } +} + +func TestResolveCodexModel(t *testing.T) { + tests := []struct { + name string + input string + wantModel string + wantFallback bool + }{ + {name: "empty", input: "", wantModel: codexDefaultModel, wantFallback: true}, + { + name: "unsupported namespace", + input: "anthropic/claude-3.5", + wantModel: codexDefaultModel, + wantFallback: true, + }, + {name: "non-openai prefixed", input: "glm-4.7", wantModel: codexDefaultModel, wantFallback: true}, + {name: "openai prefix", input: "openai/gpt-5.2", wantModel: "gpt-5.2", wantFallback: false}, + {name: "direct gpt", input: "gpt-4o", wantModel: "gpt-4o", wantFallback: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotModel, reason := resolveCodexModel(tt.input) + if gotModel != tt.wantModel { + t.Fatalf("resolveCodexModel(%q) model = %q, want %q", tt.input, gotModel, tt.wantModel) + } + if tt.wantFallback && reason == "" { + t.Fatalf("resolveCodexModel(%q) expected fallback reason", tt.input) + } + if !tt.wantFallback && reason != "" { + t.Fatalf("resolveCodexModel(%q) unexpected fallback reason: %q", tt.input, reason) + } + }) } } @@ -262,3 +630,16 @@ func createOpenAITestClient(baseURL, token, accountID string) *openai.Client { c := openai.NewClient(opts...) return &c } + +func writeCompletedSSE(w http.ResponseWriter, response map[string]any) { + event := map[string]any{ + "type": "response.completed", + "sequence_number": 1, + "response": response, + } + b, _ := json.Marshal(event) + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "event: response.completed\n") + fmt.Fprintf(w, "data: %s\n\n", string(b)) + fmt.Fprintf(w, "data: [DONE]\n\n") +} diff --git a/pkg/providers/cooldown.go b/pkg/providers/cooldown.go new file mode 100644 index 000000000..b0d8608dc --- /dev/null +++ b/pkg/providers/cooldown.go @@ -0,0 +1,207 @@ +package providers + +import ( + "math" + "sync" + "time" +) + +const ( + defaultFailureWindow = 24 * time.Hour +) + +// CooldownTracker manages per-provider cooldown state for the fallback chain. +// Thread-safe via sync.RWMutex. In-memory only (resets on restart). +type CooldownTracker struct { + mu sync.RWMutex + entries map[string]*cooldownEntry + failureWindow time.Duration + nowFunc func() time.Time // for testing +} + +type cooldownEntry struct { + ErrorCount int + FailureCounts map[FailoverReason]int + CooldownEnd time.Time // standard cooldown expiry + DisabledUntil time.Time // billing-specific disable expiry + DisabledReason FailoverReason // reason for disable (billing) + LastFailure time.Time +} + +// NewCooldownTracker creates a tracker with default 24h failure window. +func NewCooldownTracker() *CooldownTracker { + return &CooldownTracker{ + entries: make(map[string]*cooldownEntry), + failureWindow: defaultFailureWindow, + nowFunc: time.Now, + } +} + +// MarkFailure records a failure for a provider and sets appropriate cooldown. +// Resets error counts if last failure was more than failureWindow ago. +func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) { + ct.mu.Lock() + defer ct.mu.Unlock() + + now := ct.nowFunc() + entry := ct.getOrCreate(provider) + + // 24h failure window reset: if no failure in failureWindow, reset counters. + if !entry.LastFailure.IsZero() && now.Sub(entry.LastFailure) > ct.failureWindow { + entry.ErrorCount = 0 + entry.FailureCounts = make(map[FailoverReason]int) + } + + entry.ErrorCount++ + entry.FailureCounts[reason]++ + entry.LastFailure = now + + if reason == FailoverBilling { + billingCount := entry.FailureCounts[FailoverBilling] + entry.DisabledUntil = now.Add(calculateBillingCooldown(billingCount)) + entry.DisabledReason = FailoverBilling + } else { + entry.CooldownEnd = now.Add(calculateStandardCooldown(entry.ErrorCount)) + } +} + +// MarkSuccess resets all counters and cooldowns for a provider. +func (ct *CooldownTracker) MarkSuccess(provider string) { + ct.mu.Lock() + defer ct.mu.Unlock() + + entry := ct.entries[provider] + if entry == nil { + return + } + + entry.ErrorCount = 0 + entry.FailureCounts = make(map[FailoverReason]int) + entry.CooldownEnd = time.Time{} + entry.DisabledUntil = time.Time{} + entry.DisabledReason = "" +} + +// IsAvailable returns true if the provider is not in cooldown or disabled. +func (ct *CooldownTracker) IsAvailable(provider string) bool { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return true + } + + now := ct.nowFunc() + + // Billing disable takes precedence (longer cooldown). + if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) { + return false + } + + // Standard cooldown. + if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) { + return false + } + + return true +} + +// CooldownRemaining returns how long until the provider becomes available. +// Returns 0 if already available. +func (ct *CooldownTracker) CooldownRemaining(provider string) time.Duration { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + + now := ct.nowFunc() + var remaining time.Duration + + if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) { + d := entry.DisabledUntil.Sub(now) + if d > remaining { + remaining = d + } + } + + if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) { + d := entry.CooldownEnd.Sub(now) + if d > remaining { + remaining = d + } + } + + return remaining +} + +// ErrorCount returns the current error count for a provider. +func (ct *CooldownTracker) ErrorCount(provider string) int { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + return entry.ErrorCount +} + +// FailureCount returns the failure count for a specific reason. +func (ct *CooldownTracker) FailureCount(provider string, reason FailoverReason) int { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + return entry.FailureCounts[reason] +} + +func (ct *CooldownTracker) getOrCreate(provider string) *cooldownEntry { + entry := ct.entries[provider] + if entry == nil { + entry = &cooldownEntry{ + FailureCounts: make(map[FailoverReason]int), + } + ct.entries[provider] = entry + } + return entry +} + +// calculateStandardCooldown computes standard exponential backoff. +// Formula from OpenClaw: min(1h, 1min * 5^min(n-1, 3)) +// +// 1 error → 1 min +// 2 errors → 5 min +// 3 errors → 25 min +// 4+ errors → 1 hour (cap) +func calculateStandardCooldown(errorCount int) time.Duration { + n := max(1, errorCount) + exp := min(n-1, 3) + ms := 60_000 * int(math.Pow(5, float64(exp))) + ms = min(3_600_000, ms) // cap at 1 hour + return time.Duration(ms) * time.Millisecond +} + +// calculateBillingCooldown computes billing-specific exponential backoff. +// Formula from OpenClaw: min(24h, 5h * 2^min(n-1, 10)) +// +// 1 error → 5 hours +// 2 errors → 10 hours +// 3 errors → 20 hours +// 4+ errors → 24 hours (cap) +func calculateBillingCooldown(billingErrorCount int) time.Duration { + const baseMs = 5 * 60 * 60 * 1000 // 5 hours + const maxMs = 24 * 60 * 60 * 1000 // 24 hours + + n := max(1, billingErrorCount) + exp := min(n-1, 10) + raw := float64(baseMs) * math.Pow(2, float64(exp)) + ms := int(math.Min(float64(maxMs), raw)) + return time.Duration(ms) * time.Millisecond +} diff --git a/pkg/providers/cooldown_test.go b/pkg/providers/cooldown_test.go new file mode 100644 index 000000000..47f43ad5c --- /dev/null +++ b/pkg/providers/cooldown_test.go @@ -0,0 +1,269 @@ +package providers + +import ( + "sync" + "testing" + "time" +) + +func newTestTracker(now time.Time) (*CooldownTracker, *time.Time) { + current := now + ct := NewCooldownTracker() + ct.nowFunc = func() time.Time { return current } + return ct, ¤t +} + +func TestCooldown_InitiallyAvailable(t *testing.T) { + ct := NewCooldownTracker() + if !ct.IsAvailable("openai") { + t.Error("new provider should be available") + } + if ct.ErrorCount("openai") != 0 { + t.Error("new provider should have 0 errors") + } +} + +func TestCooldown_StandardEscalation(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 1st error → 1 min cooldown + ct.MarkFailure("openai", FailoverRateLimit) + if ct.IsAvailable("openai") { + t.Error("should be in cooldown after 1st error") + } + + // Advance 61 seconds → available + *current = now.Add(61 * time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after 1 min cooldown") + } + + // 2nd error → 5 min cooldown + ct.MarkFailure("openai", FailoverRateLimit) + *current = now.Add(61*time.Second + 4*time.Minute) + if ct.IsAvailable("openai") { + t.Error("should be in cooldown (5 min) after 2nd error") + } + *current = now.Add(61*time.Second + 6*time.Minute) + if !ct.IsAvailable("openai") { + t.Error("should be available after 5 min cooldown") + } +} + +func TestCooldown_StandardCap(t *testing.T) { + // Verify formula: 1m, 5m, 25m, 1h, 1h, 1h... + expected := []time.Duration{ + 1 * time.Minute, + 5 * time.Minute, + 25 * time.Minute, + 1 * time.Hour, + 1 * time.Hour, + } + + for i, want := range expected { + got := calculateStandardCooldown(i + 1) + if got != want { + t.Errorf("calculateStandardCooldown(%d) = %v, want %v", i+1, got, want) + } + } +} + +func TestCooldown_BillingEscalation(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 1st billing error → 5h cooldown + ct.MarkFailure("openai", FailoverBilling) + if ct.IsAvailable("openai") { + t.Error("should be disabled after billing error") + } + + // Advance 4h → still disabled + *current = now.Add(4 * time.Hour) + if ct.IsAvailable("openai") { + t.Error("should still be disabled (5h cooldown)") + } + + // Advance 5h + 1s → available + *current = now.Add(5*time.Hour + 1*time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after 5h billing cooldown") + } +} + +func TestCooldown_BillingCap(t *testing.T) { + expected := []time.Duration{ + 5 * time.Hour, + 10 * time.Hour, + 20 * time.Hour, + 24 * time.Hour, + 24 * time.Hour, + } + + for i, want := range expected { + got := calculateBillingCooldown(i + 1) + if got != want { + t.Errorf("calculateBillingCooldown(%d) = %v, want %v", i+1, got, want) + } + } +} + +func TestCooldown_SuccessReset(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverBilling) + if ct.ErrorCount("openai") != 2 { + t.Errorf("error count = %d, want 2", ct.ErrorCount("openai")) + } + + ct.MarkSuccess("openai") + if ct.ErrorCount("openai") != 0 { + t.Errorf("error count after success = %d, want 0", ct.ErrorCount("openai")) + } + if !ct.IsAvailable("openai") { + t.Error("should be available after success") + } + if ct.FailureCount("openai", FailoverRateLimit) != 0 { + t.Error("failure counts should be reset after success") + } + if ct.FailureCount("openai", FailoverBilling) != 0 { + t.Error("billing failure count should be reset after success") + } +} + +func TestCooldown_FailureWindowReset(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 4 errors → 1h cooldown + for i := 0; i < 4; i++ { + ct.MarkFailure("openai", FailoverRateLimit) + *current = current.Add(2 * time.Second) // small advance between errors + } + if ct.ErrorCount("openai") != 4 { + t.Errorf("error count = %d, want 4", ct.ErrorCount("openai")) + } + + // Advance 25 hours (past 24h failure window) + *current = now.Add(25 * time.Hour) + + // Next error should reset counters first, then increment to 1 + ct.MarkFailure("openai", FailoverRateLimit) + if ct.ErrorCount("openai") != 1 { + t.Errorf("error count after window reset = %d, want 1 (reset + 1)", ct.ErrorCount("openai")) + } +} + +func TestCooldown_PerReasonTracking(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverBilling) + ct.MarkFailure("openai", FailoverAuth) + + if ct.FailureCount("openai", FailoverRateLimit) != 2 { + t.Errorf("rate_limit count = %d, want 2", ct.FailureCount("openai", FailoverRateLimit)) + } + if ct.FailureCount("openai", FailoverBilling) != 1 { + t.Errorf("billing count = %d, want 1", ct.FailureCount("openai", FailoverBilling)) + } + if ct.FailureCount("openai", FailoverAuth) != 1 { + t.Errorf("auth count = %d, want 1", ct.FailureCount("openai", FailoverAuth)) + } + if ct.ErrorCount("openai") != 4 { + t.Errorf("total error count = %d, want 4", ct.ErrorCount("openai")) + } +} + +func TestCooldown_BillingTakesPrecedence(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // Standard cooldown (1 min) + billing disable (5h) + ct.MarkFailure("openai", FailoverRateLimit) // 1 min cooldown + ct.MarkFailure("openai", FailoverBilling) // 5h disable + + // After 2 min: standard cooldown expired but billing still active + *current = now.Add(2 * time.Minute) + if ct.IsAvailable("openai") { + t.Error("billing disable should take precedence over standard cooldown") + } + + // After 5h + 1s: both expired + *current = now.Add(5*time.Hour + 1*time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after all cooldowns expire") + } +} + +func TestCooldown_CooldownRemaining(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // No failures → 0 remaining + if ct.CooldownRemaining("openai") != 0 { + t.Error("expected 0 remaining for new provider") + } + + ct.MarkFailure("openai", FailoverRateLimit) + + *current = now.Add(30 * time.Second) + remaining := ct.CooldownRemaining("openai") + if remaining <= 0 || remaining > 1*time.Minute { + t.Errorf("remaining = %v, expected ~30s", remaining) + } +} + +func TestCooldown_SuccessOnUnknownProvider(t *testing.T) { + ct := NewCooldownTracker() + // Should not panic + ct.MarkSuccess("nonexistent") + if !ct.IsAvailable("nonexistent") { + t.Error("nonexistent provider should be available") + } +} + +func TestCooldown_ConcurrentAccess(t *testing.T) { + ct := NewCooldownTracker() + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(3) + go func() { + defer wg.Done() + ct.MarkFailure("openai", FailoverRateLimit) + }() + go func() { + defer wg.Done() + ct.IsAvailable("openai") + }() + go func() { + defer wg.Done() + ct.MarkSuccess("openai") + }() + } + + wg.Wait() + // If we got here without panic, concurrent access is safe +} + +func TestCooldown_MultipleProviders(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("anthropic", FailoverBilling) + + if ct.IsAvailable("openai") { + t.Error("openai should be in cooldown") + } + if ct.IsAvailable("anthropic") { + t.Error("anthropic should be in cooldown") + } + // groq was never touched + if !ct.IsAvailable("groq") { + t.Error("groq should be available") + } +} diff --git a/pkg/providers/error_classifier.go b/pkg/providers/error_classifier.go new file mode 100644 index 000000000..a0f003006 --- /dev/null +++ b/pkg/providers/error_classifier.go @@ -0,0 +1,253 @@ +package providers + +import ( + "context" + "regexp" + "strings" +) + +// errorPattern defines a single pattern (string or regex) for error classification. +type errorPattern struct { + substring string + regex *regexp.Regexp +} + +func substr(s string) errorPattern { return errorPattern{substring: s} } +func rxp(r string) errorPattern { return errorPattern{regex: regexp.MustCompile("(?i)" + r)} } + +// Error patterns organized by FailoverReason, matching OpenClaw production (~40 patterns). +var ( + rateLimitPatterns = []errorPattern{ + rxp(`rate[_ ]limit`), + substr("too many requests"), + substr("429"), + substr("exceeded your current quota"), + rxp(`exceeded.*quota`), + rxp(`resource has been exhausted`), + rxp(`resource.*exhausted`), + substr("resource_exhausted"), + substr("quota exceeded"), + substr("usage limit"), + } + + overloadedPatterns = []errorPattern{ + rxp(`overloaded_error`), + rxp(`"type"\s*:\s*"overloaded_error"`), + substr("overloaded"), + } + + timeoutPatterns = []errorPattern{ + substr("timeout"), + substr("timed out"), + substr("deadline exceeded"), + substr("context deadline exceeded"), + } + + billingPatterns = []errorPattern{ + rxp(`\b402\b`), + substr("payment required"), + substr("insufficient credits"), + substr("credit balance"), + substr("plans & billing"), + substr("insufficient balance"), + } + + authPatterns = []errorPattern{ + rxp(`invalid[_ ]?api[_ ]?key`), + substr("incorrect api key"), + substr("invalid token"), + substr("authentication"), + substr("re-authenticate"), + substr("oauth token refresh failed"), + substr("unauthorized"), + substr("forbidden"), + substr("access denied"), + substr("expired"), + substr("token has expired"), + rxp(`\b401\b`), + rxp(`\b403\b`), + substr("no credentials found"), + substr("no api key found"), + } + + formatPatterns = []errorPattern{ + substr("string should match pattern"), + substr("tool_use.id"), + substr("tool_use_id"), + substr("messages.1.content.1.tool_use.id"), + substr("invalid request format"), + } + + imageDimensionPatterns = []errorPattern{ + rxp(`image dimensions exceed max`), + } + + imageSizePatterns = []errorPattern{ + rxp(`image exceeds.*mb`), + } + + // Transient HTTP status codes that map to timeout (server-side failures). + transientStatusCodes = map[int]bool{ + 500: true, 502: true, 503: true, + 521: true, 522: true, 523: true, 524: true, + 529: true, + } +) + +// ClassifyError classifies an error into a FailoverError with reason. +// Returns nil if the error is not classifiable (unknown errors should not trigger fallback). +func ClassifyError(err error, provider, model string) *FailoverError { + if err == nil { + return nil + } + + // Context cancellation: user abort, never fallback. + if err == context.Canceled { + return nil + } + + // Context deadline exceeded: treat as timeout, always fallback. + if err == context.DeadlineExceeded { + return &FailoverError{ + Reason: FailoverTimeout, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + msg := strings.ToLower(err.Error()) + + // Image dimension/size errors: non-retriable, non-fallback. + if IsImageDimensionError(msg) || IsImageSizeError(msg) { + return &FailoverError{ + Reason: FailoverFormat, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + // Try HTTP status code extraction first. + if status := extractHTTPStatus(msg); status > 0 { + if reason := classifyByStatus(status); reason != "" { + return &FailoverError{ + Reason: reason, + Provider: provider, + Model: model, + Status: status, + Wrapped: err, + } + } + } + + // Message pattern matching (priority order from OpenClaw). + if reason := classifyByMessage(msg); reason != "" { + return &FailoverError{ + Reason: reason, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + return nil +} + +// classifyByStatus maps HTTP status codes to FailoverReason. +func classifyByStatus(status int) FailoverReason { + switch { + case status == 401 || status == 403: + return FailoverAuth + case status == 402: + return FailoverBilling + case status == 408: + return FailoverTimeout + case status == 429: + return FailoverRateLimit + case status == 400: + return FailoverFormat + case transientStatusCodes[status]: + return FailoverTimeout + } + return "" +} + +// classifyByMessage matches error messages against patterns. +// Priority order matters (from OpenClaw classifyFailoverReason). +func classifyByMessage(msg string) FailoverReason { + if matchesAny(msg, rateLimitPatterns) { + return FailoverRateLimit + } + if matchesAny(msg, overloadedPatterns) { + return FailoverRateLimit // Overloaded treated as rate_limit + } + if matchesAny(msg, billingPatterns) { + return FailoverBilling + } + if matchesAny(msg, timeoutPatterns) { + return FailoverTimeout + } + if matchesAny(msg, authPatterns) { + return FailoverAuth + } + if matchesAny(msg, formatPatterns) { + return FailoverFormat + } + return "" +} + +// extractHTTPStatus extracts an HTTP status code from an error message. +// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429". +func extractHTTPStatus(msg string) int { + // Common patterns in Go HTTP error messages + patterns := []*regexp.Regexp{ + regexp.MustCompile(`status[:\s]+(\d{3})`), + regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`), + } + + for _, p := range patterns { + if m := p.FindStringSubmatch(msg); len(m) > 1 { + return parseDigits(m[1]) + } + } + + return 0 +} + +// IsImageDimensionError returns true if the message indicates an image dimension error. +func IsImageDimensionError(msg string) bool { + return matchesAny(msg, imageDimensionPatterns) +} + +// IsImageSizeError returns true if the message indicates an image file size error. +func IsImageSizeError(msg string) bool { + return matchesAny(msg, imageSizePatterns) +} + +// matchesAny checks if msg matches any of the patterns. +func matchesAny(msg string, patterns []errorPattern) bool { + for _, p := range patterns { + if p.regex != nil { + if p.regex.MatchString(msg) { + return true + } + } else if p.substring != "" { + if strings.Contains(msg, p.substring) { + return true + } + } + } + return false +} + +// parseDigits converts a string of digits to an int. +func parseDigits(s string) int { + n := 0 + for _, c := range s { + if c >= '0' && c <= '9' { + n = n*10 + int(c-'0') + } + } + return n +} diff --git a/pkg/providers/error_classifier_test.go b/pkg/providers/error_classifier_test.go new file mode 100644 index 000000000..865aea57a --- /dev/null +++ b/pkg/providers/error_classifier_test.go @@ -0,0 +1,337 @@ +package providers + +import ( + "context" + "errors" + "fmt" + "testing" +) + +func TestClassifyError_Nil(t *testing.T) { + result := ClassifyError(nil, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for nil error, got %+v", result) + } +} + +func TestClassifyError_ContextCanceled(t *testing.T) { + result := ClassifyError(context.Canceled, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for context.Canceled (user abort), got %+v", result) + } +} + +func TestClassifyError_ContextDeadlineExceeded(t *testing.T) { + result := ClassifyError(context.DeadlineExceeded, "openai", "gpt-4") + if result == nil { + t.Fatal("expected non-nil for deadline exceeded") + } + if result.Reason != FailoverTimeout { + t.Errorf("reason = %q, want timeout", result.Reason) + } +} + +func TestClassifyError_StatusCodes(t *testing.T) { + tests := []struct { + status int + reason FailoverReason + }{ + {401, FailoverAuth}, + {403, FailoverAuth}, + {402, FailoverBilling}, + {408, FailoverTimeout}, + {429, FailoverRateLimit}, + {400, FailoverFormat}, + {500, FailoverTimeout}, + {502, FailoverTimeout}, + {503, FailoverTimeout}, + {521, FailoverTimeout}, + {522, FailoverTimeout}, + {523, FailoverTimeout}, + {524, FailoverTimeout}, + {529, FailoverTimeout}, + } + + for _, tt := range tests { + err := fmt.Errorf("API error: status: %d something went wrong", tt.status) + result := ClassifyError(err, "test", "model") + if result == nil { + t.Errorf("status %d: expected non-nil", tt.status) + continue + } + if result.Reason != tt.reason { + t.Errorf("status %d: reason = %q, want %q", tt.status, result.Reason, tt.reason) + } + } +} + +func TestClassifyError_RateLimitPatterns(t *testing.T) { + patterns := []string{ + "rate limit exceeded", + "rate_limit reached", + "too many requests", + "exceeded your current quota", + "resource has been exhausted", + "resource_exhausted", + "quota exceeded", + "usage limit reached", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverRateLimit { + t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason) + } + } +} + +func TestClassifyError_OverloadedPatterns(t *testing.T) { + patterns := []string{ + "overloaded_error", + `{"type": "overloaded_error"}`, + "server is overloaded", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "anthropic", "claude") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + // Overloaded is treated as rate_limit + if result.Reason != FailoverRateLimit { + t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason) + } + } +} + +func TestClassifyError_BillingPatterns(t *testing.T) { + patterns := []string{ + "payment required", + "insufficient credits", + "credit balance too low", + "plans & billing page", + "insufficient balance", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverBilling { + t.Errorf("pattern %q: reason = %q, want billing", msg, result.Reason) + } + } +} + +func TestClassifyError_TimeoutPatterns(t *testing.T) { + patterns := []string{ + "request timeout", + "connection timed out", + "deadline exceeded", + "context deadline exceeded", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverTimeout { + t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason) + } + } +} + +func TestClassifyError_AuthPatterns(t *testing.T) { + patterns := []string{ + "invalid api key", + "invalid_api_key", + "incorrect api key", + "invalid token", + "authentication failed", + "re-authenticate", + "oauth token refresh failed", + "unauthorized access", + "forbidden", + "access denied", + "expired", + "token has expired", + "no credentials found", + "no api key found", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverAuth { + t.Errorf("pattern %q: reason = %q, want auth", msg, result.Reason) + } + } +} + +func TestClassifyError_FormatPatterns(t *testing.T) { + patterns := []string{ + "string should match pattern", + "tool_use.id is required", + "invalid tool_use_id", + "messages.1.content.1.tool_use.id must be valid", + "invalid request format", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "anthropic", "claude") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverFormat { + t.Errorf("pattern %q: reason = %q, want format", msg, result.Reason) + } + } +} + +func TestClassifyError_ImageDimensionError(t *testing.T) { + err := errors.New("image dimensions exceed max allowed 2048x2048") + result := ClassifyError(err, "openai", "gpt-4o") + if result == nil { + t.Fatal("expected non-nil for image dimension error") + } + if result.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", result.Reason) + } + if result.IsRetriable() { + t.Error("image dimension error should not be retriable") + } +} + +func TestClassifyError_ImageSizeError(t *testing.T) { + err := errors.New("image exceeds 20 mb limit") + result := ClassifyError(err, "openai", "gpt-4o") + if result == nil { + t.Fatal("expected non-nil for image size error") + } + if result.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", result.Reason) + } +} + +func TestClassifyError_UnknownError(t *testing.T) { + err := errors.New("some completely random error") + result := ClassifyError(err, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for unknown error, got %+v", result) + } +} + +func TestClassifyError_ProviderModelPropagation(t *testing.T) { + err := errors.New("rate limit exceeded") + result := ClassifyError(err, "my-provider", "my-model") + if result == nil { + t.Fatal("expected non-nil") + } + if result.Provider != "my-provider" { + t.Errorf("provider = %q, want my-provider", result.Provider) + } + if result.Model != "my-model" { + t.Errorf("model = %q, want my-model", result.Model) + } +} + +func TestFailoverError_IsRetriable(t *testing.T) { + tests := []struct { + reason FailoverReason + retriable bool + }{ + {FailoverAuth, true}, + {FailoverRateLimit, true}, + {FailoverBilling, true}, + {FailoverTimeout, true}, + {FailoverOverloaded, true}, + {FailoverFormat, false}, + {FailoverUnknown, true}, + } + + for _, tt := range tests { + fe := &FailoverError{Reason: tt.reason} + if fe.IsRetriable() != tt.retriable { + t.Errorf("IsRetriable(%q) = %v, want %v", tt.reason, fe.IsRetriable(), tt.retriable) + } + } +} + +func TestFailoverError_ErrorString(t *testing.T) { + fe := &FailoverError{ + Reason: FailoverRateLimit, + Provider: "openai", + Model: "gpt-4", + Status: 429, + Wrapped: errors.New("too many requests"), + } + s := fe.Error() + if s == "" { + t.Error("expected non-empty error string") + } +} + +func TestFailoverError_Unwrap(t *testing.T) { + inner := errors.New("inner error") + fe := &FailoverError{Reason: FailoverTimeout, Wrapped: inner} + if fe.Unwrap() != inner { + t.Error("Unwrap should return wrapped error") + } +} + +func TestExtractHTTPStatus(t *testing.T) { + tests := []struct { + msg string + want int + }{ + {"status: 429 rate limited", 429}, + {"status 401 unauthorized", 401}, + {"HTTP/1.1 502 Bad Gateway", 502}, + {"no status code here", 0}, + {"random number 12345", 0}, + } + + for _, tt := range tests { + got := extractHTTPStatus(tt.msg) + if got != tt.want { + t.Errorf("extractHTTPStatus(%q) = %d, want %d", tt.msg, got, tt.want) + } + } +} + +func TestIsImageDimensionError(t *testing.T) { + if !IsImageDimensionError("image dimensions exceed max 4096x4096") { + t.Error("should match image dimensions exceed max") + } + if IsImageDimensionError("normal error message") { + t.Error("should not match normal error") + } +} + +func TestIsImageSizeError(t *testing.T) { + if !IsImageSizeError("image exceeds 20 mb") { + t.Error("should match image exceeds mb") + } + if IsImageSizeError("normal error message") { + t.Error("should not match normal error") + } +} diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go new file mode 100644 index 000000000..b6f1b5e21 --- /dev/null +++ b/pkg/providers/factory.go @@ -0,0 +1,307 @@ +package providers + +import ( + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" +) + +const defaultAnthropicAPIBase = "https://api.anthropic.com/v1" + +var getCredential = auth.GetCredential + +type providerType int + +const ( + providerTypeHTTPCompat providerType = iota + providerTypeClaudeAuth + providerTypeCodexAuth + providerTypeCodexCLIToken + providerTypeClaudeCLI + providerTypeCodexCLI + providerTypeGitHubCopilot +) + +type providerSelection struct { + providerType providerType + apiKey string + apiBase string + proxy string + model string + workspace string + connectMode string + enableWebSearch bool +} + +func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { + model := cfg.Agents.Defaults.Model + providerName := strings.ToLower(cfg.Agents.Defaults.Provider) + lowerModel := strings.ToLower(model) + + sel := providerSelection{ + providerType: providerTypeHTTPCompat, + model: model, + } + + // First, prefer explicit provider configuration. + if providerName != "" { + switch providerName { + case "groq": + if cfg.Providers.Groq.APIKey != "" { + sel.apiKey = cfg.Providers.Groq.APIKey + sel.apiBase = cfg.Providers.Groq.APIBase + sel.proxy = cfg.Providers.Groq.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.groq.com/openai/v1" + } + } + case "openai", "gpt": + if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { + sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch + if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { + sel.providerType = providerTypeCodexCLIToken + return sel, nil + } + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + sel.providerType = providerTypeCodexAuth + return sel, nil + } + sel.apiKey = cfg.Providers.OpenAI.APIKey + sel.apiBase = cfg.Providers.OpenAI.APIBase + sel.proxy = cfg.Providers.OpenAI.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.openai.com/v1" + } + } + case "anthropic", "claude": + if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + sel.providerType = providerTypeClaudeAuth + return sel, nil + } + sel.apiKey = cfg.Providers.Anthropic.APIKey + sel.apiBase = cfg.Providers.Anthropic.APIBase + sel.proxy = cfg.Providers.Anthropic.Proxy + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + } + case "openrouter": + if cfg.Providers.OpenRouter.APIKey != "" { + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + } + case "zhipu", "glm": + if cfg.Providers.Zhipu.APIKey != "" { + sel.apiKey = cfg.Providers.Zhipu.APIKey + sel.apiBase = cfg.Providers.Zhipu.APIBase + sel.proxy = cfg.Providers.Zhipu.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + } + case "gemini", "google": + if cfg.Providers.Gemini.APIKey != "" { + sel.apiKey = cfg.Providers.Gemini.APIKey + sel.apiBase = cfg.Providers.Gemini.APIBase + sel.proxy = cfg.Providers.Gemini.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + } + case "vllm": + if cfg.Providers.VLLM.APIBase != "" { + sel.apiKey = cfg.Providers.VLLM.APIKey + sel.apiBase = cfg.Providers.VLLM.APIBase + sel.proxy = cfg.Providers.VLLM.Proxy + } + case "shengsuanyun": + if cfg.Providers.ShengSuanYun.APIKey != "" { + sel.apiKey = cfg.Providers.ShengSuanYun.APIKey + sel.apiBase = cfg.Providers.ShengSuanYun.APIBase + sel.proxy = cfg.Providers.ShengSuanYun.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://router.shengsuanyun.com/api/v1" + } + } + case "nvidia": + if cfg.Providers.Nvidia.APIKey != "" { + sel.apiKey = cfg.Providers.Nvidia.APIKey + sel.apiBase = cfg.Providers.Nvidia.APIBase + sel.proxy = cfg.Providers.Nvidia.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://integrate.api.nvidia.com/v1" + } + } + case "claude-cli", "claude-code", "claudecode": + workspace := cfg.WorkspacePath() + if workspace == "" { + workspace = "." + } + sel.providerType = providerTypeClaudeCLI + sel.workspace = workspace + return sel, nil + case "codex-cli", "codex-code": + workspace := cfg.WorkspacePath() + if workspace == "" { + workspace = "." + } + sel.providerType = providerTypeCodexCLI + sel.workspace = workspace + return sel, nil + case "deepseek": + if cfg.Providers.DeepSeek.APIKey != "" { + sel.apiKey = cfg.Providers.DeepSeek.APIKey + sel.apiBase = cfg.Providers.DeepSeek.APIBase + sel.proxy = cfg.Providers.DeepSeek.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.deepseek.com/v1" + } + if model != "deepseek-chat" && model != "deepseek-reasoner" { + sel.model = "deepseek-chat" + } + } + case "github_copilot", "copilot": + sel.providerType = providerTypeGitHubCopilot + if cfg.Providers.GitHubCopilot.APIBase != "" { + sel.apiBase = cfg.Providers.GitHubCopilot.APIBase + } else { + sel.apiBase = "localhost:4321" + } + sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode + return sel, nil + } + } + + // Fallback: infer provider from model and configured keys. + if sel.apiKey == "" && sel.apiBase == "" { + switch { + case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": + sel.apiKey = cfg.Providers.Moonshot.APIKey + sel.apiBase = cfg.Providers.Moonshot.APIBase + sel.proxy = cfg.Providers.Moonshot.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.moonshot.cn/v1" + } + case strings.HasPrefix(model, "openrouter/") || + strings.HasPrefix(model, "anthropic/") || + strings.HasPrefix(model, "openai/") || + strings.HasPrefix(model, "meta-llama/") || + strings.HasPrefix(model, "deepseek/") || + strings.HasPrefix(model, "google/"): + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && + (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + sel.providerType = providerTypeClaudeAuth + return sel, nil + } + sel.apiKey = cfg.Providers.Anthropic.APIKey + sel.apiBase = cfg.Providers.Anthropic.APIBase + sel.proxy = cfg.Providers.Anthropic.Proxy + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && + (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch + if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { + sel.providerType = providerTypeCodexCLIToken + return sel, nil + } + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + sel.providerType = providerTypeCodexAuth + return sel, nil + } + sel.apiKey = cfg.Providers.OpenAI.APIKey + sel.apiBase = cfg.Providers.OpenAI.APIBase + sel.proxy = cfg.Providers.OpenAI.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.openai.com/v1" + } + case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": + sel.apiKey = cfg.Providers.Gemini.APIKey + sel.apiBase = cfg.Providers.Gemini.APIBase + sel.proxy = cfg.Providers.Gemini.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": + sel.apiKey = cfg.Providers.Zhipu.APIKey + sel.apiBase = cfg.Providers.Zhipu.APIBase + sel.proxy = cfg.Providers.Zhipu.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": + sel.apiKey = cfg.Providers.Groq.APIKey + sel.apiBase = cfg.Providers.Groq.APIBase + sel.proxy = cfg.Providers.Groq.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.groq.com/openai/v1" + } + case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": + sel.apiKey = cfg.Providers.Nvidia.APIKey + sel.apiBase = cfg.Providers.Nvidia.APIBase + sel.proxy = cfg.Providers.Nvidia.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://integrate.api.nvidia.com/v1" + } + case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "": + sel.apiKey = cfg.Providers.Ollama.APIKey + sel.apiBase = cfg.Providers.Ollama.APIBase + sel.proxy = cfg.Providers.Ollama.Proxy + if sel.apiBase == "" { + sel.apiBase = "http://localhost:11434/v1" + } + case cfg.Providers.VLLM.APIBase != "": + sel.apiKey = cfg.Providers.VLLM.APIKey + sel.apiBase = cfg.Providers.VLLM.APIBase + sel.proxy = cfg.Providers.VLLM.Proxy + default: + if cfg.Providers.OpenRouter.APIKey != "" { + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + } else { + return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model) + } + } + } + + if sel.providerType == providerTypeHTTPCompat { + if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") { + return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model) + } + if sel.apiBase == "" { + return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model) + } + } + + return sel, nil +} diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go new file mode 100644 index 000000000..74fe8a36c --- /dev/null +++ b/pkg/providers/factory_provider.go @@ -0,0 +1,192 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package providers + +import ( + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" +) + +// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store. +func createClaudeAuthProvider() (LLMProvider, error) { + cred, err := getCredential("anthropic") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil +} + +// createCodexAuthProvider creates a Codex provider using OAuth credentials from auth store. +func createCodexAuthProvider() (LLMProvider, error) { + cred, err := getCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil +} + +// ExtractProtocol extracts the protocol prefix and model identifier from a model string. +// If no prefix is specified, it defaults to "openai". +// Examples: +// - "openai/gpt-4o" -> ("openai", "gpt-4o") +// - "anthropic/claude-sonnet-4.6" -> ("anthropic", "claude-sonnet-4.6") +// - "gpt-4o" -> ("openai", "gpt-4o") // default protocol +func ExtractProtocol(model string) (protocol, modelID string) { + model = strings.TrimSpace(model) + protocol, modelID, found := strings.Cut(model, "/") + if !found { + return "openai", model + } + return protocol, modelID +} + +// CreateProviderFromConfig creates a provider based on the ModelConfig. +// It uses the protocol prefix in the Model field to determine which provider to create. +// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot +// Returns the provider, the model ID (without protocol prefix), and any error. +func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) { + if cfg == nil { + return nil, "", fmt.Errorf("config is nil") + } + + if cfg.Model == "" { + return nil, "", fmt.Errorf("model is required") + } + + protocol, modelID := ExtractProtocol(cfg.Model) + + switch protocol { + case "openai": + // OpenAI with OAuth/token auth (Codex-style) + if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" { + provider, err := createCodexAuthProvider() + if err != nil { + return nil, "", err + } + return provider, modelID, nil + } + // OpenAI with API key + if cfg.APIKey == "" && cfg.APIBase == "" { + return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol) + } + apiBase := cfg.APIBase + if apiBase == "" { + apiBase = getDefaultAPIBase(protocol) + } + return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil + + case "openrouter", "groq", "zhipu", "gemini", "nvidia", + "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", + "volcengine", "vllm", "qwen": + // All other OpenAI-compatible HTTP providers + if cfg.APIKey == "" && cfg.APIBase == "" { + return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol) + } + apiBase := cfg.APIBase + if apiBase == "" { + apiBase = getDefaultAPIBase(protocol) + } + return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil + + case "anthropic": + if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" { + // Use OAuth credentials from auth store + provider, err := createClaudeAuthProvider() + if err != nil { + return nil, "", err + } + return provider, modelID, nil + } + // Use API key with HTTP API + apiBase := cfg.APIBase + if apiBase == "" { + apiBase = "https://api.anthropic.com/v1" + } + if cfg.APIKey == "" { + return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model) + } + return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil + + case "antigravity": + return NewAntigravityProvider(), modelID, nil + + case "claude-cli", "claudecli": + workspace := cfg.Workspace + if workspace == "" { + workspace = "." + } + return NewClaudeCliProvider(workspace), modelID, nil + + case "codex-cli", "codexcli": + workspace := cfg.Workspace + if workspace == "" { + workspace = "." + } + return NewCodexCliProvider(workspace), modelID, nil + + case "github-copilot", "copilot": + apiBase := cfg.APIBase + if apiBase == "" { + apiBase = "localhost:4321" + } + connectMode := cfg.ConnectMode + if connectMode == "" { + connectMode = "grpc" + } + provider, err := NewGitHubCopilotProvider(apiBase, connectMode, modelID) + if err != nil { + return nil, "", err + } + return provider, modelID, nil + + default: + return nil, "", fmt.Errorf("unknown protocol %q in model %q", protocol, cfg.Model) + } +} + +// getDefaultAPIBase returns the default API base URL for a given protocol. +func getDefaultAPIBase(protocol string) string { + switch protocol { + case "openai": + return "https://api.openai.com/v1" + case "openrouter": + return "https://openrouter.ai/api/v1" + case "groq": + return "https://api.groq.com/openai/v1" + case "zhipu": + return "https://open.bigmodel.cn/api/paas/v4" + case "gemini": + return "https://generativelanguage.googleapis.com/v1beta" + case "nvidia": + return "https://integrate.api.nvidia.com/v1" + case "ollama": + return "http://localhost:11434/v1" + case "moonshot": + return "https://api.moonshot.cn/v1" + case "shengsuanyun": + return "https://router.shengsuanyun.com/api/v1" + case "deepseek": + return "https://api.deepseek.com/v1" + case "cerebras": + return "https://api.cerebras.ai/v1" + case "volcengine": + return "https://ark.cn-beijing.volces.com/api/v3" + case "qwen": + return "https://dashscope.aliyuncs.com/compatible-mode/v1" + case "vllm": + return "http://localhost:8000/v1" + default: + return "" + } +} diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go new file mode 100644 index 000000000..6b133101a --- /dev/null +++ b/pkg/providers/factory_provider_test.go @@ -0,0 +1,249 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package providers + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestExtractProtocol(t *testing.T) { + tests := []struct { + name string + model string + wantProtocol string + wantModelID string + }{ + { + name: "openai with prefix", + model: "openai/gpt-4o", + wantProtocol: "openai", + wantModelID: "gpt-4o", + }, + { + name: "anthropic with prefix", + model: "anthropic/claude-sonnet-4.6", + wantProtocol: "anthropic", + wantModelID: "claude-sonnet-4.6", + }, + { + name: "no prefix - defaults to openai", + model: "gpt-4o", + wantProtocol: "openai", + wantModelID: "gpt-4o", + }, + { + name: "groq with prefix", + model: "groq/llama-3.1-70b", + wantProtocol: "groq", + wantModelID: "llama-3.1-70b", + }, + { + name: "empty string", + model: "", + wantProtocol: "openai", + wantModelID: "", + }, + { + name: "with whitespace", + model: " openai/gpt-4 ", + wantProtocol: "openai", + wantModelID: "gpt-4", + }, + { + name: "multiple slashes", + model: "nvidia/meta/llama-3.1-8b", + wantProtocol: "nvidia", + wantModelID: "meta/llama-3.1-8b", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + protocol, modelID := ExtractProtocol(tt.model) + if protocol != tt.wantProtocol { + t.Errorf("ExtractProtocol(%q) protocol = %q, want %q", tt.model, protocol, tt.wantProtocol) + } + if modelID != tt.wantModelID { + t.Errorf("ExtractProtocol(%q) modelID = %q, want %q", tt.model, modelID, tt.wantModelID) + } + }) + } +} + +func TestCreateProviderFromConfig_OpenAI(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-openai", + Model: "openai/gpt-4o", + APIKey: "test-key", + APIBase: "https://api.example.com/v1", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "gpt-4o" { + t.Errorf("modelID = %q, want %q", modelID, "gpt-4o") + } +} + +func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) { + tests := []struct { + name string + protocol string + }{ + {"openai", "openai"}, + {"groq", "groq"}, + {"openrouter", "openrouter"}, + {"cerebras", "cerebras"}, + {"qwen", "qwen"}, + {"vllm", "vllm"}, + {"deepseek", "deepseek"}, + {"ollama", "ollama"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-" + tt.protocol, + Model: tt.protocol + "/test-model", + APIKey: "test-key", + } + + provider, _, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + + // Verify we got an HTTPProvider for all these protocols + if _, ok := provider.(*HTTPProvider); !ok { + t.Fatalf("expected *HTTPProvider, got %T", provider) + } + }) + } +} + +func TestCreateProviderFromConfig_Anthropic(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-anthropic", + Model: "anthropic/claude-sonnet-4.6", + APIKey: "test-key", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "claude-sonnet-4.6" { + t.Errorf("modelID = %q, want %q", modelID, "claude-sonnet-4.6") + } +} + +func TestCreateProviderFromConfig_Antigravity(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-antigravity", + Model: "antigravity/gemini-2.0-flash", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "gemini-2.0-flash" { + t.Errorf("modelID = %q, want %q", modelID, "gemini-2.0-flash") + } +} + +func TestCreateProviderFromConfig_ClaudeCLI(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-claude-cli", + Model: "claude-cli/claude-sonnet-4.6", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "claude-sonnet-4.6" { + t.Errorf("modelID = %q, want %q", modelID, "claude-sonnet-4.6") + } +} + +func TestCreateProviderFromConfig_CodexCLI(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-codex-cli", + Model: "codex-cli/codex", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "codex" { + t.Errorf("modelID = %q, want %q", modelID, "codex") + } +} + +func TestCreateProviderFromConfig_MissingAPIKey(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-no-key", + Model: "openai/gpt-4o", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for missing API key") + } +} + +func TestCreateProviderFromConfig_UnknownProtocol(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-unknown", + Model: "unknown-protocol/model", + APIKey: "test-key", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for unknown protocol") + } +} + +func TestCreateProviderFromConfig_NilConfig(t *testing.T) { + _, _, err := CreateProviderFromConfig(nil) + if err == nil { + t.Fatal("CreateProviderFromConfig(nil) expected error") + } +} + +func TestCreateProviderFromConfig_EmptyModel(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-empty", + Model: "", + } + + _, _, err := CreateProviderFromConfig(cfg) + if err == nil { + t.Fatal("CreateProviderFromConfig() expected error for empty model") + } +} diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go new file mode 100644 index 000000000..5680f23b3 --- /dev/null +++ b/pkg/providers/factory_test.go @@ -0,0 +1,299 @@ +package providers + +import ( + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestResolveProviderSelection(t *testing.T) { + tests := []struct { + name string + setup func(*config.Config) + wantType providerType + wantAPIBase string + wantProxy string + wantErrSubstr string + }{ + { + name: "explicit claude-cli provider routes to cli provider type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "claude-cli" + cfg.Agents.Defaults.Workspace = "/tmp/ws" + }, + wantType: providerTypeClaudeCLI, + }, + { + name: "explicit copilot provider routes to github copilot type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "copilot" + }, + wantType: providerTypeGitHubCopilot, + wantAPIBase: "localhost:4321", + }, + { + name: "explicit deepseek provider uses deepseek defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "deepseek" + cfg.Agents.Defaults.Model = "deepseek/deepseek-chat" + cfg.Providers.DeepSeek.APIKey = "deepseek-key" + cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.deepseek.com/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit shengsuanyun provider uses defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "shengsuanyun" + cfg.Providers.ShengSuanYun.APIKey = "ssy-key" + cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://router.shengsuanyun.com/api/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit nvidia provider uses defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "nvidia" + cfg.Providers.Nvidia.APIKey = "nvapi-test" + cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://integrate.api.nvidia.com/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "openrouter model uses openrouter defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "openrouter/auto" + cfg.Providers.OpenRouter.APIKey = "sk-or-test" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://openrouter.ai/api/v1", + }, + { + name: "anthropic oauth routes to claude auth provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "claude-sonnet-4.6" + cfg.Providers.Anthropic.AuthMethod = "oauth" + }, + wantType: providerTypeClaudeAuth, + }, + { + name: "openai oauth routes to codex auth provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "gpt-4o" + cfg.Providers.OpenAI.AuthMethod = "oauth" + }, + wantType: providerTypeCodexAuth, + }, + { + name: "openai codex-cli auth routes to codex cli token provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "gpt-4o" + cfg.Providers.OpenAI.AuthMethod = "codex-cli" + }, + wantType: providerTypeCodexCLIToken, + }, + { + name: "explicit codex-code provider routes to codex cli provider type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "codex-code" + cfg.Agents.Defaults.Workspace = "/tmp/ws" + }, + wantType: providerTypeCodexCLI, + }, + { + name: "zhipu model uses zhipu base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "glm-4.7" + cfg.Providers.Zhipu.APIKey = "zhipu-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://open.bigmodel.cn/api/paas/v4", + }, + { + name: "groq model uses groq base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "groq/llama-3.3-70b" + cfg.Providers.Groq.APIKey = "gsk-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.groq.com/openai/v1", + }, + { + name: "ollama model uses ollama base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "ollama/qwen2.5:14b" + cfg.Providers.Ollama.APIKey = "ollama-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "http://localhost:11434/v1", + }, + { + name: "moonshot model keeps proxy and default base", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5" + cfg.Providers.Moonshot.APIKey = "moonshot-key" + cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.moonshot.cn/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "missing keys returns model config error", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "custom-model" + }, + wantErrSubstr: "no API key configured for model", + }, + { + name: "openrouter prefix without key returns provider key error", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "openrouter/auto" + }, + wantErrSubstr: "no API key configured for provider", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.DefaultConfig() + tt.setup(cfg) + + got, err := resolveProviderSelection(cfg) + if tt.wantErrSubstr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr) + } + if !strings.Contains(err.Error(), tt.wantErrSubstr) { + t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr) + } + return + } + + if err != nil { + t.Fatalf("resolveProviderSelection() error = %v", err) + } + if got.providerType != tt.wantType { + t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType) + } + if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase { + t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase) + } + if tt.wantProxy != "" && got.proxy != tt.wantProxy { + t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy) + } + }) + } +} + +func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Model = "test-openrouter" + cfg.ModelList = []config.ModelConfig{ + { + ModelName: "test-openrouter", + Model: "openrouter/auto", + APIKey: "sk-or-test", + APIBase: "https://openrouter.ai/api/v1", + }, + } + + provider, _, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*HTTPProvider); !ok { + t.Fatalf("provider type = %T, want *HTTPProvider", provider) + } +} + +func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Model = "test-codex" + cfg.ModelList = []config.ModelConfig{ + { + ModelName: "test-codex", + Model: "codex-cli/codex-model", + Workspace: "/tmp/workspace", + }, + } + + provider, _, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*CodexCliProvider); !ok { + t.Fatalf("provider type = %T, want *CodexCliProvider", provider) + } +} + +func TestCreateProviderReturnsClaudeCliProviderForClaudeCli(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Model = "test-claude-cli" + cfg.ModelList = []config.ModelConfig{ + { + ModelName: "test-claude-cli", + Model: "claude-cli/claude-sonnet", + Workspace: "/tmp/workspace", + }, + } + + provider, _, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*ClaudeCliProvider); !ok { + t.Fatalf("provider type = %T, want *ClaudeCliProvider", provider) + } +} + +func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) { + originalGetCredential := getCredential + t.Cleanup(func() { getCredential = originalGetCredential }) + + getCredential = func(provider string) (*auth.AuthCredential, error) { + if provider != "anthropic" { + t.Fatalf("provider = %q, want anthropic", provider) + } + return &auth.AuthCredential{ + AccessToken: "anthropic-token", + }, nil + } + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Model = "test-claude-oauth" + cfg.ModelList = []config.ModelConfig{ + { + ModelName: "test-claude-oauth", + Model: "anthropic/claude-sonnet-4.6", + AuthMethod: "oauth", + }, + } + + provider, _, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*ClaudeProvider); !ok { + t.Fatalf("provider type = %T, want *ClaudeProvider", provider) + } + // TODO: Test custom APIBase when createClaudeAuthProvider supports it +} + +func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) { + // TODO: This test requires openai protocol to support auth_method: "oauth" + // which is not yet implemented in the new factory_provider.go + t.Skip("OpenAI OAuth via model_list not yet implemented") +} diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go new file mode 100644 index 000000000..ecd451ec9 --- /dev/null +++ b/pkg/providers/fallback.go @@ -0,0 +1,287 @@ +package providers + +import ( + "context" + "fmt" + "strings" + "time" +) + +// FallbackChain orchestrates model fallback across multiple candidates. +type FallbackChain struct { + cooldown *CooldownTracker +} + +// FallbackCandidate represents one model/provider to try. +type FallbackCandidate struct { + Provider string + Model string +} + +// FallbackResult contains the successful response and metadata about all attempts. +type FallbackResult struct { + Response *LLMResponse + Provider string + Model string + Attempts []FallbackAttempt +} + +// FallbackAttempt records one attempt in the fallback chain. +type FallbackAttempt struct { + Provider string + Model string + Error error + Reason FailoverReason + Duration time.Duration + Skipped bool // true if skipped due to cooldown +} + +// NewFallbackChain creates a new fallback chain with the given cooldown tracker. +func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain { + return &FallbackChain{cooldown: cooldown} +} + +// ResolveCandidates parses model config into a deduplicated candidate list. +func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate { + seen := make(map[string]bool) + var candidates []FallbackCandidate + + addCandidate := func(raw string) { + ref := ParseModelRef(raw, defaultProvider) + if ref == nil { + return + } + key := ModelKey(ref.Provider, ref.Model) + if seen[key] { + return + } + seen[key] = true + candidates = append(candidates, FallbackCandidate{ + Provider: ref.Provider, + Model: ref.Model, + }) + } + + // Primary first. + addCandidate(cfg.Primary) + + // Then fallbacks. + for _, fb := range cfg.Fallbacks { + addCandidate(fb) + } + + return candidates +} + +// Execute runs the fallback chain for text/chat requests. +// It tries each candidate in order, respecting cooldowns and error classification. +// +// Behavior: +// - Candidates in cooldown are skipped (logged as skipped attempt). +// - context.Canceled aborts immediately (user abort, no fallback). +// - Non-retriable errors (format) abort immediately. +// - Retriable errors trigger fallback to next candidate. +// - Success marks provider as good (resets cooldown). +// - If all fail, returns aggregate error with all attempts. +func (fc *FallbackChain) Execute( + ctx context.Context, + candidates []FallbackCandidate, + run func(ctx context.Context, provider, model string) (*LLMResponse, error), +) (*FallbackResult, error) { + if len(candidates) == 0 { + return nil, fmt.Errorf("fallback: no candidates configured") + } + + result := &FallbackResult{ + Attempts: make([]FallbackAttempt, 0, len(candidates)), + } + + for i, candidate := range candidates { + // Check context before each attempt. + if ctx.Err() == context.Canceled { + return nil, context.Canceled + } + + // Check cooldown. + if !fc.cooldown.IsAvailable(candidate.Provider) { + remaining := fc.cooldown.CooldownRemaining(candidate.Provider) + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Skipped: true, + Reason: FailoverRateLimit, + Error: fmt.Errorf( + "provider %s in cooldown (%s remaining)", + candidate.Provider, + remaining.Round(time.Second), + ), + }) + continue + } + + // Execute the run function. + start := time.Now() + resp, err := run(ctx, candidate.Provider, candidate.Model) + elapsed := time.Since(start) + + if err == nil { + // Success. + fc.cooldown.MarkSuccess(candidate.Provider) + result.Response = resp + result.Provider = candidate.Provider + result.Model = candidate.Model + return result, nil + } + + // Context cancellation: abort immediately, no fallback. + if ctx.Err() == context.Canceled { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, context.Canceled + } + + // Classify the error. + failErr := ClassifyError(err, candidate.Provider, candidate.Model) + + if failErr == nil { + // Unclassifiable error: do not fallback, return immediately. + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, fmt.Errorf("fallback: unclassified error from %s/%s: %w", + candidate.Provider, candidate.Model, err) + } + + // Non-retriable error: abort immediately. + if !failErr.IsRetriable() { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: failErr, + Reason: failErr.Reason, + Duration: elapsed, + }) + return nil, failErr + } + + // Retriable error: mark failure and continue to next candidate. + fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason) + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: failErr, + Reason: failErr.Reason, + Duration: elapsed, + }) + + // If this was the last candidate, return aggregate error. + if i == len(candidates)-1 { + return nil, &FallbackExhaustedError{Attempts: result.Attempts} + } + } + + // All candidates were skipped (all in cooldown). + return nil, &FallbackExhaustedError{Attempts: result.Attempts} +} + +// ExecuteImage runs the fallback chain for image/vision requests. +// Simpler than Execute: no cooldown checks (image endpoints have different rate limits). +// Image dimension/size errors abort immediately (non-retriable). +func (fc *FallbackChain) ExecuteImage( + ctx context.Context, + candidates []FallbackCandidate, + run func(ctx context.Context, provider, model string) (*LLMResponse, error), +) (*FallbackResult, error) { + if len(candidates) == 0 { + return nil, fmt.Errorf("image fallback: no candidates configured") + } + + result := &FallbackResult{ + Attempts: make([]FallbackAttempt, 0, len(candidates)), + } + + for i, candidate := range candidates { + if ctx.Err() == context.Canceled { + return nil, context.Canceled + } + + start := time.Now() + resp, err := run(ctx, candidate.Provider, candidate.Model) + elapsed := time.Since(start) + + if err == nil { + result.Response = resp + result.Provider = candidate.Provider + result.Model = candidate.Model + return result, nil + } + + if ctx.Err() == context.Canceled { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, context.Canceled + } + + // Image dimension/size errors are non-retriable. + errMsg := strings.ToLower(err.Error()) + if IsImageDimensionError(errMsg) || IsImageSizeError(errMsg) { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Reason: FailoverFormat, + Duration: elapsed, + }) + return nil, &FailoverError{ + Reason: FailoverFormat, + Provider: candidate.Provider, + Model: candidate.Model, + Wrapped: err, + } + } + + // Any other error: record and try next. + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + + if i == len(candidates)-1 { + return nil, &FallbackExhaustedError{Attempts: result.Attempts} + } + } + + return nil, &FallbackExhaustedError{Attempts: result.Attempts} +} + +// FallbackExhaustedError indicates all fallback candidates were tried and failed. +type FallbackExhaustedError struct { + Attempts []FallbackAttempt +} + +func (e *FallbackExhaustedError) Error() string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("fallback: all %d candidates failed:", len(e.Attempts))) + for i, a := range e.Attempts { + if a.Skipped { + sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: skipped (cooldown)", i+1, a.Provider, a.Model)) + } else { + sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: %v (reason=%s, %s)", + i+1, a.Provider, a.Model, a.Error, a.Reason, a.Duration.Round(time.Millisecond))) + } + } + return sb.String() +} diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go new file mode 100644 index 000000000..e872c672e --- /dev/null +++ b/pkg/providers/fallback_test.go @@ -0,0 +1,479 @@ +package providers + +import ( + "context" + "errors" + "testing" + "time" +) + +func makeCandidate(provider, model string) FallbackCandidate { + return FallbackCandidate{Provider: provider, Model: model} +} + +func successRun(content string) func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return &LLMResponse{Content: content, FinishReason: "stop"}, nil + } +} + +func failRun(err error) func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return nil, err + } +} + +func TestFallback_SingleCandidate_Success(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + result, err := fc.Execute(context.Background(), candidates, successRun("hello")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "hello" { + t.Errorf("content = %q, want hello", result.Response.Content) + } + if result.Provider != "openai" || result.Model != "gpt-4" { + t.Errorf("provider/model = %s/%s, want openai/gpt-4", result.Provider, result.Model) + } +} + +func TestFallback_SecondCandidateSuccess(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude-opus"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + return nil, errors.New("rate limit exceeded") + } + return &LLMResponse{Content: "from claude", FinishReason: "stop"}, nil + } + + result, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } + if result.Response.Content != "from claude" { + t.Errorf("content = %q, want 'from claude'", result.Response.Content) + } + if len(result.Attempts) != 1 { + t.Errorf("attempts = %d, want 1 (failed attempt recorded)", len(result.Attempts)) + } +} + +func TestFallback_AllFail(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + makeCandidate("groq", "llama"), + } + + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return nil, errors.New("rate limit exceeded") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error when all candidates fail") + } + var exhausted *FallbackExhaustedError + if !errors.As(err, &exhausted) { + t.Errorf("expected FallbackExhaustedError, got %T: %v", err, err) + } + if len(exhausted.Attempts) != 3 { + t.Errorf("attempts = %d, want 3", len(exhausted.Attempts)) + } +} + +func TestFallback_ContextCanceled(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + ctx, cancel := context.WithCancel(context.Background()) + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + cancel() // cancel context + return nil, context.Canceled + } + t.Error("should not reach second candidate after cancel") + return nil, nil + } + + _, err := fc.Execute(ctx, candidates, run) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %v", err) + } +} + +func TestFallback_NonRetriableError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("string should match pattern") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for non-retriable") + } + var fe *FailoverError + if !errors.As(err, &fe) { + t.Fatalf("expected FailoverError, got %T", err) + } + if fe.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", fe.Reason) + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (non-retriable should not try next)", attempt) + } +} + +func TestFallback_CooldownSkip(t *testing.T) { + now := time.Now() + ct, _ := newTestTracker(now) + fc := NewFallbackChain(ct) + + // Put openai in cooldown + ct.MarkFailure("openai", FailoverRateLimit) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + if provider == "openai" { + t.Error("should not call openai (in cooldown)") + } + return &LLMResponse{Content: "claude response", FinishReason: "stop"}, nil + } + + result, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } + // Should have 1 skipped attempt + skipped := 0 + for _, a := range result.Attempts { + if a.Skipped { + skipped++ + } + } + if skipped != 1 { + t.Errorf("skipped = %d, want 1", skipped) + } +} + +func TestFallback_AllInCooldown(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + // Put all providers in cooldown + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("anthropic", FailoverBilling) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + _, err := fc.Execute(context.Background(), candidates, + func(ctx context.Context, provider, model string) (*LLMResponse, error) { + t.Error("should not call any provider (all in cooldown)") + return nil, nil + }) + + if err == nil { + t.Fatal("expected error when all in cooldown") + } + var exhausted *FallbackExhaustedError + if !errors.As(err, &exhausted) { + t.Fatalf("expected FallbackExhaustedError, got %T", err) + } +} + +func TestFallback_NoCandidates(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + _, err := fc.Execute(context.Background(), nil, successRun("ok")) + if err == nil { + t.Error("expected error for empty candidates") + } +} + +func TestFallback_EmptyFallbacks(t *testing.T) { + // Single primary, no fallbacks: should work like direct call + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + result, err := fc.Execute(context.Background(), candidates, successRun("ok")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "ok" { + t.Error("expected success with single candidate") + } +} + +func TestFallback_UnclassifiedError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("completely unknown internal error") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for unclassified error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (should not fallback on unclassified)", attempt) + } +} + +func TestFallback_SuccessResetsCooldown(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere + } + return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ct.IsAvailable("openai") { + t.Error("success should reset cooldown") + } +} + +// --- Image Fallback Tests --- + +func TestImageFallback_Success(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")} + result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "image result" { + t.Error("expected image result") + } +} + +func TestImageFallback_DimensionError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("image dimensions exceed max 4096x4096") + } + + _, err := fc.ExecuteImage(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for image dimension error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (image dimension error should not retry)", attempt) + } +} + +func TestImageFallback_SizeError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("image exceeds 20 mb") + } + + _, err := fc.ExecuteImage(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for image size error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (image size error should not retry)", attempt) + } +} + +func TestImageFallback_RetryOnOtherErrors(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude-sonnet"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + return nil, errors.New("rate limit exceeded") + } + return &LLMResponse{Content: "image ok", FinishReason: "stop"}, nil + } + + result, err := fc.ExecuteImage(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } +} + +func TestImageFallback_NoCandidates(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + _, err := fc.ExecuteImage(context.Background(), nil, successRun("ok")) + if err == nil { + t.Error("expected error for empty candidates") + } +} + +// --- ResolveCandidates Tests --- + +func TestResolveCandidates_Simple(t *testing.T) { + cfg := ModelConfig{ + Primary: "gpt-4", + Fallbacks: []string{"anthropic/claude-opus", "groq/llama-3"}, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 3 { + t.Fatalf("candidates = %d, want 3", len(candidates)) + } + + if candidates[0].Provider != "openai" || candidates[0].Model != "gpt-4" { + t.Errorf("candidate[0] = %s/%s, want openai/gpt-4", candidates[0].Provider, candidates[0].Model) + } + if candidates[1].Provider != "anthropic" || candidates[1].Model != "claude-opus" { + t.Errorf("candidate[1] = %s/%s, want anthropic/claude-opus", candidates[1].Provider, candidates[1].Model) + } + if candidates[2].Provider != "groq" || candidates[2].Model != "llama-3" { + t.Errorf("candidate[2] = %s/%s, want groq/llama-3", candidates[2].Provider, candidates[2].Model) + } +} + +func TestResolveCandidates_Deduplication(t *testing.T) { + cfg := ModelConfig{ + Primary: "openai/gpt-4", + Fallbacks: []string{"openai/gpt-4", "anthropic/claude"}, + } + + candidates := ResolveCandidates(cfg, "default") + if len(candidates) != 2 { + t.Errorf("candidates = %d, want 2 (duplicate removed)", len(candidates)) + } +} + +func TestResolveCandidates_EmptyFallbacks(t *testing.T) { + cfg := ModelConfig{ + Primary: "gpt-4", + Fallbacks: nil, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 1 { + t.Errorf("candidates = %d, want 1", len(candidates)) + } +} + +func TestResolveCandidates_EmptyPrimary(t *testing.T) { + cfg := ModelConfig{ + Primary: "", + Fallbacks: []string{"anthropic/claude"}, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 1 { + t.Errorf("candidates = %d, want 1", len(candidates)) + } +} + +func TestFallbackExhaustedError_Message(t *testing.T) { + e := &FallbackExhaustedError{ + Attempts: []FallbackAttempt{ + { + Provider: "openai", + Model: "gpt-4", + Error: errors.New("rate limited"), + Reason: FailoverRateLimit, + Duration: 500 * time.Millisecond, + }, + {Provider: "anthropic", Model: "claude", Skipped: true}, + }, + } + msg := e.Error() + if msg == "" { + t.Error("expected non-empty error message") + } +} diff --git a/pkg/providers/github_copilot_provider.go b/pkg/providers/github_copilot_provider.go new file mode 100644 index 000000000..6124881f7 --- /dev/null +++ b/pkg/providers/github_copilot_provider.go @@ -0,0 +1,82 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + + copilot "github.com/github/copilot-sdk/go" +) + +type GitHubCopilotProvider struct { + uri string + connectMode string // `stdio` or `grpc`` + + session *copilot.Session +} + +func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*GitHubCopilotProvider, error) { + var session *copilot.Session + if connectMode == "" { + connectMode = "grpc" + } + switch connectMode { + + case "stdio": + // todo + case "grpc": + client := copilot.NewClient(&copilot.ClientOptions{ + CLIUrl: uri, + }) + if err := client.Start(context.Background()); err != nil { + return nil, fmt.Errorf( + "Can't connect to Github Copilot, https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server for details", + ) + } + defer client.Stop() + session, _ = client.CreateSession(context.Background(), &copilot.SessionConfig{ + Model: model, + Hooks: &copilot.SessionHooks{}, + }) + + } + + return &GitHubCopilotProvider{ + uri: uri, + connectMode: connectMode, + session: session, + }, nil +} + +// Chat sends a chat request to GitHub Copilot +func (p *GitHubCopilotProvider) Chat( + ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any, +) (*LLMResponse, error) { + type tempMessage struct { + Role string `json:"role"` + Content string `json:"content"` + } + out := make([]tempMessage, 0, len(messages)) + + for _, msg := range messages { + out = append(out, tempMessage{ + Role: msg.Role, + Content: msg.Content, + }) + } + + fullcontent, _ := json.Marshal(out) + + content, _ := p.session.Send(ctx, copilot.MessageOptions{ + Prompt: string(fullcontent), + }) + + return &LLMResponse{ + FinishReason: "stop", + Content: content, + }, nil +} + +func (p *GitHubCopilotProvider) GetDefaultModel() string { + return "gpt-4.1" +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 7179c4cc5..d0c4344f3 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -7,398 +7,37 @@ package providers import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "github.com/sipeed/picoclaw/pkg/auth" - "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers/openai_compat" ) type HTTPProvider struct { - apiKey string - apiBase string - httpClient *http.Client + delegate *openai_compat.Provider } func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { - client := &http.Client{ - Timeout: 0, - } - - if proxy != "" { - proxyURL, err := url.Parse(proxy) - if err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - } - } - } - return &HTTPProvider{ - apiKey: apiKey, - apiBase: apiBase, - httpClient: client, + delegate: openai_compat.NewProvider(apiKey, apiBase, proxy), } } -func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - if p.apiBase == "" { - return nil, fmt.Errorf("API base not configured") - } - - // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5) - if idx := strings.Index(model, "/"); idx != -1 { - prefix := model[:idx] - if prefix == "moonshot" || prefix == "nvidia" { - model = model[idx+1:] - } - } - - requestBody := map[string]interface{}{ - "model": model, - "messages": messages, - } - - if len(tools) > 0 { - requestBody["tools"] = tools - requestBody["tool_choice"] = "auto" - } - - if maxTokens, ok := options["max_tokens"].(int); ok { - lowerModel := strings.ToLower(model) - if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { - requestBody["max_completion_tokens"] = maxTokens - } else { - requestBody["max_tokens"] = maxTokens - } - } - - if temperature, ok := options["temperature"].(float64); ok { - lowerModel := strings.ToLower(model) - // Kimi k2 models only support temperature=1 - if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { - requestBody["temperature"] = 1.0 - } else { - requestBody["temperature"] = temperature - } - } - - jsonData, err := json.Marshal(requestBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - if p.apiKey != "" { - req.Header.Set("Authorization", "Bearer "+p.apiKey) - } - - resp, err := p.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API error: %s", string(body)) +func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider { + return &HTTPProvider{ + delegate: openai_compat.NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField), } - - return p.parseResponse(body) } -func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { - var apiResponse struct { - Choices []struct { - Message struct { - Content string `json:"content"` - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function *struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` - } `json:"tool_calls"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage *UsageInfo `json:"usage"` - } - - if err := json.Unmarshal(body, &apiResponse); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - if len(apiResponse.Choices) == 0 { - return &LLMResponse{ - Content: "", - FinishReason: "stop", - }, nil - } - - choice := apiResponse.Choices[0] - - toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) - for _, tc := range choice.Message.ToolCalls { - arguments := make(map[string]interface{}) - name := "" - - // Handle OpenAI format with nested function object - if tc.Type == "function" && tc.Function != nil { - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } else if tc.Function != nil { - // Legacy format without type field - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } - - toolCalls = append(toolCalls, ToolCall{ - ID: tc.ID, - Name: name, - Arguments: arguments, - }) - } - - return &LLMResponse{ - Content: choice.Message.Content, - ToolCalls: toolCalls, - FinishReason: choice.FinishReason, - Usage: apiResponse.Usage, - }, nil +func (p *HTTPProvider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + return p.delegate.Chat(ctx, messages, tools, model, options) } func (p *HTTPProvider) GetDefaultModel() string { return "" } - -func createClaudeAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("anthropic") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) - } - if cred == nil { - return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") - } - return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil -} - -func createCodexAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("openai") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) - } - if cred == nil { - return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") - } - return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil -} - -func CreateProvider(cfg *config.Config) (LLMProvider, error) { - model := cfg.Agents.Defaults.Model - providerName := strings.ToLower(cfg.Agents.Defaults.Provider) - - var apiKey, apiBase, proxy string - - lowerModel := strings.ToLower(model) - - // First, try to use explicitly configured provider - if providerName != "" { - switch providerName { - case "groq": - if cfg.Providers.Groq.APIKey != "" { - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - } - case "openai", "gpt": - if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider() - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - } - case "anthropic", "claude": - if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - return createClaudeAuthProvider() - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - } - case "openrouter": - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } - case "zhipu", "glm": - if cfg.Providers.Zhipu.APIKey != "" { - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - } - case "gemini", "google": - if cfg.Providers.Gemini.APIKey != "" { - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - } - case "vllm": - if cfg.Providers.VLLM.APIBase != "" { - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - } - case "claude-cli", "claudecode", "claude-code": - workspace := cfg.Agents.Defaults.Workspace - if workspace == "" { - workspace = "." - } - return NewClaudeCliProvider(workspace), nil - } - } - - // Fallback: detect provider from model name - if apiKey == "" && apiBase == "" { - switch { - case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": - apiKey = cfg.Providers.Moonshot.APIKey - apiBase = cfg.Providers.Moonshot.APIBase - proxy = cfg.Providers.Moonshot.Proxy - if apiBase == "" { - apiBase = "https://api.moonshot.cn/v1" - } - - case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - - case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - return createClaudeAuthProvider() - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - proxy = cfg.Providers.Anthropic.Proxy - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - - case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider() - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - proxy = cfg.Providers.OpenAI.Proxy - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - - case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - proxy = cfg.Providers.Gemini.Proxy - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - - case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - proxy = cfg.Providers.Zhipu.Proxy - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - - case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - proxy = cfg.Providers.Groq.Proxy - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - - case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": - apiKey = cfg.Providers.Nvidia.APIKey - apiBase = cfg.Providers.Nvidia.APIBase - proxy = cfg.Providers.Nvidia.Proxy - if apiBase == "" { - apiBase = "https://integrate.api.nvidia.com/v1" - } - - case cfg.Providers.VLLM.APIBase != "": - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - proxy = cfg.Providers.VLLM.Proxy - - default: - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } else { - return nil, fmt.Errorf("no API key configured for model: %s", model) - } - } - } - - if apiKey == "" && !strings.HasPrefix(model, "bedrock/") { - return nil, fmt.Errorf("no API key configured for provider (model: %s)", model) - } - - if apiBase == "" { - return nil, fmt.Errorf("no API base configured for provider (model: %s)", model) - } - - return NewHTTPProvider(apiKey, apiBase, proxy), nil -} diff --git a/pkg/providers/legacy_provider.go b/pkg/providers/legacy_provider.go new file mode 100644 index 000000000..eb13cec65 --- /dev/null +++ b/pkg/providers/legacy_provider.go @@ -0,0 +1,49 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package providers + +import ( + "fmt" + + "github.com/sipeed/picoclaw/pkg/config" +) + +// CreateProvider creates a provider based on the configuration. +// It uses the model_list configuration (new format) to create providers. +// The old providers config is automatically converted to model_list during config loading. +// Returns the provider, the model ID to use, and any error. +func CreateProvider(cfg *config.Config) (LLMProvider, string, error) { + model := cfg.Agents.Defaults.Model + + // Ensure model_list is populated (should be done by LoadConfig, but handle edge cases) + if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() { + cfg.ModelList = config.ConvertProvidersToModelList(cfg) + } + + // Must have model_list at this point + if len(cfg.ModelList) == 0 { + return nil, "", fmt.Errorf("no providers configured. Please add entries to model_list in your config") + } + + // Get model config from model_list + modelCfg, err := cfg.GetModelConfig(model) + if err != nil { + return nil, "", fmt.Errorf("model %q not found in model_list: %w", model, err) + } + + // Inject global workspace if not set in model config + if modelCfg.Workspace == "" { + modelCfg.Workspace = cfg.WorkspacePath() + } + + // Use factory to create provider + provider, modelID, err := CreateProviderFromConfig(modelCfg) + if err != nil { + return nil, "", fmt.Errorf("failed to create provider for model %q: %w", model, err) + } + + return provider, modelID, nil +} diff --git a/pkg/providers/model_ref.go b/pkg/providers/model_ref.go new file mode 100644 index 000000000..0d1b02d16 --- /dev/null +++ b/pkg/providers/model_ref.go @@ -0,0 +1,64 @@ +package providers + +import "strings" + +// ModelRef represents a parsed model reference with provider and model name. +type ModelRef struct { + Provider string + Model string +} + +// ParseModelRef parses "anthropic/claude-opus" into {Provider: "anthropic", Model: "claude-opus"}. +// If no slash present, uses defaultProvider. +// Returns nil for empty input. +func ParseModelRef(raw string, defaultProvider string) *ModelRef { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + + if idx := strings.Index(raw, "/"); idx > 0 { + provider := NormalizeProvider(raw[:idx]) + model := strings.TrimSpace(raw[idx+1:]) + if model == "" { + return nil + } + return &ModelRef{Provider: provider, Model: model} + } + + return &ModelRef{ + Provider: NormalizeProvider(defaultProvider), + Model: raw, + } +} + +// NormalizeProvider normalizes provider identifiers to canonical form. +func NormalizeProvider(provider string) string { + p := strings.ToLower(strings.TrimSpace(provider)) + + switch p { + case "z.ai", "z-ai": + return "zai" + case "opencode-zen": + return "opencode" + case "qwen": + return "qwen-portal" + case "kimi-code": + return "kimi-coding" + case "gpt": + return "openai" + case "claude": + return "anthropic" + case "glm": + return "zhipu" + case "google": + return "gemini" + } + + return p +} + +// ModelKey returns a canonical "provider/model" key for deduplication. +func ModelKey(provider, model string) string { + return NormalizeProvider(provider) + "/" + strings.ToLower(strings.TrimSpace(model)) +} diff --git a/pkg/providers/model_ref_test.go b/pkg/providers/model_ref_test.go new file mode 100644 index 000000000..6dd25167f --- /dev/null +++ b/pkg/providers/model_ref_test.go @@ -0,0 +1,125 @@ +package providers + +import "testing" + +func TestParseModelRef_WithSlash(t *testing.T) { + ref := ParseModelRef("anthropic/claude-opus", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", ref.Provider) + } + if ref.Model != "claude-opus" { + t.Errorf("model = %q, want claude-opus", ref.Model) + } +} + +func TestParseModelRef_WithoutSlash(t *testing.T) { + ref := ParseModelRef("gpt-4", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "openai" { + t.Errorf("provider = %q, want openai", ref.Provider) + } + if ref.Model != "gpt-4" { + t.Errorf("model = %q, want gpt-4", ref.Model) + } +} + +func TestParseModelRef_Empty(t *testing.T) { + ref := ParseModelRef("", "openai") + if ref != nil { + t.Errorf("expected nil for empty string, got %+v", ref) + } +} + +func TestParseModelRef_EmptyModelAfterSlash(t *testing.T) { + ref := ParseModelRef("openai/", "default") + if ref != nil { + t.Errorf("expected nil for empty model, got %+v", ref) + } +} + +func TestParseModelRef_WhitespaceHandling(t *testing.T) { + ref := ParseModelRef(" anthropic / claude-opus ", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", ref.Provider) + } + if ref.Model != "claude-opus" { + t.Errorf("model = %q, want claude-opus", ref.Model) + } +} + +func TestNormalizeProvider(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"OpenAI", "openai"}, + {"ANTHROPIC", "anthropic"}, + {"z.ai", "zai"}, + {"z-ai", "zai"}, + {"Z.AI", "zai"}, + {"opencode-zen", "opencode"}, + {"qwen", "qwen-portal"}, + {"kimi-code", "kimi-coding"}, + {"gpt", "openai"}, + {"claude", "anthropic"}, + {"glm", "zhipu"}, + {"google", "gemini"}, + {"groq", "groq"}, + {"", ""}, + } + + for _, tt := range tests { + got := NormalizeProvider(tt.input) + if got != tt.want { + t.Errorf("NormalizeProvider(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestModelKey(t *testing.T) { + tests := []struct { + provider string + model string + want string + }{ + {"openai", "gpt-4", "openai/gpt-4"}, + {"Anthropic", "Claude-Opus", "anthropic/claude-opus"}, + {"claude", "sonnet", "anthropic/sonnet"}, + {"z.ai", "Model-X", "zai/model-x"}, + } + + for _, tt := range tests { + got := ModelKey(tt.provider, tt.model) + if got != tt.want { + t.Errorf("ModelKey(%q, %q) = %q, want %q", tt.provider, tt.model, got, tt.want) + } + } +} + +func TestParseModelRef_ProviderNormalization(t *testing.T) { + ref := ParseModelRef("Z.AI/model-x", "default") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "zai" { + t.Errorf("provider = %q, want zai", ref.Provider) + } +} + +func TestParseModelRef_DefaultProviderNormalization(t *testing.T) { + ref := ParseModelRef("gpt-4o", "GPT") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "openai" { + t.Errorf("provider = %q, want openai (normalized from GPT)", ref.Provider) + } +} diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go new file mode 100644 index 000000000..b8528953a --- /dev/null +++ b/pkg/providers/openai_compat/provider.go @@ -0,0 +1,278 @@ +package openai_compat + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ( + ToolCall = protocoltypes.ToolCall + FunctionCall = protocoltypes.FunctionCall + LLMResponse = protocoltypes.LLMResponse + UsageInfo = protocoltypes.UsageInfo + Message = protocoltypes.Message + ToolDefinition = protocoltypes.ToolDefinition + ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition + ExtraContent = protocoltypes.ExtraContent + GoogleExtra = protocoltypes.GoogleExtra +) + +type Provider struct { + apiKey string + apiBase string + maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models) + httpClient *http.Client +} + +func NewProvider(apiKey, apiBase, proxy string) *Provider { + return NewProviderWithMaxTokensField(apiKey, apiBase, proxy, "") +} + +func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider { + client := &http.Client{ + Timeout: 120 * time.Second, + } + + if proxy != "" { + parsed, err := url.Parse(proxy) + if err == nil { + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(parsed), + } + } else { + log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err) + } + } + + return &Provider{ + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + maxTokensField: maxTokensField, + httpClient: client, + } +} + +func (p *Provider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, +) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + model = normalizeModel(model, p.apiBase) + + requestBody := map[string]any{ + "model": model, + "messages": messages, + } + + if len(tools) > 0 { + requestBody["tools"] = tools + requestBody["tool_choice"] = "auto" + } + + if maxTokens, ok := asInt(options["max_tokens"]); ok { + // Use configured maxTokensField if specified, otherwise fallback to model-based detection + fieldName := p.maxTokensField + if fieldName == "" { + // Fallback: detect from model name for backward compatibility + lowerModel := strings.ToLower(model) + if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") || + strings.Contains(lowerModel, "gpt-5") { + fieldName = "max_completion_tokens" + } else { + fieldName = "max_tokens" + } + } + requestBody[fieldName] = maxTokens + } + + if temperature, ok := asFloat(options["temperature"]); ok { + lowerModel := strings.ToLower(model) + // Kimi k2 models only support temperature=1. + if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { + requestBody["temperature"] = 1.0 + } else { + requestBody["temperature"] = temperature + } + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) + } + + return parseResponse(body) +} + +func parseResponse(body []byte) (*LLMResponse, error) { + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function *struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + ExtraContent *struct { + Google *struct { + ThoughtSignature string `json:"thought_signature"` + } `json:"google"` + } `json:"extra_content"` + } `json:"tool_calls"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage *UsageInfo `json:"usage"` + } + + if err := json.Unmarshal(body, &apiResponse); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + if len(apiResponse.Choices) == 0 { + return &LLMResponse{ + Content: "", + FinishReason: "stop", + }, nil + } + + choice := apiResponse.Choices[0] + toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) + for _, tc := range choice.Message.ToolCalls { + arguments := make(map[string]any) + name := "" + + // Extract thought_signature from Gemini/Google-specific extra content + thoughtSignature := "" + if tc.ExtraContent != nil && tc.ExtraContent.Google != nil { + thoughtSignature = tc.ExtraContent.Google.ThoughtSignature + } + + if tc.Function != nil { + name = tc.Function.Name + if tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { + log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) + arguments["raw"] = tc.Function.Arguments + } + } + } + + // Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence + toolCall := ToolCall{ + ID: tc.ID, + Name: name, + Arguments: arguments, + ThoughtSignature: thoughtSignature, + } + + if thoughtSignature != "" { + toolCall.ExtraContent = &ExtraContent{ + Google: &GoogleExtra{ + ThoughtSignature: thoughtSignature, + }, + } + } + + toolCalls = append(toolCalls, toolCall) + } + + return &LLMResponse{ + Content: choice.Message.Content, + ToolCalls: toolCalls, + FinishReason: choice.FinishReason, + Usage: apiResponse.Usage, + }, nil +} + +func normalizeModel(model, apiBase string) string { + idx := strings.Index(model, "/") + if idx == -1 { + return model + } + + if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") { + return model + } + + prefix := strings.ToLower(model[:idx]) + switch prefix { + case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu": + return model[idx+1:] + default: + return model + } +} + +func asInt(v any) (int, bool) { + switch val := v.(type) { + case int: + return val, true + case int64: + return int(val), true + case float64: + return int(val), true + case float32: + return int(val), true + default: + return 0, false + } +} + +func asFloat(v any) (float64, bool) { + switch val := v.(type) { + case float64: + return val, true + case float32: + return float64(val), true + case int: + return float64(val), true + case int64: + return float64(val), true + default: + return 0, false + } +} diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go new file mode 100644 index 000000000..42f9d42ab --- /dev/null +++ b/pkg/providers/openai_compat/provider_test.go @@ -0,0 +1,283 @@ +package openai_compat + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "glm-4.7", + map[string]any{"max_tokens": 1234}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if _, ok := requestBody["max_completion_tokens"]; !ok { + t.Fatalf("expected max_completion_tokens in request body") + } + if _, ok := requestBody["max_tokens"]; ok { + t.Fatalf("did not expect max_tokens key for glm model") + } +} + +func TestProviderChat_ParsesToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{ + "content": "", + "tool_calls": []map[string]any{ + { + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}", + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + "usage": map[string]any{ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } + if out.ToolCalls[0].Arguments["city"] != "SF" { + t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"]) + } +} + +func TestProviderChat_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "moonshot/kimi-k2.5", + map[string]any{"temperature": 0.3}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["model"] != "kimi-k2.5" { + t.Fatalf("model = %v, want kimi-k2.5", requestBody["model"]) + } + if requestBody["temperature"] != 1.0 { + t.Fatalf("temperature = %v, want 1.0", requestBody["temperature"]) + } +} + +func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { + tests := []struct { + name string + input string + wantModel string + }{ + { + name: "strips groq prefix and keeps nested model", + input: "groq/openai/gpt-oss-120b", + wantModel: "openai/gpt-oss-120b", + }, + { + name: "strips ollama prefix", + input: "ollama/qwen2.5:14b", + wantModel: "qwen2.5:14b", + }, + { + name: "strips deepseek prefix", + input: "deepseek/deepseek-chat", + wantModel: "deepseek-chat", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["model"] != tt.wantModel { + t.Fatalf("model = %v, want %s", requestBody["model"], tt.wantModel) + } + }) + } +} + +func TestProvider_ProxyConfigured(t *testing.T) { + proxyURL := "http://127.0.0.1:8080" + p := NewProvider("key", "https://example.com", proxyURL) + + transport, ok := p.httpClient.Transport.(*http.Transport) + if !ok || transport == nil { + t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport) + } + + req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}} + gotProxy, err := transport.Proxy(req) + if err != nil { + t.Fatalf("proxy function returned error: %v", err) + } + if gotProxy == nil || gotProxy.String() != proxyURL { + t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL) + } +} + +func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "gpt-4o", + map[string]any{"max_tokens": float64(512), "temperature": 1}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["max_tokens"] != float64(512) { + t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"]) + } + if requestBody["temperature"] != float64(1) { + t.Fatalf("temperature = %v, want 1", requestBody["temperature"]) + } +} + +func TestNormalizeModel_UsesAPIBase(t *testing.T) { + if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" { + t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat") + } + if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" { + t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto") + } +} diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go new file mode 100644 index 000000000..3a089ca47 --- /dev/null +++ b/pkg/providers/protocoltypes/types.go @@ -0,0 +1,56 @@ +package protocoltypes + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Function *FunctionCall `json:"function,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]any `json:"arguments,omitempty"` + ThoughtSignature string `json:"-"` // Internal use only + ExtraContent *ExtraContent `json:"extra_content,omitempty"` +} + +type ExtraContent struct { + Google *GoogleExtra `json:"google,omitempty"` +} + +type GoogleExtra struct { + ThoughtSignature string `json:"thought_signature,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + ThoughtSignature string `json:"thought_signature,omitempty"` +} + +type LLMResponse struct { + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` +} + +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ToolDefinition struct { + Type string `json:"type"` + Function ToolFunctionDefinition `json:"function"` +} + +type ToolFunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]any `json:"parameters"` +} diff --git a/pkg/providers/tool_call_extract.go b/pkg/providers/tool_call_extract.go new file mode 100644 index 000000000..7ddea0e99 --- /dev/null +++ b/pkg/providers/tool_call_extract.go @@ -0,0 +1,72 @@ +package providers + +import ( + "encoding/json" + "strings" +) + +// extractToolCallsFromText parses tool call JSON from response text. +// Both ClaudeCliProvider and CodexCliProvider use this to extract +// tool calls that the model outputs in its response text. +func extractToolCallsFromText(text string) []ToolCall { + start := strings.Index(text, `{"tool_calls"`) + if start == -1 { + return nil + } + + end := findMatchingBrace(text, start) + if end == start { + return nil + } + + jsonStr := text[start:end] + + var wrapper struct { + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } + + if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil { + return nil + } + + var result []ToolCall + for _, tc := range wrapper.ToolCalls { + var args map[string]any + json.Unmarshal([]byte(tc.Function.Arguments), &args) + + result = append(result, ToolCall{ + ID: tc.ID, + Type: tc.Type, + Name: tc.Function.Name, + Arguments: args, + Function: &FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } + + return result +} + +// stripToolCallsFromText removes tool call JSON from response text. +func stripToolCallsFromText(text string) string { + start := strings.Index(text, `{"tool_calls"`) + if start == -1 { + return text + } + + end := findMatchingBrace(text, start) + if end == start { + return text + } + + return strings.TrimSpace(text[:start] + text[end:]) +} diff --git a/pkg/providers/toolcall_utils.go b/pkg/providers/toolcall_utils.go new file mode 100644 index 000000000..49218b1b1 --- /dev/null +++ b/pkg/providers/toolcall_utils.go @@ -0,0 +1,54 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package providers + +import "encoding/json" + +// NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated. +// It handles cases where Name/Arguments might be in different locations (top-level vs Function) +// and ensures both are populated consistently. +func NormalizeToolCall(tc ToolCall) ToolCall { + normalized := tc + + // Ensure Name is populated from Function if not set + if normalized.Name == "" && normalized.Function != nil { + normalized.Name = normalized.Function.Name + } + + // Ensure Arguments is not nil + if normalized.Arguments == nil { + normalized.Arguments = map[string]any{} + } + + // Parse Arguments from Function.Arguments if not already set + if len(normalized.Arguments) == 0 && normalized.Function != nil && normalized.Function.Arguments != "" { + var parsed map[string]any + if err := json.Unmarshal([]byte(normalized.Function.Arguments), &parsed); err == nil && parsed != nil { + normalized.Arguments = parsed + } + } + + // Ensure Function is populated with consistent values + argsJSON, _ := json.Marshal(normalized.Arguments) + if normalized.Function == nil { + normalized.Function = &FunctionCall{ + Name: normalized.Name, + Arguments: string(argsJSON), + } + } else { + if normalized.Function.Name == "" { + normalized.Function.Name = normalized.Name + } + if normalized.Name == "" { + normalized.Name = normalized.Function.Name + } + if normalized.Function.Arguments == "" { + normalized.Function.Arguments = string(argsJSON) + } + } + + return normalized +} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 88b62e975..f711e7803 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,52 +1,74 @@ package providers -import "context" - -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type,omitempty"` - Function *FunctionCall `json:"function,omitempty"` - Name string `json:"name,omitempty"` - Arguments map[string]interface{} `json:"arguments,omitempty"` -} +import ( + "context" + "fmt" -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ( + ToolCall = protocoltypes.ToolCall + FunctionCall = protocoltypes.FunctionCall + LLMResponse = protocoltypes.LLMResponse + UsageInfo = protocoltypes.UsageInfo + Message = protocoltypes.Message + ToolDefinition = protocoltypes.ToolDefinition + ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition + ExtraContent = protocoltypes.ExtraContent + GoogleExtra = protocoltypes.GoogleExtra +) -type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` +type LLMProvider interface { + Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, + ) (*LLMResponse, error) + GetDefaultModel() string } -type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` +// FailoverReason classifies why an LLM request failed for fallback decisions. +type FailoverReason string + +const ( + FailoverAuth FailoverReason = "auth" + FailoverRateLimit FailoverReason = "rate_limit" + FailoverBilling FailoverReason = "billing" + FailoverTimeout FailoverReason = "timeout" + FailoverFormat FailoverReason = "format" + FailoverOverloaded FailoverReason = "overloaded" + FailoverUnknown FailoverReason = "unknown" +) + +// FailoverError wraps an LLM provider error with classification metadata. +type FailoverError struct { + Reason FailoverReason + Provider string + Model string + Status int + Wrapped error } -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` +func (e *FailoverError) Error() string { + return fmt.Sprintf("failover(%s): provider=%s model=%s status=%d: %v", + e.Reason, e.Provider, e.Model, e.Status, e.Wrapped) } -type LLMProvider interface { - Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) - GetDefaultModel() string +func (e *FailoverError) Unwrap() error { + return e.Wrapped } -type ToolDefinition struct { - Type string `json:"type"` - Function ToolFunctionDefinition `json:"function"` +// IsRetriable returns true if this error should trigger fallback to next candidate. +// Non-retriable: Format errors (bad request structure, image dimension/size). +func (e *FailoverError) IsRetriable() bool { + return e.Reason != FailoverFormat } -type ToolFunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]interface{} `json:"parameters"` +// ModelConfig holds primary model and fallback list. +type ModelConfig struct { + Primary string + Fallbacks []string } diff --git a/pkg/relation/authorizer.go b/pkg/relation/authorizer.go new file mode 100644 index 000000000..18ae37c80 --- /dev/null +++ b/pkg/relation/authorizer.go @@ -0,0 +1,479 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package relation + +import ( + "fmt" + "sync" +) + +// AuthzRequest represents an authorization request +type AuthzRequest struct { + // SubjectHID is the H-id of the subject requesting access + SubjectHID string + + // SubjectSID is the optional S-id of the subject + SubjectSID string + + // Action is the action being requested + Action Action + + // Resource is the resource being accessed + Resource *ResourceID + + // Context provides additional context for the decision + Context map[string]interface{} +} + +// AuthzResult represents the result of an authorization decision +type AuthzResult struct { + // Allowed is true if access is granted + Allowed bool + + // Reason is a human-readable explanation + Reason string + + // MatchedRelation is the relation that matched (if any) + MatchedRelation *Relation + + // MatchedPolicy is the policy that matched (if any) + MatchedPolicy *Policy +} + +// Authorizer handles authorization decisions +type Authorizer struct { + registry *Registry + policies *PolicySet + mu sync.RWMutex + + // DefaultDeny causes authorization to deny by default + DefaultDeny bool +} + +// NewAuthorizer creates a new authorizer +func NewAuthorizer() *Authorizer { + return &Authorizer{ + registry: NewRegistry(), + policies: DefaultPolicies(), + DefaultDeny: true, + } +} + +// NewAuthorizerWithRegistry creates an authorizer with a specific registry +func NewAuthorizerWithRegistry(registry *Registry) *Authorizer { + return &Authorizer{ + registry: registry, + policies: DefaultPolicies(), + DefaultDeny: true, + } +} + +// SetRegistry sets the relation registry +func (a *Authorizer) SetRegistry(registry *Registry) { + a.mu.Lock() + defer a.mu.Unlock() + a.registry = registry +} + +// GetRegistry returns the relation registry +func (a *Authorizer) GetRegistry() *Registry { + a.mu.RLock() + defer a.mu.RUnlock() + return a.registry +} + +// SetPolicies sets the policy set +func (a *Authorizer) SetPolicies(policies *PolicySet) { + a.mu.Lock() + defer a.mu.Unlock() + a.policies = policies +} + +// GetPolicies returns the policy set +func (a *Authorizer) GetPolicies() *PolicySet { + a.mu.RLock() + defer a.mu.RUnlock() + return a.policies +} + +// Authorize checks if a subject is authorized to perform an action on a resource +func (a *Authorizer) Authorize(req *AuthzRequest) *AuthzResult { + if req == nil { + return &AuthzResult{ + Allowed: false, + Reason: "nil request", + } + } + + a.mu.RLock() + defer a.mu.RUnlock() + + // Check for owner access first (always allowed) + if a.isOwner(req) { + return &AuthzResult{ + Allowed: true, + Reason: "owner has full access", + } + } + + // Check policy-based authorization + policyResult := a.checkPolicies(req) + if policyResult != nil { + return policyResult + } + + // Check relation-based authorization + return a.checkRelations(req) +} + +// isOwner checks if the subject is the owner of the resource +func (a *Authorizer) isOwner(req *AuthzRequest) bool { + // Owner check: subject H-id matches resource HID + return req.SubjectHID == req.Resource.HID +} + +// checkPolicies checks if any policies allow or deny the request +func (a *Authorizer) checkPolicies(req *AuthzRequest) *AuthzResult { + if a.policies == nil || a.policies.Count() == 0 { + return nil + } + + effect := a.policies.Evaluate(req.SubjectHID, req.Action, req.Resource) + + if effect == EffectDeny { + return &AuthzResult{ + Allowed: false, + Reason: "denied by policy", + } + } + + if effect == EffectAllow { + return &AuthzResult{ + Allowed: true, + Reason: "allowed by policy", + } + } + + return nil +} + +// checkRelations checks if any relations allow or deny the request +func (a *Authorizer) checkRelations(req *AuthzRequest) *AuthzResult { + if a.registry == nil { + return a.defaultResult() + } + + // Get the highest privilege relation + privilege := a.registry.GetPrivilege(req.SubjectHID, req.SubjectSID, req.Resource) + + if privilege == "" || privilege == RelationAny { + return a.defaultResult() + } + + // Check if the relation type allows the action + if ActionAllowedByRelation(req.Action, privilege) { + return &AuthzResult{ + Allowed: true, + Reason: fmt.Sprintf("allowed by %s relation", privilege), + } + } + + // Check for owner privileges (owner can do anything) + if privilege == RelationOwner { + return &AuthzResult{ + Allowed: true, + Reason: "owner has full access", + } + } + + return &AuthzResult{ + Allowed: false, + Reason: fmt.Sprintf("%s relation does not allow %s", privilege, req.Action), + } +} + +// defaultResult returns the default authorization result +func (a *Authorizer) defaultResult() *AuthzResult { + if a.DefaultDeny { + return &AuthzResult{ + Allowed: false, + Reason: "default deny", + } + } + return &AuthzResult{ + Allowed: true, + Reason: "default allow", + } +} + +// CanRead is a convenience method to check read authorization +func (a *Authorizer) CanRead(subjectHID, subjectSID string, resource *ResourceID) bool { + req := &AuthzRequest{ + SubjectHID: subjectHID, + SubjectSID: subjectSID, + Action: ActionRead, + Resource: resource, + } + return a.Authorize(req).Allowed +} + +// CanWrite is a convenience method to check write authorization +func (a *Authorizer) CanWrite(subjectHID, subjectSID string, resource *ResourceID) bool { + req := &AuthzRequest{ + SubjectHID: subjectHID, + SubjectSID: subjectSID, + Action: ActionWrite, + Resource: resource, + } + return a.Authorize(req).Allowed +} + +// CanDelete is a convenience method to check delete authorization +func (a *Authorizer) CanDelete(subjectHID, subjectSID string, resource *ResourceID) bool { + req := &AuthzRequest{ + SubjectHID: subjectHID, + SubjectSID: subjectSID, + Action: ActionDelete, + Resource: resource, + } + return a.Authorize(req).Allowed +} + +// CanExecute is a convenience method to check execute authorization +func (a *Authorizer) CanExecute(subjectHID, subjectSID string, resource *ResourceID) bool { + req := &AuthzRequest{ + SubjectHID: subjectHID, + SubjectSID: subjectSID, + Action: ActionExecute, + Resource: resource, + } + return a.Authorize(req).Allowed +} + +// CanAdmin is a convenience method to check admin authorization +func (a *Authorizer) CanAdmin(subjectHID, subjectSID string, resource *ResourceID) bool { + req := &AuthzRequest{ + SubjectHID: subjectHID, + SubjectSID: subjectSID, + Action: ActionAdmin, + Resource: resource, + } + return a.Authorize(req).Allowed +} + +// CanShare is a convenience method to check share authorization +func (a *Authorizer) CanShare(subjectHID, subjectSID string, resource *ResourceID) bool { + req := &AuthzRequest{ + SubjectHID: subjectHID, + SubjectSID: subjectSID, + Action: ActionShare, + Resource: resource, + } + return a.Authorize(req).Allowed +} + +// Grant grants a relation to a subject for a resource +func (a *Authorizer) Grant(subjectHID, subjectSID string, resource *ResourceID, relType RelationType) error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.registry == nil { + return fmt.Errorf("no registry configured") + } + + rel := NewRelationWithSID(subjectHID, subjectSID, resource, relType) + return a.registry.Add(rel) +} + +// Revoke revokes a relation from a subject for a resource +func (a *Authorizer) Revoke(subjectHID, subjectSID string, resource *ResourceID, relType RelationType) error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.registry == nil { + return fmt.Errorf("no registry configured") + } + + return a.registry.Remove(subjectHID, subjectSID, resource, relType) +} + +// GetRelations returns all relations for a subject +func (a *Authorizer) GetRelations(subjectHID, subjectSID string) []*Relation { + a.mu.RLock() + defer a.mu.RUnlock() + + if a.registry == nil { + return nil + } + + return a.registry.GetBySubject(subjectHID, subjectSID) +} + +// GetResources returns all resources a subject has access to +func (a *Authorizer) GetResources(subjectHID, subjectSID string) []*ResourceID { + a.mu.RLock() + defer a.mu.RUnlock() + + if a.registry == nil { + return nil + } + + return a.registry.GetResourcesBySubject(subjectHID, subjectSID) +} + +// FilterAuthorized returns only the resources that the subject is authorized to access +func (a *Authorizer) FilterAuthorized(subjectHID, subjectSID string, resources []*ResourceID, action Action) []*ResourceID { + a.mu.RLock() + defer a.mu.RUnlock() + + authorized := make([]*ResourceID, 0) + + for _, resource := range resources { + req := &AuthzRequest{ + SubjectHID: subjectHID, + SubjectSID: subjectSID, + Action: action, + Resource: resource, + } + + if a.Authorize(req).Allowed { + authorized = append(authorized, resource) + } + } + + return authorized +} + +// BatchAuthorize checks authorization for multiple requests +func (a *Authorizer) BatchAuthorize(requests []*AuthzRequest) []*AuthzResult { + a.mu.RLock() + defer a.mu.RUnlock() + + results := make([]*AuthzResult, len(requests)) + + for i, req := range requests { + results[i] = a.checkRelations(req) + } + + return results +} + +// TransferOwnership transfers ownership of a resource to a new H-id +func (a *Authorizer) TransferOwnership(resource *ResourceID, currentOwner, newOwner string) error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.registry == nil { + return fmt.Errorf("no registry configured") + } + + // Remove old owner relations + oldOwnerFilter := &RelationFilter{ + SubjectHID: currentOwner, + ResourceType: resource.Type, + ResourceID: resource.ID, + RelationType: RelationOwner, + } + + matching := a.registry.Apply(oldOwnerFilter) + for _, rel := range matching { + if rel.RelationType == RelationOwner { + a.registry.Remove(rel.SubjectHID, rel.SubjectSID, rel.Resource, RelationOwner) + } + } + + // Add new owner relation + newOwnerRel := NewRelation(newOwner, resource, RelationOwner) + return a.registry.Add(newOwnerRel) +} + +// AddPolicy adds a policy to the authorizer +func (a *Authorizer) AddPolicy(policy *Policy) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.policies == nil { + a.policies = NewPolicySet() + } + + a.policies.Add(policy) +} + +// RemovePolicy removes a policy from the authorizer +func (a *Authorizer) RemovePolicy(policyID string) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.policies == nil { + return + } + + a.policies.Remove(policyID) +} + +// GetPolicy returns a policy by ID +func (a *Authorizer) GetPolicy(policyID string) (*Policy, bool) { + a.mu.RLock() + defer a.mu.RUnlock() + + if a.policies == nil { + return nil, false + } + + return a.policies.Get(policyID) +} + +// GetAllPolicies returns all policies +func (a *Authorizer) GetAllPolicies() []*Policy { + a.mu.RLock() + defer a.mu.RUnlock() + + if a.policies == nil { + return nil + } + + return a.policies.GetAll() +} + +// ClearPolicies removes all policies +func (a *Authorizer) ClearPolicies() { + a.mu.Lock() + defer a.mu.Unlock() + + a.policies = NewPolicySet() +} + +// Clear clears the registry and policies +func (a *Authorizer) Clear() { + a.mu.Lock() + defer a.mu.Unlock() + + if a.registry != nil { + a.registry.Clear() + } + + a.policies = NewPolicySet() +} + +// Stats returns statistics about the authorizer +func (a *Authorizer) Stats() map[string]interface{} { + a.mu.RLock() + defer a.mu.RUnlock() + + stats := make(map[string]interface{}) + + if a.registry != nil { + stats["relation_count"] = a.registry.Count() + } + + if a.policies != nil { + stats["policy_count"] = a.policies.Count() + } + + stats["default_deny"] = a.DefaultDeny + + return stats +} diff --git a/pkg/relation/policy.go b/pkg/relation/policy.go new file mode 100644 index 000000000..f2d68f763 --- /dev/null +++ b/pkg/relation/policy.go @@ -0,0 +1,420 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package relation + +import ( + "fmt" + "strings" +) + +// Action represents an action that can be performed on a resource +type Action string + +const ( + // ActionRead is the read action + ActionRead Action = "read" + + // ActionWrite is the write action + ActionWrite Action = "write" + + // ActionDelete is the delete action + ActionDelete Action = "delete" + + // ActionExecute is the execute action + ActionExecute Action = "execute" + + // ActionAdmin is the admin action + ActionAdmin Action = "admin" + + // ActionShare is the share/delegate action + ActionShare Action = "share" + + // ActionAny is a wildcard for any action + ActionAny Action = "*" +) + +// ActionMapping defines which relation types can perform which actions +// Note: Admin is a read-only management role with delete/admin but NOT write permissions +var ActionMapping = map[Action][]RelationType{ + ActionRead: { + RelationReader, + RelationWriter, + RelationAdmin, + RelationDelegate, + RelationOwner, + RelationMember, // Members can typically read + }, + ActionWrite: { + RelationWriter, + RelationDelegate, + RelationOwner, + // Admin NOT included - Admin is read-only for content + }, + ActionDelete: { + RelationAdmin, + RelationDelegate, + RelationOwner, + }, + ActionExecute: { + RelationExecutor, + RelationAdmin, + RelationDelegate, + RelationOwner, + }, + ActionAdmin: { + RelationAdmin, + RelationDelegate, + RelationOwner, + }, + ActionShare: { + RelationDelegate, + RelationOwner, + }, +} + +// Effect defines the effect of a policy statement +type Effect int + +const ( + // EffectAllow permits the action + EffectAllow Effect = iota + + // EffectDeny denies the action + EffectDeny +) + +// String returns the string representation of the effect +func (e Effect) String() string { + switch e { + case EffectAllow: + return "allow" + case EffectDeny: + return "deny" + default: + return "unknown" + } +} + +// ParseEffect parses a string into an Effect +func ParseEffect(s string) (Effect, error) { + switch strings.ToLower(s) { + case "allow": + return EffectAllow, nil + case "deny": + return EffectDeny, nil + default: + return EffectDeny, fmt.Errorf("invalid effect: %s", s) + } +} + +// Statement is a single policy statement +type Statement struct { + // Effect is either allow or deny + Effect Effect `json:"effect"` + + // Action is the action this statement applies to + Action Action `json:"action"` + + // Resource is the resource this statement applies to + Resource *ResourceID `json:"resource"` + + // Condition is an optional condition that must be met + Condition string `json:"condition,omitempty"` +} + +// NewAllowStatement creates a new allow statement +func NewAllowStatement(action Action, resource *ResourceID) *Statement { + return &Statement{ + Effect: EffectAllow, + Action: action, + Resource: resource, + } +} + +// NewDenyStatement creates a new deny statement +func NewDenyStatement(action Action, resource *ResourceID) *Statement { + return &Statement{ + Effect: EffectDeny, + Action: action, + Resource: resource, + } +} + +// Matches checks if the statement matches the given action and resource +func (s *Statement) Matches(action Action, resource *ResourceID) bool { + if s.Action != ActionAny && action != ActionAny && s.Action != action { + return false + } + if s.Resource != nil && !s.Resource.Matches(resource) { + return false + } + return true +} + +// Policy is a collection of policy statements +type Policy struct { + // ID is a unique identifier for the policy + ID string `json:"id"` + + // Name is a human-readable name + Name string `json:"name,omitempty"` + + // Description describes what this policy does + Description string `json:"description,omitempty"` + + // Statements is the list of policy statements + Statements []*Statement `json:"statements"` + + // Scope limits the policy to specific H-ids (empty = all) + Scope []string `json:"scope,omitempty"` +} + +// NewPolicy creates a new policy +func NewPolicy(id string) *Policy { + return &Policy{ + ID: id, + Statements: make([]*Statement, 0), + Scope: make([]string, 0), + } +} + +// AddStatement adds a statement to the policy +func (p *Policy) AddStatement(stmt *Statement) { + p.Statements = append(p.Statements, stmt) +} + +// AddAllow adds an allow statement to the policy +func (p *Policy) AddAllow(action Action, resource *ResourceID) { + p.AddStatement(NewAllowStatement(action, resource)) +} + +// AddDeny adds a deny statement to the policy +func (p *Policy) AddDeny(action Action, resource *ResourceID) { + p.AddStatement(NewDenyStatement(action, resource)) +} + +// Evaluate evaluates the policy for a given action and resource +func (p *Policy) Evaluate(action Action, resource *ResourceID) Effect { + // Default deny + result := EffectDeny + + for _, stmt := range p.Statements { + if !stmt.Matches(action, resource) { + continue + } + + // Deny takes precedence over allow + if stmt.Effect == EffectDeny { + return EffectDeny + } + + result = EffectAllow + } + + return result +} + +// HasMatch returns true if the policy has any statements that match the given action and resource +func (p *Policy) HasMatch(action Action, resource *ResourceID) bool { + for _, stmt := range p.Statements { + if stmt.Matches(action, resource) { + return true + } + } + return false +} + +// IsInScope checks if the given H-id is in the policy's scope +func (p *Policy) IsInScope(hid string) bool { + if len(p.Scope) == 0 { + return true // No scope restriction + } + + for _, s := range p.Scope { + if s == hid { + return true + } + } + + return false +} + +// Clone creates a deep copy of the policy +func (p *Policy) Clone() *Policy { + clone := &Policy{ + ID: p.ID, + Name: p.Name, + Description: p.Description, + Statements: make([]*Statement, len(p.Statements)), + Scope: make([]string, len(p.Scope)), + } + + copy(clone.Scope, p.Scope) + + for i, stmt := range p.Statements { + clone.Statements[i] = &Statement{ + Effect: stmt.Effect, + Action: stmt.Action, + Resource: stmt.Resource.Clone(), + Condition: stmt.Condition, + } + } + + return clone +} + +// PolicySet is a collection of policies +type PolicySet struct { + policies map[string]*Policy +} + +// NewPolicySet creates a new policy set +func NewPolicySet() *PolicySet { + return &PolicySet{ + policies: make(map[string]*Policy), + } +} + +// Add adds a policy to the set +func (s *PolicySet) Add(policy *Policy) { + s.policies[policy.ID] = policy +} + +// Get retrieves a policy by ID +func (s *PolicySet) Get(id string) (*Policy, bool) { + policy, ok := s.policies[id] + return policy, ok +} + +// Remove removes a policy from the set +func (s *PolicySet) Remove(id string) { + delete(s.policies, id) +} + +// Evaluate evaluates all policies for a given action, resource, and H-id +func (s *PolicySet) Evaluate(hid string, action Action, resource *ResourceID) Effect { + // Default deny + result := EffectDeny + + for _, policy := range s.policies { + // Check scope + if !policy.IsInScope(hid) { + continue + } + + // Skip policies that don't have matching statements for this action/resource + if !policy.HasMatch(action, resource) { + continue + } + + effect := policy.Evaluate(action, resource) + + // Deny takes precedence + if effect == EffectDeny { + return EffectDeny + } + + result = EffectAllow + } + + return result +} + +// GetAll returns all policies +func (s *PolicySet) GetAll() []*Policy { + policies := make([]*Policy, 0, len(s.policies)) + for _, p := range s.policies { + policies = append(policies, p) + } + return policies +} + +// Count returns the number of policies +func (s *PolicySet) Count() int { + return len(s.policies) +} + +// Clear removes all policies +func (s *PolicySet) Clear() { + s.policies = make(map[string]*Policy) +} + +// DefaultPolicies returns the default policy set +func DefaultPolicies() *PolicySet { + ps := NewPolicySet() + + // Owner can do anything on their own resources + // Note: This is checked via isOwner() in Authorizer, not via policies + // Default policies should only allow access when explicitly granted via relations + // So we don't add a blanket allow policy here + + return ps +} + +// GetRequiredRelationType returns the minimum relation type required for an action +func GetRequiredRelationType(action Action) RelationType { + switch action { + case ActionRead: + return RelationReader + case ActionWrite: + return RelationWriter + case ActionDelete: + return RelationAdmin + case ActionExecute: + return RelationExecutor + case ActionAdmin: + return RelationAdmin + case ActionShare: + return RelationDelegate + default: + return RelationOwner + } +} + +// ActionAllowedByRelation checks if an action is allowed by a given relation type +func ActionAllowedByRelation(action Action, relType RelationType) bool { + allowedTypes, ok := ActionMapping[action] + if !ok { + return false + } + + for _, allowed := range allowedTypes { + if allowed == relType { + return true + } + } + + return false +} + +// Built-in policies + +// OwnerPolicy creates a policy that gives owners full access +func OwnerPolicy() *Policy { + p := NewPolicy("owner-policy") + p.Description = "Owners have full access to their own resources" + p.AddAllow(ActionAny, &ResourceID{HID: "$owner", ID: "*"}) + return p +} + +// ReaderPolicy creates a policy that gives read-only access +func ReaderPolicy() *Policy { + p := NewPolicy("reader-policy") + p.Description = "Readers can read resources" + p.AddAllow(ActionRead, &ResourceID{ID: "*"}) + return p +} + +// AdminPolicy creates a policy that gives admin access +func AdminPolicy() *Policy { + p := NewPolicy("admin-policy") + p.Description = "Admins have administrative access" + p.AddAllow(ActionRead, &ResourceID{ID: "*"}) + p.AddAllow(ActionWrite, &ResourceID{ID: "*"}) + p.AddAllow(ActionDelete, &ResourceID{ID: "*"}) + p.AddAllow(ActionAdmin, &ResourceID{ID: "*"}) + return p +} diff --git a/pkg/relation/registry.go b/pkg/relation/registry.go new file mode 100644 index 000000000..0b5147a71 --- /dev/null +++ b/pkg/relation/registry.go @@ -0,0 +1,445 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package relation + +import ( + "fmt" + "sync" + "time" +) + +// Registry manages relations between identities and resources +type Registry struct { + mu sync.RWMutex + relations *RelationSet + + // Index for faster lookups + bySubjectResource map[string]*Relation // key: "subject:resource" + + // Change tracking + lastModified int64 +} + +// NewRegistry creates a new relation registry +func NewRegistry() *Registry { + return &Registry{ + relations: NewRelationSet(), + bySubjectResource: make(map[string]*Relation), + lastModified: time.Now().UnixMilli(), + } +} + +// Add adds a relation to the registry +func (r *Registry) Add(rel *Relation) error { + if rel == nil { + return fmt.Errorf("cannot add nil relation") + } + + r.mu.Lock() + defer r.mu.Unlock() + + // Check for duplicates + key := r.relationKey(rel.SubjectHID, rel.SubjectSID, rel.Resource, rel.RelationType) + if _, exists := r.bySubjectResource[key]; exists { + return fmt.Errorf("relation already exists: %s", key) + } + + r.relations.Add(rel) + r.bySubjectResource[key] = rel + r.lastModified = time.Now().UnixMilli() + + return nil +} + +// Update updates an existing relation +func (r *Registry) Update(rel *Relation) error { + if rel == nil { + return fmt.Errorf("cannot update nil relation") + } + + r.mu.Lock() + defer r.mu.Unlock() + + newKey := r.relationKey(rel.SubjectHID, rel.SubjectSID, rel.Resource, rel.RelationType) + + // Try to find the relation with any type for the same subject-resource pair + oldRel, found := r.findBySubjectResource(rel.SubjectHID, rel.SubjectSID, rel.Resource) + if !found { + return fmt.Errorf("relation not found for subject-resource pair") + } + + oldKey := r.relationKey(oldRel.SubjectHID, oldRel.SubjectSID, oldRel.Resource, oldRel.RelationType) + + // Remove old relation and add new one + r.relations.Remove(oldRel.SubjectHID, oldRel.SubjectSID, oldRel.Resource, oldRel.RelationType) + delete(r.bySubjectResource, oldKey) + + r.relations.Add(rel) + r.bySubjectResource[newKey] = rel + r.lastModified = time.Now().UnixMilli() + + return nil +} + +// AddOrUpdate adds a relation or updates if it exists +func (r *Registry) AddOrUpdate(rel *Relation) error { + if rel == nil { + return fmt.Errorf("cannot add nil relation") + } + + r.mu.Lock() + defer r.mu.Unlock() + + key := r.relationKey(rel.SubjectHID, rel.SubjectSID, rel.Resource, rel.RelationType) + + // Check if relation exists + if _, exists := r.bySubjectResource[key]; exists { + // Update + r.relations.Remove(rel.SubjectHID, rel.SubjectSID, rel.Resource, RelationAny) + delete(r.bySubjectResource, key) + } + + r.relations.Add(rel) + r.bySubjectResource[key] = rel + r.lastModified = time.Now().UnixMilli() + + return nil +} + +// Remove removes a relation from the registry +func (r *Registry) Remove(subjectHID, subjectSID string, resource *ResourceID, relType RelationType) error { + r.mu.Lock() + defer r.mu.Unlock() + + removed := r.relations.Remove(subjectHID, subjectSID, resource, relType) + + // Update index + key := r.relationKey(subjectHID, subjectSID, resource, relType) + if removed > 0 { + delete(r.bySubjectResource, key) + r.lastModified = time.Now().UnixMilli() + } + + return nil +} + +// Get retrieves a specific relation +func (r *Registry) Get(subjectHID, subjectSID string, resource *ResourceID, relType RelationType) (*Relation, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + rels := r.relations.Find(subjectHID, subjectSID, resource, relType) + if len(rels) == 0 { + return nil, false + } + return rels[0], true +} + +// GetBySubject retrieves all relations for a subject +func (r *Registry) GetBySubject(subjectHID, subjectSID string) []*Relation { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.relations.FindBySubject(subjectHID, subjectSID) +} + +// GetByResource retrieves all relations for a resource +func (r *Registry) GetByResource(resource *ResourceID) []*Relation { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.relations.FindByResource(resource) +} + +// GetByType retrieves all relations of a specific type +func (r *Registry) GetByType(relType RelationType) []*Relation { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.relations.FindByType(relType) +} + +// HasRelation checks if any relation exists between subject and resource +func (r *Registry) HasRelation(subjectHID string, resource *ResourceID) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.relations.HasAnyRelation(subjectHID, resource) +} + +// GetPrivilege gets the highest privilege relation type for a subject on a resource +func (r *Registry) GetPrivilege(subjectHID, subjectSID string, resource *ResourceID) RelationType { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.relations.GetHighestPrivilege(subjectHID, subjectSID, resource) +} + +// HasPrivilege checks if a subject has at least the required privilege on a resource +func (r *Registry) HasPrivilege(subjectHID, subjectSID string, resource *ResourceID, required RelationType) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + privilege := r.relations.GetHighestPrivilege(subjectHID, subjectSID, resource) + return HasPrivilege(privilege, required) +} + +// GetAll returns all relations in the registry +func (r *Registry) GetAll() []*Relation { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.relations.All() +} + +// Count returns the total number of relations +func (r *Registry) Count() int { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.relations.Len() +} + +// Clear removes all relations from the registry +func (r *Registry) Clear() { + r.mu.Lock() + defer r.mu.Unlock() + + r.relations.Clear() + r.bySubjectResource = make(map[string]*Relation) + r.lastModified = time.Now().UnixMilli() +} + +// LastModified returns the last modified timestamp +func (r *Registry) LastModified() int64 { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.lastModified +} + +// GetByOwner retrieves all relations where the subject is the owner +func (r *Registry) GetByOwner(ownerHID string) []*Relation { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.relations.FindByType(RelationOwner) +} + +// GetResourcesBySubject retrieves all resources that a subject has relations with +func (r *Registry) GetResourcesBySubject(subjectHID, subjectSID string) []*ResourceID { + r.mu.RLock() + defer r.mu.RUnlock() + + rels := r.relations.FindBySubject(subjectHID, subjectSID) + resources := make([]*ResourceID, 0, len(rels)) + seen := make(map[string]bool) + + for _, rel := range rels { + key := rel.Resource.String() + if !seen[key] { + seen[key] = true + resources = append(resources, rel.Resource) + } + } + + return resources +} + +// GetSubjectsByResource retrieves all subjects that have relations with a resource +func (r *Registry) GetSubjectsByResource(resource *ResourceID) []string { + r.mu.RLock() + defer r.mu.RUnlock() + + rels := r.relations.FindByResource(resource) + subjects := make([]string, 0, len(rels)) + seen := make(map[string]bool) + + for _, rel := range rels { + key := rel.SubjectHID + if rel.SubjectSID != "" { + key = rel.SubjectHID + "/" + rel.SubjectSID + } + if !seen[key] { + seen[key] = true + subjects = append(subjects, key) + } + } + + return subjects +} + +// Clone creates a deep copy of the registry +func (r *Registry) Clone() *Registry { + r.mu.RLock() + defer r.mu.RUnlock() + + clone := NewRegistry() + for _, rel := range r.relations.All() { + clone.Add(rel.Clone()) + } + + return clone +} + +// Merge merges another registry into this one +func (r *Registry) Merge(other *Registry) error { + if other == nil { + return nil + } + + r.mu.Lock() + defer r.mu.Unlock() + + for _, rel := range other.GetAll() { + key := r.relationKey(rel.SubjectHID, rel.SubjectSID, rel.Resource, rel.RelationType) + + // Skip if already exists + if _, exists := r.bySubjectResource[key]; exists { + continue + } + + r.relations.Add(rel.Clone()) + r.bySubjectResource[key] = rel + } + + r.lastModified = time.Now().UnixMilli() + return nil +} + +// relationKey creates a unique key for a subject-resource-relation tuple +func (r *Registry) relationKey(subjectHID, subjectSID string, resource *ResourceID, relType RelationType) string { + subject := subjectHID + if subjectSID != "" { + subject = subjectHID + "/" + subjectSID + } + return subject + ":" + resource.String() + ":" + string(relType) +} + +// findBySubjectResource finds a relation for the given subject-resource pair (any type) +func (r *Registry) findBySubjectResource(subjectHID, subjectSID string, resource *ResourceID) (*Relation, bool) { + for _, rel := range r.bySubjectResource { + if rel.SubjectHID == subjectHID && rel.SubjectSID == subjectSID && rel.Resource.Matches(resource) { + return rel, true + } + } + return nil, false +} + +// RelationFilter is used to filter relations +type RelationFilter struct { + SubjectHID string + SubjectSID string + ResourceType ResourceType + ResourceHID string + ResourceID string + RelationType RelationType +} + +// Apply applies the filter to the registry +func (r *Registry) Apply(filter *RelationFilter) []*Relation { + r.mu.RLock() + defer r.mu.RUnlock() + + all := r.relations.All() + results := make([]*Relation, 0) + + for _, rel := range all { + if filter.SubjectHID != "" && rel.SubjectHID != filter.SubjectHID { + continue + } + if filter.SubjectSID != "" && rel.SubjectSID != filter.SubjectSID { + continue + } + if filter.ResourceType != "" && filter.ResourceType != ResourceAny && rel.Resource.Type != filter.ResourceType { + continue + } + if filter.ResourceHID != "" && rel.Resource.HID != filter.ResourceHID { + continue + } + if filter.ResourceID != "" && rel.Resource.ID != filter.ResourceID { + continue + } + if filter.RelationType != "" && filter.RelationType != RelationAny && rel.RelationType != filter.RelationType { + continue + } + + results = append(results, rel) + } + + return results +} + +// Export exports the registry as a map +func (r *Registry) Export() map[string]interface{} { + r.mu.RLock() + defer r.mu.RUnlock() + + data := make(map[string]interface{}) + data["last_modified"] = r.lastModified + data["count"] = r.relations.Len() + + relations := make([]map[string]interface{}, 0, r.relations.Len()) + for _, rel := range r.relations.All() { + relData := map[string]interface{}{ + "subject_hid": rel.SubjectHID, + "subject_sid": rel.SubjectSID, + "resource_type": rel.Resource.Type, + "resource_hid": rel.Resource.HID, + "resource_id": rel.Resource.ID, + "resource_namespace": rel.Resource.Namespace, + "relation_type": rel.RelationType, + "attributes": rel.Attributes, + } + relations = append(relations, relData) + } + data["relations"] = relations + + return data +} + +// Import imports relations from a map +func (r *Registry) Import(data map[string]interface{}) error { + r.mu.Lock() + defer r.mu.Unlock() + + relations, ok := data["relations"].([]interface{}) + if !ok { + return fmt.Errorf("invalid import data: missing relations") + } + + for _, relData := range relations { + relMap, ok := relData.(map[string]interface{}) + if !ok { + continue + } + + subjectHID, _ := relMap["subject_hid"].(string) + subjectSID, _ := relMap["subject_sid"].(string) + resourceType, _ := relMap["resource_type"].(string) + resourceHID, _ := relMap["resource_hid"].(string) + resourceID, _ := relMap["resource_id"].(string) + resourceNamespace, _ := relMap["resource_namespace"].(string) + relationType, _ := relMap["relation_type"].(string) + + rel := NewRelationWithSID( + subjectHID, + subjectSID, + &ResourceID{ + Type: ResourceType(resourceType), + HID: resourceHID, + ID: resourceID, + Namespace: resourceNamespace, + }, + RelationType(relationType), + ) + + r.relations.Add(rel) + } + + r.lastModified = time.Now().UnixMilli() + return nil +} diff --git a/pkg/relation/relation.go b/pkg/relation/relation.go new file mode 100644 index 000000000..a8359c321 --- /dev/null +++ b/pkg/relation/relation.go @@ -0,0 +1,457 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package relation + +import ( + "fmt" +) + +// RelationType defines the type of relationship +type RelationType string + +const ( + // RelationOwner indicates ownership + RelationOwner RelationType = "owner" + + // RelationMember indicates membership + RelationMember RelationType = "member" + + // RelationReader indicates read access + RelationReader RelationType = "reader" + + // RelationWriter indicates write access + RelationWriter RelationType = "writer" + + // RelationAdmin indicates administrative access + RelationAdmin RelationType = "admin" + + // RelationExecutor indicates execution permission + RelationExecutor RelationType = "executor" + + // RelationDelegate indicates delegation permission + RelationDelegate RelationType = "delegate" + + // RelationAny is a wildcard for any relation type + RelationAny RelationType = "*" +) + +// Relation represents a relationship between an identity and a resource +type Relation struct { + // SubjectHID is the H-id of the subject (who has the relation) + SubjectHID string `json:"subject_hid"` + + // SubjectSID is the optional S-id of the subject + SubjectSID string `json:"subject_sid,omitempty"` + + // Resource is the resource being related to + Resource *ResourceID `json:"resource"` + + // RelationType is the type of relation + RelationType RelationType `json:"relation_type"` + + // Attributes contains additional relation attributes + Attributes map[string]string `json:"attributes,omitempty"` +} + +// NewRelation creates a new relation +func NewRelation(subjectHID string, resource *ResourceID, relType RelationType) *Relation { + return &Relation{ + SubjectHID: subjectHID, + Resource: resource, + RelationType: relType, + Attributes: make(map[string]string), + } +} + +// NewRelationWithSID creates a new relation with a specific S-id +func NewRelationWithSID(subjectHID, subjectSID string, resource *ResourceID, relType RelationType) *Relation { + return &Relation{ + SubjectHID: subjectHID, + SubjectSID: subjectSID, + Resource: resource, + RelationType: relType, + Attributes: make(map[string]string), + } +} + +// String returns the string representation of the relation +func (r *Relation) String() string { + subject := r.SubjectHID + if r.SubjectSID != "" { + subject = r.SubjectHID + "/" + r.SubjectSID + } + return fmt.Sprintf("%s %s %s", subject, r.RelationType, r.Resource.String()) +} + +// Matches checks if this relation matches the given criteria +func (r *Relation) Matches(subjectHID, subjectSID string, resource *ResourceID, relType RelationType) bool { + if r.SubjectHID != "" && subjectHID != "" && r.SubjectHID != subjectHID { + return false + } + if r.SubjectSID != "" && subjectSID != "" && r.SubjectSID != subjectSID { + return false + } + if r.RelationType != RelationAny && relType != RelationAny && r.RelationType != relType { + return false + } + if resource != nil && !r.Resource.Matches(resource) { + return false + } + return true +} + +// IsDirect returns true if the relation is for a specific S-id +func (r *Relation) IsDirect() bool { + return r.SubjectSID != "" +} + +// IsTenant returns true if the relation is for an entire H-id +func (r *Relation) IsTenant() bool { + return r.SubjectSID == "" +} + +// GetAttribute returns an attribute value +func (r *Relation) GetAttribute(key string) (string, bool) { + if r.Attributes == nil { + return "", false + } + v, ok := r.Attributes[key] + return v, ok +} + +// SetAttribute sets an attribute value +func (r *Relation) SetAttribute(key, value string) { + if r.Attributes == nil { + r.Attributes = make(map[string]string) + } + r.Attributes[key] = value +} + +// Clone returns a copy of the relation +func (r *Relation) Clone() *Relation { + if r == nil { + return nil + } + + clone := &Relation{ + SubjectHID: r.SubjectHID, + SubjectSID: r.SubjectSID, + Resource: r.Resource.Clone(), + RelationType: r.RelationType, + } + + if r.Attributes != nil { + clone.Attributes = make(map[string]string, len(r.Attributes)) + for k, v := range r.Attributes { + clone.Attributes[k] = v + } + } + + return clone +} + +// RelationTypeHierarchy defines the hierarchy of relation types +// Higher values indicate more privilege +var RelationTypeHierarchy = map[RelationType]int{ + RelationReader: 1, + RelationMember: 2, + RelationWriter: 3, + RelationExecutor: 4, + RelationAdmin: 5, + RelationDelegate: 6, + RelationOwner: 7, +} + +// RelationActionPermissions defines which actions each relation type can perform +// This allows for non-linear permission models where e.g., Admin has read+delete but not write +var RelationActionPermissions = map[RelationType]map[Action]bool{ + RelationReader: { + ActionRead: true, + }, + RelationWriter: { + ActionRead: true, + ActionWrite: true, + }, + RelationAdmin: { + ActionRead: true, + ActionDelete: true, + ActionShare: true, + ActionAdmin: true, + }, + RelationExecutor: { + ActionRead: true, + ActionExecute: true, + }, + RelationDelegate: { + ActionRead: true, + ActionWrite: true, + ActionDelete: true, + ActionShare: true, + ActionAdmin: true, + }, + RelationOwner: { + ActionRead: true, + ActionWrite: true, + ActionDelete: true, + ActionShare: true, + ActionExecute: true, + ActionAdmin: true, + }, +} + +// HasPermission checks if a relation type has a specific action permission +func HasPermission(relType RelationType, action Action) bool { + if relType == RelationOwner { + return true // Owner has all permissions + } + perms, ok := RelationActionPermissions[relType] + if !ok { + return false + } + return perms[action] +} + +// HasPrivilege returns true if the given relation type has at least the required privilege level +func HasPrivilege(has, requires RelationType) bool { + if has == RelationOwner { + return true // Owner has all privileges + } + if requires == RelationOwner { + return false // Only owner has owner privilege + } + + hasLevel, ok := RelationTypeHierarchy[has] + if !ok { + return false + } + + requiredLevel, ok := RelationTypeHierarchy[requires] + if !ok { + return false + } + + return hasLevel >= requiredLevel +} + +// IsAtLeast returns true if this relation type is at least as privileged as the other +func (r RelationType) IsAtLeast(other RelationType) bool { + return HasPrivilege(r, other) +} + +// IsAtMost returns true if this relation type is at most as privileged as the other +func (r RelationType) IsAtMost(other RelationType) bool { + return HasPrivilege(other, r) +} + +// CommonRelationTypes returns relation types that typically imply read access +func ReadRelationTypes() []RelationType { + return []RelationType{ + RelationReader, + RelationWriter, + RelationAdmin, + RelationDelegate, + RelationOwner, + } +} + +// WriteRelationTypes returns relation types that typically imply write access +func WriteRelationTypes() []RelationType { + return []RelationType{ + RelationWriter, + RelationAdmin, + RelationDelegate, + RelationOwner, + } +} + +// AdminRelationTypes returns relation types with administrative privileges +func AdminRelationTypes() []RelationType { + return []RelationType{ + RelationAdmin, + RelationDelegate, + RelationOwner, + } +} + +// RelationSet is a collection of relations with efficient lookup +type RelationSet struct { + relations []*Relation + // indexes for fast lookup + bySubject map[string][]*Relation + byResource map[string][]*Relation + byType map[RelationType][]*Relation +} + +// NewRelationSet creates a new relation set +func NewRelationSet() *RelationSet { + return &RelationSet{ + relations: make([]*Relation, 0), + bySubject: make(map[string][]*Relation), + byResource: make(map[string][]*Relation), + byType: make(map[RelationType][]*Relation), + } +} + +// Add adds a relation to the set +func (s *RelationSet) Add(rel *Relation) { + s.relations = append(s.relations, rel) + + // Update indexes + subjectKey := rel.SubjectHID + if rel.SubjectSID != "" { + subjectKey = rel.SubjectHID + "/" + rel.SubjectSID + } + s.bySubject[subjectKey] = append(s.bySubject[subjectKey], rel) + + resourceKey := rel.Resource.String() + s.byResource[resourceKey] = append(s.byResource[resourceKey], rel) + + s.byType[rel.RelationType] = append(s.byType[rel.RelationType], rel) +} + +// Remove removes relations matching the given criteria +func (s *RelationSet) Remove(subjectHID, subjectSID string, resource *ResourceID, relType RelationType) int { + removed := 0 + newRelations := make([]*Relation, 0, len(s.relations)) + + // Clear indexes + s.bySubject = make(map[string][]*Relation) + s.byResource = make(map[string][]*Relation) + s.byType = make(map[RelationType][]*Relation) + + for _, rel := range s.relations { + if rel.Matches(subjectHID, subjectSID, resource, relType) { + removed++ + continue + } + newRelations = append(newRelations, rel) + + // Rebuild indexes + subjectKey := rel.SubjectHID + if rel.SubjectSID != "" { + subjectKey = rel.SubjectHID + "/" + rel.SubjectSID + } + s.bySubject[subjectKey] = append(s.bySubject[subjectKey], rel) + + resourceKey := rel.Resource.String() + s.byResource[resourceKey] = append(s.byResource[resourceKey], rel) + + s.byType[rel.RelationType] = append(s.byType[rel.RelationType], rel) + } + + s.relations = newRelations + return removed +} + +// FindBySubject finds all relations for a given subject +func (s *RelationSet) FindBySubject(subjectHID, subjectSID string) []*Relation { + key := subjectHID + if subjectSID != "" { + key = subjectHID + "/" + subjectSID + } + + if rels, ok := s.bySubject[key]; ok { + return rels + } + + return nil +} + +// FindByResource finds all relations for a given resource +func (s *RelationSet) FindByResource(resource *ResourceID) []*Relation { + if rels, ok := s.byResource[resource.String()]; ok { + return rels + } + return nil +} + +// FindByType finds all relations of a given type +func (s *RelationSet) FindByType(relType RelationType) []*Relation { + if rels, ok := s.byType[relType]; ok { + return rels + } + return nil +} + +// Find finds relations matching all given criteria +func (s *RelationSet) Find(subjectHID, subjectSID string, resource *ResourceID, relType RelationType) []*Relation { + results := make([]*Relation, 0) + + for _, rel := range s.relations { + if rel.Matches(subjectHID, subjectSID, resource, relType) { + results = append(results, rel) + } + } + + return results +} + +// HasAnyRelation returns true if the subject has any relation to the resource +func (s *RelationSet) HasAnyRelation(subjectHID string, resource *ResourceID) bool { + for _, rel := range s.relations { + if rel.SubjectHID == subjectHID && rel.Resource.Matches(resource) { + return true + } + } + return false +} + +// GetHighestPrivilege returns the highest privilege relation type for a subject on a resource +func (s *RelationSet) GetHighestPrivilege(subjectHID, subjectSID string, resource *ResourceID) RelationType { + highest := RelationType("") + + for _, rel := range s.relations { + if !rel.Matches(subjectHID, subjectSID, resource, RelationAny) { + continue + } + + if highest == "" || HasPrivilege(rel.RelationType, highest) { + highest = rel.RelationType + } + } + + return highest +} + +// All returns all relations in the set +func (s *RelationSet) All() []*Relation { + return s.relations +} + +// Len returns the number of relations in the set +func (s *RelationSet) Len() int { + return len(s.relations) +} + +// Clear removes all relations from the set +func (s *RelationSet) Clear() { + s.relations = make([]*Relation, 0) + s.bySubject = make(map[string][]*Relation) + s.byResource = make(map[string][]*Relation) + s.byType = make(map[RelationType][]*Relation) +} + +// Clone creates a deep copy of the relation set +func (s *RelationSet) Clone() *RelationSet { + clone := NewRelationSet() + + for _, rel := range s.relations { + clone.Add(rel.Clone()) + } + + return clone +} + +// Merge merges another relation set into this one +func (s *RelationSet) Merge(other *RelationSet) { + if other == nil { + return + } + + for _, rel := range other.All() { + s.Add(rel.Clone()) + } +} diff --git a/pkg/relation/relation_test.go b/pkg/relation/relation_test.go new file mode 100644 index 000000000..3bbaad8d1 --- /dev/null +++ b/pkg/relation/relation_test.go @@ -0,0 +1,760 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package relation + +import ( + "testing" +) + +func TestResourceID_String(t *testing.T) { + tests := []struct { + name string + id *ResourceID + expect string + }{ + { + name: "minimal", + id: NewResourceID(ResourceMemory, "mem-123"), + expect: "memory:mem-123", + }, + { + name: "with HID", + id: NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123"), + expect: "memory:user-alice:mem-123", + }, + { + name: "with namespace", + id: &ResourceID{Type: ResourceMemory, HID: "user-alice", ID: "mem-123", Namespace: "default"}, + expect: "memory:user-alice:default:mem-123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.id.String(); got != tt.expect { + t.Errorf("ResourceID.String() = %v, want %v", got, tt.expect) + } + }) + } +} + +func TestParseResourceID(t *testing.T) { + tests := []struct { + name string + s string + wantType string + wantHID string + wantID string + wantErr bool + }{ + { + name: "minimal", + s: "memory:mem-123", + wantType: "memory", + wantHID: "", + wantID: "mem-123", + wantErr: false, + }, + { + name: "with HID", + s: "memory:user-alice:mem-123", + wantType: "memory", + wantHID: "user-alice", + wantID: "mem-123", + wantErr: false, + }, + { + name: "invalid", + s: "memory", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseResourceID(tt.s) + if tt.wantErr { + if err == nil { + t.Errorf("ParseResourceID() expected error, got nil") + } + return + } + if err != nil { + t.Errorf("ParseResourceID() error = %v", err) + return + } + if got.Type != ResourceType(tt.wantType) { + t.Errorf("Type = %v, want %v", got.Type, tt.wantType) + } + if got.HID != tt.wantHID { + t.Errorf("HID = %v, want %v", got.HID, tt.wantHID) + } + if got.ID != tt.wantID { + t.Errorf("ID = %v, want %v", got.ID, tt.wantID) + } + }) + } +} + +func TestResourceID_Matches(t *testing.T) { + wildcard := NewResourceID(ResourceAny, "*") + specific := NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123") + + tests := []struct { + name string + a *ResourceID + b *ResourceID + expected bool + }{ + {"exact match", specific, specific, true}, + {"wildcard matches specific", wildcard, specific, true}, + {"specific matches wildcard", specific, wildcard, false}, + {"different type", specific, NewResourceID(ResourceNode, "node-01"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.a.Matches(tt.b); got != tt.expected { + t.Errorf("ResourceID.Matches() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestRelation_String(t *testing.T) { + rel := NewRelation( + "user-alice", + NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123"), + RelationReader, + ) + + want := "user-alice reader memory:user-alice:mem-123" + if got := rel.String(); got != want { + t.Errorf("Relation.String() = %v, want %v", got, want) + } +} + +func TestRelation_Clone(t *testing.T) { + original := NewRelation( + "user-alice", + NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123"), + RelationReader, + ) + original.SetAttribute("key1", "value1") + + clone := original.Clone() + + if clone.SubjectHID != original.SubjectHID { + t.Errorf("Clone.SubjectHID mismatch") + } + if clone.RelationType != original.RelationType { + t.Errorf("Clone.RelationType mismatch") + } + + // Verify attribute was copied + v, ok := clone.GetAttribute("key1") + if !ok || v != "value1" { + t.Errorf("Clone attributes not copied") + } + + // Modify clone and ensure original is unchanged + clone.SetAttribute("key2", "value2") + _, ok = original.GetAttribute("key2") + if ok { + t.Errorf("Modifying clone affected original") + } +} + +func TestRelationTypeHierarchy(t *testing.T) { + tests := []struct { + has RelationType + requires RelationType + expected bool + }{ + {RelationOwner, RelationReader, true}, + {RelationOwner, RelationWriter, true}, + {RelationOwner, RelationAdmin, true}, + {RelationAdmin, RelationReader, true}, + // Note: Admin has higher hierarchy level than Writer, but doesn't have write permission + // This is checked via ActionAllowedByRelation, not HasPrivilege + {RelationReader, RelationWriter, false}, + {RelationWriter, RelationAdmin, false}, + } + + for _, tt := range tests { + t.Run(string(tt.has)+"_has_"+string(tt.requires), func(t *testing.T) { + if got := HasPrivilege(tt.has, tt.requires); got != tt.expected { + t.Errorf("HasPrivilege(%s, %s) = %v, want %v", tt.has, tt.requires, got, tt.expected) + } + }) + } +} + +func TestNewRelationSet(t *testing.T) { + rs := NewRelationSet() + + if rs.Len() != 0 { + t.Errorf("NewRelationSet should be empty") + } + + rel := NewRelation("user-alice", NewResourceID(ResourceMemory, "mem-123"), RelationReader) + rs.Add(rel) + + if rs.Len() != 1 { + t.Errorf("Expected 1 relation, got %d", rs.Len()) + } +} + +func TestRelationSet_FindBySubject(t *testing.T) { + rs := NewRelationSet() + + rel1 := NewRelation("user-alice", NewResourceID(ResourceMemory, "mem-1"), RelationReader) + rel2 := NewRelation("user-alice", NewResourceID(ResourceMemory, "mem-2"), RelationWriter) + rel3 := NewRelation("user-bob", NewResourceID(ResourceMemory, "mem-3"), RelationReader) + + rs.Add(rel1) + rs.Add(rel2) + rs.Add(rel3) + + // Find user-alice's relations + relations := rs.FindBySubject("user-alice", "") + if len(relations) != 2 { + t.Errorf("Expected 2 relations for user-alice, got %d", len(relations)) + } + + // Find specific S-id relations + relations = rs.FindBySubject("user-alice", "node-01") + if len(relations) != 0 { + t.Errorf("Expected 0 relations for user-alice/node-01, got %d", len(relations)) + } +} + +func TestRelationSet_Remove(t *testing.T) { + rs := NewRelationSet() + + rel := NewRelation("user-alice", NewResourceID(ResourceMemory, "mem-123"), RelationReader) + rs.Add(rel) + + if rs.Len() != 1 { + t.Errorf("Expected 1 relation, got %d", rs.Len()) + } + + // Remove the relation + removed := rs.Remove("user-alice", "", rel.Resource, RelationReader) + if removed != 1 { + t.Errorf("Expected to remove 1 relation, got %d", removed) + } + + if rs.Len() != 0 { + t.Errorf("Expected 0 relations after removal, got %d", rs.Len()) + } +} + +func TestRelationSet_Merge(t *testing.T) { + rs1 := NewRelationSet() + rs2 := NewRelationSet() + + rel1 := NewRelation("user-alice", NewResourceID(ResourceMemory, "mem-1"), RelationReader) + rel2 := NewRelation("user-bob", NewResourceID(ResourceMemory, "mem-2"), RelationWriter) + + rs1.Add(rel1) + rs2.Add(rel2) + + rs1.Merge(rs2) + + if rs1.Len() != 2 { + t.Errorf("Expected 2 relations after merge, got %d", rs1.Len()) + } +} + +func TestRegistry_Add(t *testing.T) { + registry := NewRegistry() + + rel := NewRelation("user-alice", NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123"), RelationReader) + + err := registry.Add(rel) + if err != nil { + t.Errorf("Add() error = %v", err) + } + + if registry.Count() != 1 { + t.Errorf("Expected 1 relation, got %d", registry.Count()) + } + + // Duplicate should fail + err = registry.Add(rel) + if err == nil { + t.Errorf("Expected error for duplicate relation") + } +} + +func TestRegistry_Update(t *testing.T) { + registry := NewRegistry() + + rel := NewRelation("user-alice", NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123"), RelationReader) + registry.Add(rel) + + // Update to writer + rel.RelationType = RelationWriter + err := registry.Update(rel) + if err != nil { + t.Errorf("Update() error = %v", err) + } + + // Check updated relation + privilege := registry.GetPrivilege("user-alice", "", rel.Resource) + if privilege != RelationWriter { + t.Errorf("Expected RelationWriter after update, got %s", privilege) + } +} + +func TestRegistry_GetPrivilege(t *testing.T) { + registry := NewRegistry() + + rel1 := NewRelation("user-alice", NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-1"), RelationReader) + rel2 := NewRelation("user-alice", NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-1"), RelationWriter) + + registry.Add(rel1) + registry.Add(rel2) + + // Get highest privilege + privilege := registry.GetPrivilege("user-alice", "", rel1.Resource) + if privilege != RelationWriter { + t.Errorf("Expected RelationWriter, got %s", privilege) + } +} + +func TestRegistry_HasPrivilege(t *testing.T) { + registry := NewRegistry() + + rel := NewRelation("user-alice", NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-1"), RelationWriter) + registry.Add(rel) + + // Check has privilege + if !registry.HasPrivilege("user-alice", "", rel.Resource, RelationReader) { + t.Errorf("Expected to have reader privilege") + } + + if !registry.HasPrivilege("user-alice", "", rel.Resource, RelationWriter) { + t.Errorf("Expected to have writer privilege") + } + + // Should not have admin privilege + if registry.HasPrivilege("user-alice", "", rel.Resource, RelationAdmin) { + t.Errorf("Expected to not have admin privilege") + } +} + +func TestRegistry_Clear(t *testing.T) { + registry := NewRegistry() + + rel := NewRelation("user-alice", NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-1"), RelationReader) + registry.Add(rel) + + if registry.Count() != 1 { + t.Errorf("Expected 1 relation, got %d", registry.Count()) + } + + registry.Clear() + + if registry.Count() != 0 { + t.Errorf("Expected 0 relations after clear, got %d", registry.Count()) + } +} + +func TestStatement_Matches(t *testing.T) { + resource := NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123") + + tests := []struct { + name string + stmt *Statement + action Action + resource *ResourceID + expected bool + }{ + { + name: "exact match", + stmt: NewAllowStatement(ActionRead, resource), + action: ActionRead, + resource: resource, + expected: true, + }, + { + name: "action mismatch", + stmt: NewAllowStatement(ActionRead, resource), + action: ActionWrite, + resource: resource, + expected: false, + }, + { + name: "wildcard action", + stmt: &Statement{Effect: EffectAllow, Action: ActionAny, Resource: resource}, + action: ActionRead, + resource: resource, + expected: true, + }, + { + name: "wildcard resource", + stmt: &Statement{Effect: EffectAllow, Action: ActionRead, Resource: NewResourceID(ResourceAny, "*")}, + action: ActionRead, + resource: resource, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.stmt.Matches(tt.action, tt.resource); got != tt.expected { + t.Errorf("Statement.Matches() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestPolicy_Evaluate(t *testing.T) { + policy := NewPolicy("test-policy") + + resource := NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123") + + // Add allow statement + policy.AddAllow(ActionRead, resource) + + effect := policy.Evaluate(ActionRead, resource) + if effect != EffectAllow { + t.Errorf("Expected allow for read, got %s", effect) + } + + effect = policy.Evaluate(ActionWrite, resource) + if effect != EffectDeny { + t.Errorf("Expected deny for write, got %s", effect) + } + + // Add deny statement + policy.AddDeny(ActionWrite, resource) + + effect = policy.Evaluate(ActionWrite, resource) + if effect != EffectDeny { + t.Errorf("Expected deny (explicit), got %s", effect) + } + + // Deny should take precedence over allow + policy.AddAllow(ActionWrite, resource) + effect = policy.Evaluate(ActionWrite, resource) + if effect != EffectDeny { + t.Errorf("Expected deny (deny takes precedence), got %s", effect) + } +} + +func TestPolicy_IsInScope(t *testing.T) { + policy := NewPolicy("test-policy") + policy.Scope = []string{"user-alice", "user-bob"} + + if !policy.IsInScope("user-alice") { + t.Errorf("Expected user-alice in scope") + } + + if policy.IsInScope("user-charlie") { + t.Errorf("Expected user-charlie not in scope") + } +} + +func TestPolicySet_Evaluate(t *testing.T) { + ps := NewPolicySet() + + policy1 := NewPolicy("policy-1") + policy1.AddAllow(ActionRead, NewResourceID(ResourceMemory, "mem-1")) + + policy2 := NewPolicy("policy-2") + policy2.AddDeny(ActionWrite, NewResourceID(ResourceMemory, "mem-1")) + + ps.Add(policy1) + ps.Add(policy2) + + resource := NewResourceID(ResourceMemory, "mem-1") + + effect := ps.Evaluate("user-alice", ActionRead, resource) + if effect != EffectAllow { + t.Errorf("Expected allow for read, got %s", effect) + } + + effect = ps.Evaluate("user-alice", ActionWrite, resource) + if effect != EffectDeny { + t.Errorf("Expected deny for write, got %s", effect) + } +} + +func TestAuthorizer_Authorize(t *testing.T) { + authz := NewAuthorizer() + + resource := NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123") + + // Owner should be allowed + req := &AuthzRequest{ + SubjectHID: "user-alice", + Action: ActionRead, + Resource: resource, + } + + result := authz.Authorize(req) + if !result.Allowed { + t.Errorf("Owner should be allowed, reason: %s", result.Reason) + } + + // Non-owner should be denied by default + req.SubjectHID = "user-bob" + result = authz.Authorize(req) + if result.Allowed { + t.Errorf("Non-owner should be denied, reason: %s", result.Reason) + } + + // Grant relation to user-bob + err := authz.Grant("user-bob", "", resource, RelationReader) + if err != nil { + t.Errorf("Grant() error = %v", err) + } + + // Now user-bob should be allowed for read + req.SubjectHID = "user-bob" + req.Action = ActionRead + result = authz.Authorize(req) + if !result.Allowed { + t.Errorf("user-bob should be allowed as reader, reason: %s", result.Reason) + } + + // But not for write + req.Action = ActionWrite + result = authz.Authorize(req) + if result.Allowed { + t.Errorf("user-bob should not be allowed as writer, reason: %s", result.Reason) + } +} + +func TestAuthorizer_ConvenienceMethods(t *testing.T) { + authz := NewAuthorizer() + + resource := NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123") + + // Grant reader relation + authz.Grant("user-bob", "", resource, RelationReader) + + tests := []struct { + name string + method string + subject string + allowed bool + }{ + {"owner read", "CanRead", "user-alice", true}, + {"owner write", "CanWrite", "user-alice", true}, + {"owner delete", "CanDelete", "user-alice", true}, + {"owner execute", "CanExecute", "user-alice", true}, + {"owner admin", "CanAdmin", "user-alice", true}, + {"owner share", "CanShare", "user-alice", true}, + {"reader read", "CanRead", "user-bob", true}, + {"reader write", "CanWrite", "user-bob", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var allowed bool + switch tt.method { + case "CanRead": + allowed = authz.CanRead(tt.subject, "", resource) + case "CanWrite": + allowed = authz.CanWrite(tt.subject, "", resource) + case "CanDelete": + allowed = authz.CanDelete(tt.subject, "", resource) + case "CanExecute": + allowed = authz.CanExecute(tt.subject, "", resource) + case "CanAdmin": + allowed = authz.CanAdmin(tt.subject, "", resource) + case "CanShare": + allowed = authz.CanShare(tt.subject, "", resource) + } + + if allowed != tt.allowed { + t.Errorf("%s() = %v, want %v", tt.method, allowed, tt.allowed) + } + }) + } +} + +func TestAuthorizer_FilterAuthorized(t *testing.T) { + authz := NewAuthorizer() + + resource1 := NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-1") + resource2 := NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-2") // Same owner + + // Grant read access to user-bob for resource1 + authz.Grant("user-bob", "", resource1, RelationReader) + + // Grant read access to user-bob for resource2 + authz.Grant("user-bob", "", resource2, RelationReader) + + resources := []*ResourceID{resource1, resource2} + authorized := authz.FilterAuthorized("user-bob", "", resources, ActionRead) + + if len(authorized) != 2 { + t.Errorf("Expected 2 authorized resources, got %d", len(authorized)) + } + + // Filter by write should return none (user-bob only has Reader, not owner) + authorized = authz.FilterAuthorized("user-bob", "", resources, ActionWrite) + if len(authorized) != 0 { + t.Errorf("Expected 0 authorized resources for write, got %d", len(authorized)) + } +} + +func TestAuthorizer_Stats(t *testing.T) { + authz := NewAuthorizer() + + stats := authz.Stats() + + if stats["relation_count"] != 0 { + t.Errorf("Expected 0 relations, got %v", stats["relation_count"]) + } + + resource := NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-1") + authz.Grant("user-bob", "", resource, RelationReader) + + stats = authz.Stats() + + if stats["relation_count"] != 1 { + t.Errorf("Expected 1 relation, got %v", stats["relation_count"]) + } +} + +func TestActionAllowedByRelation(t *testing.T) { + tests := []struct { + action Action + relType RelationType + expected bool + }{ + {ActionRead, RelationReader, true}, + {ActionRead, RelationWriter, true}, + {ActionRead, RelationAdmin, true}, + {ActionWrite, RelationReader, false}, + {ActionWrite, RelationWriter, true}, + {ActionDelete, RelationAdmin, true}, + {ActionDelete, RelationWriter, false}, + {ActionExecute, RelationExecutor, true}, + {ActionExecute, RelationWriter, false}, + } + + for _, tt := range tests { + t.Run(string(tt.action)+"_"+string(tt.relType), func(t *testing.T) { + if got := ActionAllowedByRelation(tt.action, tt.relType); got != tt.expected { + t.Errorf("ActionAllowedByRelation(%s, %s) = %v, want %v", tt.action, tt.relType, got, tt.expected) + } + }) + } +} + +func TestParseEffect(t *testing.T) { + tests := []struct { + input string + want Effect + }{ + {"allow", EffectAllow}, + {"ALLOW", EffectAllow}, + {"deny", EffectDeny}, + {"DENY", EffectDeny}, + {"invalid", EffectDeny}, // Returns deny on invalid + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got, err := ParseEffect(tt.input) + if tt.input == "invalid" && err == nil { + t.Errorf("Expected error for invalid input") + } + if tt.input != "invalid" && got != tt.want { + t.Errorf("ParseEffect(%s) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestNewRelationWithSID(t *testing.T) { + rel := NewRelationWithSID( + "user-alice", + "node-01", + NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123"), + RelationReader, + ) + + if rel.SubjectHID != "user-alice" { + t.Errorf("SubjectHID = %v, want user-alice", rel.SubjectHID) + } + if rel.SubjectSID != "node-01" { + t.Errorf("SubjectSID = %v, want node-01", rel.SubjectSID) + } + + if !rel.IsDirect() { + t.Errorf("Expected IsDirect to be true") + } + if rel.IsTenant() { + t.Errorf("Expected IsTenant to be false") + } +} + +func TestNewRelation(t *testing.T) { + rel := NewRelation( + "user-alice", + NewResourceIDWithHID(ResourceMemory, "user-alice", "mem-123"), + RelationReader, + ) + + if rel.SubjectSID != "" { + t.Errorf("Expected empty SubjectSID") + } + + if rel.IsDirect() { + t.Errorf("Expected IsDirect to be false") + } + if !rel.IsTenant() { + t.Errorf("Expected IsTenant to be true") + } +} + +func TestResource_SetGetAttribute(t *testing.T) { + resource := NewResource(ResourceMemory, "user-alice", "mem-123") + + resource.SetAttribute("key1", "value1") + resource.SetAttribute("key2", "value2") + + v, ok := resource.GetAttribute("key1") + if !ok || v != "value1" { + t.Errorf("GetAttribute failed") + } + + _, ok = resource.GetAttribute("missing") + if ok { + t.Errorf("Expected false for missing key") + } +} + +func TestResource_Tags(t *testing.T) { + resource := NewResource(ResourceMemory, "user-alice", "mem-123") + + resource.AddTag("tag1") + resource.AddTag("tag2") + resource.AddTag("tag1") // Duplicate + + if !resource.HasTag("tag1") { + t.Errorf("Expected tag1") + } + if !resource.HasTag("tag2") { + t.Errorf("Expected tag2") + } + if len(resource.Tags) != 2 { + t.Errorf("Expected 2 tags, got %d", len(resource.Tags)) + } + + resource.RemoveTag("tag1") + if resource.HasTag("tag1") { + t.Errorf("tag1 should be removed") + } +} diff --git a/pkg/relation/resource.go b/pkg/relation/resource.go new file mode 100644 index 000000000..eea97721a --- /dev/null +++ b/pkg/relation/resource.go @@ -0,0 +1,273 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package relation + +import ( + "fmt" + "strings" +) + +// ResourceType defines the type of resource +type ResourceType string + +const ( + // ResourceMemory represents a memory item + ResourceMemory ResourceType = "memory" + + // ResourceNode represents a swarm node + ResourceNode ResourceType = "node" + + // ResourceTask represents a swarm task + ResourceTask ResourceType = "task" + + // ResourceChannel represents a communication channel + ResourceChannel ResourceType = "channel" + + // ResourceConfig represents a configuration item + ResourceConfig ResourceType = "config" + + // ResourceSkill represents a skill/ability + ResourceSkill ResourceType = "skill" + + // ResourceWorkflow represents a workflow definition + ResourceWorkflow ResourceType = "workflow" + + // ResourceAny is a wildcard for any resource type + ResourceAny ResourceType = "*" +) + +// ResourceID is a unique identifier for a resource +type ResourceID struct { + // Type is the type of resource + Type ResourceType `json:"type"` + + // HID is the tenant/owner H-id + HID string `json:"hid,omitempty"` + + // ID is the specific resource identifier + ID string `json:"id"` + + // Namespace is an optional namespace for the resource + Namespace string `json:"namespace,omitempty"` +} + +// NewResourceID creates a new resource ID +func NewResourceID(typ ResourceType, id string) *ResourceID { + return &ResourceID{ + Type: typ, + ID: id, + } +} + +// NewResourceIDWithHID creates a new resource ID with H-id +func NewResourceIDWithHID(typ ResourceType, hid, id string) *ResourceID { + return &ResourceID{ + Type: typ, + HID: hid, + ID: id, + } +} + +// String returns the string representation of the resource ID +func (r *ResourceID) String() string { + parts := []string{} + if r.Type != "" { + parts = append(parts, string(r.Type)) + } + if r.HID != "" { + parts = append(parts, r.HID) + } + if r.Namespace != "" { + parts = append(parts, r.Namespace) + } + if r.ID != "" { + parts = append(parts, r.ID) + } + return strings.Join(parts, ":") +} + +// ParseResourceID parses a resource ID from a string +// Format: type:hid:namespace:id or type:hid:id or type:id +func ParseResourceID(s string) (*ResourceID, error) { + parts := strings.Split(s, ":") + if len(parts) < 2 { + return nil, fmt.Errorf("invalid resource ID format: %s", s) + } + + r := &ResourceID{} + + switch len(parts) { + case 2: + r.Type = ResourceType(parts[0]) + r.ID = parts[1] + case 3: + r.Type = ResourceType(parts[0]) + // Could be hid:id or namespace:id + // Assume it's hid:id for backward compatibility + r.HID = parts[1] + r.ID = parts[2] + case 4: + r.Type = ResourceType(parts[0]) + r.HID = parts[1] + r.Namespace = parts[2] + r.ID = parts[3] + default: + return nil, fmt.Errorf("invalid resource ID format: %s", s) + } + + return r, nil +} + +// IsWildcard returns true if this resource ID is a wildcard +func (r *ResourceID) IsWildcard() bool { + return r.Type == ResourceAny || r.ID == "*" +} + +// Matches returns true if this resource ID matches the target +func (r *ResourceID) Matches(target *ResourceID) bool { + if r == nil || target == nil { + return false + } + + // Type match (with wildcard support) + if r.Type != ResourceAny && target.Type != ResourceAny && r.Type != target.Type { + return false + } + + // HID match (with wildcard support) + if r.HID != "" && r.HID != "*" && target.HID != "" && r.HID != target.HID { + return false + } + + // Namespace match + if r.Namespace != "" && target.Namespace != "" && r.Namespace != target.Namespace { + return false + } + + // ID match (with wildcard support) + if r.ID != "" && r.ID != "*" && target.ID != "" && r.ID != target.ID { + return false + } + + return true +} + +// Clone returns a copy of the resource ID +func (r *ResourceID) Clone() *ResourceID { + if r == nil { + return nil + } + return &ResourceID{ + Type: r.Type, + HID: r.HID, + ID: r.ID, + Namespace: r.Namespace, + } +} + +// Resource represents a securable resource in the system +type Resource struct { + // ID is the unique identifier for this resource + ID *ResourceID `json:"id"` + + // OwnerHID is the H-id that owns this resource + OwnerHID string `json:"owner_hid"` + + // OwnerSID is the S-id that created this resource + OwnerSID string `json:"owner_sid,omitempty"` + + // Attributes contains additional resource attributes + Attributes map[string]string `json:"attributes,omitempty"` + + // Tags for resource categorization + Tags []string `json:"tags,omitempty"` +} + +// NewResource creates a new resource +func NewResource(typ ResourceType, hid, id string) *Resource { + return &Resource{ + ID: NewResourceIDWithHID(typ, hid, id), + OwnerHID: hid, + Attributes: make(map[string]string), + Tags: make([]string, 0), + } +} + +// GetAttribute returns an attribute value +func (r *Resource) GetAttribute(key string) (string, bool) { + if r.Attributes == nil { + return "", false + } + v, ok := r.Attributes[key] + return v, ok +} + +// SetAttribute sets an attribute value +func (r *Resource) SetAttribute(key, value string) { + if r.Attributes == nil { + r.Attributes = make(map[string]string) + } + r.Attributes[key] = value +} + +// HasTag checks if the resource has a specific tag +func (r *Resource) HasTag(tag string) bool { + for _, t := range r.Tags { + if t == tag { + return true + } + } + return false +} + +// AddTag adds a tag to the resource +func (r *Resource) AddTag(tag string) { + for _, t := range r.Tags { + if t == tag { + return + } + } + r.Tags = append(r.Tags, tag) +} + +// RemoveTag removes a tag from the resource +func (r *Resource) RemoveTag(tag string) { + for i, t := range r.Tags { + if t == tag { + r.Tags = append(r.Tags[:i], r.Tags[i+1:]...) + return + } + } +} + +// Matches checks if this resource matches the given filter +type ResourceFilter struct { + Type ResourceType + HID string + Namespace string + ID string + Tag string +} + +// Matches checks if a resource matches the filter +func (f *ResourceFilter) Matches(r *Resource) bool { + if f.Type != "" && f.Type != ResourceAny && r.ID.Type != f.Type { + return false + } + if f.HID != "" && r.OwnerHID != f.HID { + return false + } + if f.Namespace != "" && r.ID.Namespace != f.Namespace { + return false + } + if f.ID != "" && r.ID.ID != f.ID { + return false + } + if f.Tag != "" && !r.HasTag(f.Tag) { + return false + } + return true +} diff --git a/pkg/routing/agent_id.go b/pkg/routing/agent_id.go new file mode 100644 index 000000000..bcf2f0dc0 --- /dev/null +++ b/pkg/routing/agent_id.go @@ -0,0 +1,66 @@ +package routing + +import ( + "regexp" + "strings" +) + +const ( + DefaultAgentID = "main" + DefaultMainKey = "main" + DefaultAccountID = "default" + MaxAgentIDLength = 64 +) + +var ( + validIDRe = regexp.MustCompile(`^[a-z0-9][a-z0-9_-]{0,63}$`) + invalidCharsRe = regexp.MustCompile(`[^a-z0-9_-]+`) + leadingDashRe = regexp.MustCompile(`^-+`) + trailingDashRe = regexp.MustCompile(`-+$`) +) + +// NormalizeAgentID sanitizes an agent ID to [a-z0-9][a-z0-9_-]{0,63}. +// Invalid characters are collapsed to "-". Leading/trailing dashes stripped. +// Empty input returns DefaultAgentID ("main"). +func NormalizeAgentID(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return DefaultAgentID + } + lower := strings.ToLower(trimmed) + if validIDRe.MatchString(lower) { + return lower + } + result := invalidCharsRe.ReplaceAllString(lower, "-") + result = leadingDashRe.ReplaceAllString(result, "") + result = trailingDashRe.ReplaceAllString(result, "") + if len(result) > MaxAgentIDLength { + result = result[:MaxAgentIDLength] + } + if result == "" { + return DefaultAgentID + } + return result +} + +// NormalizeAccountID sanitizes an account ID. Empty returns DefaultAccountID. +func NormalizeAccountID(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return DefaultAccountID + } + lower := strings.ToLower(trimmed) + if validIDRe.MatchString(lower) { + return lower + } + result := invalidCharsRe.ReplaceAllString(lower, "-") + result = leadingDashRe.ReplaceAllString(result, "") + result = trailingDashRe.ReplaceAllString(result, "") + if len(result) > MaxAgentIDLength { + result = result[:MaxAgentIDLength] + } + if result == "" { + return DefaultAccountID + } + return result +} diff --git a/pkg/routing/agent_id_test.go b/pkg/routing/agent_id_test.go new file mode 100644 index 000000000..050fe0645 --- /dev/null +++ b/pkg/routing/agent_id_test.go @@ -0,0 +1,86 @@ +package routing + +import "testing" + +func TestNormalizeAgentID_Empty(t *testing.T) { + if got := NormalizeAgentID(""); got != DefaultAgentID { + t.Errorf("NormalizeAgentID('') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_Whitespace(t *testing.T) { + if got := NormalizeAgentID(" "); got != DefaultAgentID { + t.Errorf("NormalizeAgentID(' ') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_Valid(t *testing.T) { + tests := []struct { + input, want string + }{ + {"main", "main"}, + {"Main", "main"}, + {"SALES", "sales"}, + {"support-bot", "support-bot"}, + {"agent_1", "agent_1"}, + {"a", "a"}, + {"0test", "0test"}, + } + for _, tt := range tests { + if got := NormalizeAgentID(tt.input); got != tt.want { + t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestNormalizeAgentID_InvalidChars(t *testing.T) { + tests := []struct { + input, want string + }{ + {"Hello World", "hello-world"}, + {"agent@123", "agent-123"}, + {"foo.bar.baz", "foo-bar-baz"}, + {"--leading", "leading"}, + {"--both--", "both"}, + } + for _, tt := range tests { + if got := NormalizeAgentID(tt.input); got != tt.want { + t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestNormalizeAgentID_AllInvalid(t *testing.T) { + if got := NormalizeAgentID("@@@"); got != DefaultAgentID { + t.Errorf("NormalizeAgentID('@@@') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_TruncatesAt64(t *testing.T) { + long := "" + for i := 0; i < 100; i++ { + long += "a" + } + got := NormalizeAgentID(long) + if len(got) > MaxAgentIDLength { + t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength) + } +} + +func TestNormalizeAccountID_Empty(t *testing.T) { + if got := NormalizeAccountID(""); got != DefaultAccountID { + t.Errorf("NormalizeAccountID('') = %q, want %q", got, DefaultAccountID) + } +} + +func TestNormalizeAccountID_Valid(t *testing.T) { + if got := NormalizeAccountID("MyBot"); got != "mybot" { + t.Errorf("NormalizeAccountID('MyBot') = %q, want 'mybot'", got) + } +} + +func TestNormalizeAccountID_InvalidChars(t *testing.T) { + if got := NormalizeAccountID("bot@home"); got != "bot-home" { + t.Errorf("NormalizeAccountID('bot@home') = %q, want 'bot-home'", got) + } +} diff --git a/pkg/routing/route.go b/pkg/routing/route.go new file mode 100644 index 000000000..9eb060c53 --- /dev/null +++ b/pkg/routing/route.go @@ -0,0 +1,252 @@ +package routing + +import ( + "strings" + + "github.com/sipeed/picoclaw/pkg/config" +) + +// RouteInput contains the routing context from an inbound message. +type RouteInput struct { + Channel string + AccountID string + Peer *RoutePeer + ParentPeer *RoutePeer + GuildID string + TeamID string +} + +// ResolvedRoute is the result of agent routing. +type ResolvedRoute struct { + AgentID string + Channel string + AccountID string + SessionKey string + MainSessionKey string + MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default" +} + +// RouteResolver determines which agent handles a message based on config bindings. +type RouteResolver struct { + cfg *config.Config +} + +// NewRouteResolver creates a new route resolver. +func NewRouteResolver(cfg *config.Config) *RouteResolver { + return &RouteResolver{cfg: cfg} +} + +// ResolveRoute determines which agent handles the message and constructs session keys. +// Implements the 7-level priority cascade: +// peer > parent_peer > guild > team > account > channel_wildcard > default +func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute { + channel := strings.ToLower(strings.TrimSpace(input.Channel)) + accountID := NormalizeAccountID(input.AccountID) + peer := input.Peer + + dmScope := DMScope(r.cfg.Session.DMScope) + if dmScope == "" { + dmScope = DMScopeMain + } + identityLinks := r.cfg.Session.IdentityLinks + + bindings := r.filterBindings(channel, accountID) + + choose := func(agentID string, matchedBy string) ResolvedRoute { + resolvedAgentID := r.pickAgentID(agentID) + sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: resolvedAgentID, + Channel: channel, + AccountID: accountID, + Peer: peer, + DMScope: dmScope, + IdentityLinks: identityLinks, + })) + mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID)) + return ResolvedRoute{ + AgentID: resolvedAgentID, + Channel: channel, + AccountID: accountID, + SessionKey: sessionKey, + MainSessionKey: mainSessionKey, + MatchedBy: matchedBy, + } + } + + // Priority 1: Peer binding + if peer != nil && strings.TrimSpace(peer.ID) != "" { + if match := r.findPeerMatch(bindings, peer); match != nil { + return choose(match.AgentID, "binding.peer") + } + } + + // Priority 2: Parent peer binding + parentPeer := input.ParentPeer + if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" { + if match := r.findPeerMatch(bindings, parentPeer); match != nil { + return choose(match.AgentID, "binding.peer.parent") + } + } + + // Priority 3: Guild binding + guildID := strings.TrimSpace(input.GuildID) + if guildID != "" { + if match := r.findGuildMatch(bindings, guildID); match != nil { + return choose(match.AgentID, "binding.guild") + } + } + + // Priority 4: Team binding + teamID := strings.TrimSpace(input.TeamID) + if teamID != "" { + if match := r.findTeamMatch(bindings, teamID); match != nil { + return choose(match.AgentID, "binding.team") + } + } + + // Priority 5: Account binding + if match := r.findAccountMatch(bindings); match != nil { + return choose(match.AgentID, "binding.account") + } + + // Priority 6: Channel wildcard binding + if match := r.findChannelWildcardMatch(bindings); match != nil { + return choose(match.AgentID, "binding.channel") + } + + // Priority 7: Default agent + return choose(r.resolveDefaultAgentID(), "default") +} + +func (r *RouteResolver) filterBindings(channel, accountID string) []config.AgentBinding { + var filtered []config.AgentBinding + for _, b := range r.cfg.Bindings { + matchChannel := strings.ToLower(strings.TrimSpace(b.Match.Channel)) + if matchChannel == "" || matchChannel != channel { + continue + } + if !matchesAccountID(b.Match.AccountID, accountID) { + continue + } + filtered = append(filtered, b) + } + return filtered +} + +func matchesAccountID(matchAccountID, actual string) bool { + trimmed := strings.TrimSpace(matchAccountID) + if trimmed == "" { + return actual == DefaultAccountID + } + if trimmed == "*" { + return true + } + return strings.ToLower(trimmed) == strings.ToLower(actual) +} + +func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + if b.Match.Peer == nil { + continue + } + peerKind := strings.ToLower(strings.TrimSpace(b.Match.Peer.Kind)) + peerID := strings.TrimSpace(b.Match.Peer.ID) + if peerKind == "" || peerID == "" { + continue + } + if peerKind == strings.ToLower(peer.Kind) && peerID == peer.ID { + return b + } + } + return nil +} + +func (r *RouteResolver) findGuildMatch(bindings []config.AgentBinding, guildID string) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + matchGuild := strings.TrimSpace(b.Match.GuildID) + if matchGuild != "" && matchGuild == guildID { + return &bindings[i] + } + } + return nil +} + +func (r *RouteResolver) findTeamMatch(bindings []config.AgentBinding, teamID string) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + matchTeam := strings.TrimSpace(b.Match.TeamID) + if matchTeam != "" && matchTeam == teamID { + return &bindings[i] + } + } + return nil +} + +func (r *RouteResolver) findAccountMatch(bindings []config.AgentBinding) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + accountID := strings.TrimSpace(b.Match.AccountID) + if accountID == "*" { + continue + } + if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" { + continue + } + return &bindings[i] + } + return nil +} + +func (r *RouteResolver) findChannelWildcardMatch(bindings []config.AgentBinding) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + accountID := strings.TrimSpace(b.Match.AccountID) + if accountID != "*" { + continue + } + if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" { + continue + } + return &bindings[i] + } + return nil +} + +func (r *RouteResolver) pickAgentID(agentID string) string { + trimmed := strings.TrimSpace(agentID) + if trimmed == "" { + return NormalizeAgentID(r.resolveDefaultAgentID()) + } + normalized := NormalizeAgentID(trimmed) + agents := r.cfg.Agents.List + if len(agents) == 0 { + return normalized + } + for _, a := range agents { + if NormalizeAgentID(a.ID) == normalized { + return normalized + } + } + return NormalizeAgentID(r.resolveDefaultAgentID()) +} + +func (r *RouteResolver) resolveDefaultAgentID() string { + agents := r.cfg.Agents.List + if len(agents) == 0 { + return DefaultAgentID + } + for _, a := range agents { + if a.Default { + id := strings.TrimSpace(a.ID) + if id != "" { + return NormalizeAgentID(id) + } + } + } + if id := strings.TrimSpace(agents[0].ID); id != "" { + return NormalizeAgentID(id) + } + return DefaultAgentID +} diff --git a/pkg/routing/route_test.go b/pkg/routing/route_test.go new file mode 100644 index 000000000..8255db5f9 --- /dev/null +++ b/pkg/routing/route_test.go @@ -0,0 +1,297 @@ +package routing + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *config.Config { + return &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: "/tmp/picoclaw-test", + Model: "gpt-4", + }, + List: agents, + }, + Bindings: bindings, + Session: config.SessionConfig{ + DMScope: "per-peer", + }, + } +} + +func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) { + cfg := testConfig(nil, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != DefaultAgentID { + t.Errorf("AgentID = %q, want %q", route.AgentID, DefaultAgentID) + } + if route.MatchedBy != "default" { + t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy) + } +} + +func TestResolveRoute_PeerBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "sales", Default: true}, + {ID: "support"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "support", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + Peer: &config.PeerMatch{Kind: "direct", ID: "user123"}, + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + }) + + if route.AgentID != "support" { + t.Errorf("AgentID = %q, want 'support'", route.AgentID) + } + if route.MatchedBy != "binding.peer" { + t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy) + } +} + +func TestResolveRoute_GuildBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "gaming"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "gaming", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + GuildID: "guild-abc", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "discord", + GuildID: "guild-abc", + Peer: &RoutePeer{Kind: "channel", ID: "ch1"}, + }) + + if route.AgentID != "gaming" { + t.Errorf("AgentID = %q, want 'gaming'", route.AgentID) + } + if route.MatchedBy != "binding.guild" { + t.Errorf("MatchedBy = %q, want 'binding.guild'", route.MatchedBy) + } +} + +func TestResolveRoute_TeamBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "work"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "work", + Match: config.BindingMatch{ + Channel: "slack", + AccountID: "*", + TeamID: "T12345", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "slack", + TeamID: "T12345", + Peer: &RoutePeer{Kind: "channel", ID: "C001"}, + }) + + if route.AgentID != "work" { + t.Errorf("AgentID = %q, want 'work'", route.AgentID) + } + if route.MatchedBy != "binding.team" { + t.Errorf("MatchedBy = %q, want 'binding.team'", route.MatchedBy) + } +} + +func TestResolveRoute_AccountBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "default-agent", Default: true}, + {ID: "premium"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "premium", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "bot2", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + AccountID: "bot2", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != "premium" { + t.Errorf("AgentID = %q, want 'premium'", route.AgentID) + } + if route.MatchedBy != "binding.account" { + t.Errorf("MatchedBy = %q, want 'binding.account'", route.MatchedBy) + } +} + +func TestResolveRoute_ChannelWildcard(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "main", Default: true}, + {ID: "telegram-bot"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "telegram-bot", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != "telegram-bot" { + t.Errorf("AgentID = %q, want 'telegram-bot'", route.AgentID) + } + if route.MatchedBy != "binding.channel" { + t.Errorf("MatchedBy = %q, want 'binding.channel'", route.MatchedBy) + } +} + +func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "vip"}, + {ID: "gaming"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "vip", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + Peer: &config.PeerMatch{Kind: "direct", ID: "user-vip"}, + }, + }, + { + AgentID: "gaming", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + GuildID: "guild-1", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "discord", + GuildID: "guild-1", + Peer: &RoutePeer{Kind: "direct", ID: "user-vip"}, + }) + + if route.AgentID != "vip" { + t.Errorf("AgentID = %q, want 'vip' (peer should beat guild)", route.AgentID) + } + if route.MatchedBy != "binding.peer" { + t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy) + } +} + +func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "main", Default: true}, + } + bindings := []config.AgentBinding{ + { + AgentID: "nonexistent", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + }) + + if route.AgentID != "main" { + t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID) + } +} + +func TestResolveRoute_DefaultAgentSelection(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta", Default: true}, + {ID: "gamma"}, + } + cfg := testConfig(agents, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "cli", + }) + + if route.AgentID != "beta" { + t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID) + } +} + +func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta"}, + } + cfg := testConfig(agents, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "cli", + }) + + if route.AgentID != "alpha" { + t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID) + } +} diff --git a/pkg/routing/session_key.go b/pkg/routing/session_key.go new file mode 100644 index 000000000..e12f0d1d8 --- /dev/null +++ b/pkg/routing/session_key.go @@ -0,0 +1,183 @@ +package routing + +import ( + "fmt" + "strings" +) + +// DMScope controls DM session isolation granularity. +type DMScope string + +const ( + DMScopeMain DMScope = "main" + DMScopePerPeer DMScope = "per-peer" + DMScopePerChannelPeer DMScope = "per-channel-peer" + DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer" +) + +// RoutePeer represents a chat peer with kind and ID. +type RoutePeer struct { + Kind string // "direct", "group", "channel" + ID string +} + +// SessionKeyParams holds all inputs for session key construction. +type SessionKeyParams struct { + AgentID string + Channel string + AccountID string + Peer *RoutePeer + DMScope DMScope + IdentityLinks map[string][]string +} + +// ParsedSessionKey is the result of parsing an agent-scoped session key. +type ParsedSessionKey struct { + AgentID string + Rest string +} + +// BuildAgentMainSessionKey returns "agent::main". +func BuildAgentMainSessionKey(agentID string) string { + return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey) +} + +// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope. +func BuildAgentPeerSessionKey(params SessionKeyParams) string { + agentID := NormalizeAgentID(params.AgentID) + + peer := params.Peer + if peer == nil { + peer = &RoutePeer{Kind: "direct"} + } + peerKind := strings.TrimSpace(peer.Kind) + if peerKind == "" { + peerKind = "direct" + } + + if peerKind == "direct" { + dmScope := params.DMScope + if dmScope == "" { + dmScope = DMScopeMain + } + peerID := strings.TrimSpace(peer.ID) + + // Resolve identity links (cross-platform collapse) + if dmScope != DMScopeMain && peerID != "" { + if linked := resolveLinkedPeerID(params.IdentityLinks, params.Channel, peerID); linked != "" { + peerID = linked + } + } + peerID = strings.ToLower(peerID) + + switch dmScope { + case DMScopePerAccountChannelPeer: + if peerID != "" { + channel := normalizeChannel(params.Channel) + accountID := NormalizeAccountID(params.AccountID) + return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID) + } + case DMScopePerChannelPeer: + if peerID != "" { + channel := normalizeChannel(params.Channel) + return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID) + } + case DMScopePerPeer: + if peerID != "" { + return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID) + } + } + return BuildAgentMainSessionKey(agentID) + } + + // Group/channel peers always get per-peer sessions + channel := normalizeChannel(params.Channel) + peerID := strings.ToLower(strings.TrimSpace(peer.ID)) + if peerID == "" { + peerID = "unknown" + } + return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID) +} + +// ParseAgentSessionKey extracts agentId and rest from "agent::". +func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey { + raw := strings.TrimSpace(sessionKey) + if raw == "" { + return nil + } + parts := strings.SplitN(raw, ":", 3) + if len(parts) < 3 { + return nil + } + if parts[0] != "agent" { + return nil + } + agentID := strings.TrimSpace(parts[1]) + rest := parts[2] + if agentID == "" || rest == "" { + return nil + } + return &ParsedSessionKey{AgentID: agentID, Rest: rest} +} + +// IsSubagentSessionKey returns true if the session key represents a subagent. +func IsSubagentSessionKey(sessionKey string) bool { + raw := strings.TrimSpace(sessionKey) + if raw == "" { + return false + } + if strings.HasPrefix(strings.ToLower(raw), "subagent:") { + return true + } + parsed := ParseAgentSessionKey(raw) + if parsed == nil { + return false + } + return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:") +} + +func normalizeChannel(channel string) string { + c := strings.TrimSpace(strings.ToLower(channel)) + if c == "" { + return "unknown" + } + return c +} + +func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string { + if len(identityLinks) == 0 { + return "" + } + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return "" + } + + candidates := make(map[string]bool) + rawCandidate := strings.ToLower(peerID) + if rawCandidate != "" { + candidates[rawCandidate] = true + } + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel != "" { + scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID)) + candidates[scopedCandidate] = true + } + if len(candidates) == 0 { + return "" + } + + for canonical, ids := range identityLinks { + canonicalName := strings.TrimSpace(canonical) + if canonicalName == "" { + continue + } + for _, id := range ids { + normalized := strings.ToLower(strings.TrimSpace(id)) + if normalized != "" && candidates[normalized] { + return canonicalName + } + } + } + return "" +} diff --git a/pkg/routing/session_key_test.go b/pkg/routing/session_key_test.go new file mode 100644 index 000000000..81e4ce018 --- /dev/null +++ b/pkg/routing/session_key_test.go @@ -0,0 +1,162 @@ +package routing + +import "testing" + +func TestBuildAgentMainSessionKey(t *testing.T) { + got := BuildAgentMainSessionKey("sales") + want := "agent:sales:main" + if got != want { + t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want) + } +} + +func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) { + got := BuildAgentMainSessionKey("Sales Bot") + want := "agent:sales-bot:main" + if got != want { + t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopeMain, + }) + want := "agent:main:main" + if got != want { + t.Errorf("DMScopeMain = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerPeer, + }) + want := "agent:main:direct:user123" + if got != want { + t.Errorf("DMScopePerPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerChannelPeer, + }) + want := "agent:main:telegram:direct:user123" + if got != want { + t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + AccountID: "bot1", + Peer: &RoutePeer{Kind: "direct", ID: "User123"}, + DMScope: DMScopePerAccountChannelPeer, + }) + want := "agent:main:telegram:bot1:direct:user123" + if got != want { + t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "group", ID: "chat456"}, + DMScope: DMScopePerPeer, + }) + want := "agent:main:telegram:group:chat456" + if got != want { + t.Errorf("GroupPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: nil, + DMScope: DMScopePerPeer, + }) + // nil peer defaults to direct with empty ID, falls to main + want := "agent:main:main" + if got != want { + t.Errorf("NilPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) { + links := map[string][]string{ + "john": {"telegram:user123", "discord:john#1234"}, + } + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerPeer, + IdentityLinks: links, + }) + want := "agent:main:direct:john" + if got != want { + t.Errorf("IdentityLink = %q, want %q", got, want) + } +} + +func TestParseAgentSessionKey_Valid(t *testing.T) { + parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123") + if parsed == nil { + t.Fatal("expected non-nil result") + } + if parsed.AgentID != "sales" { + t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID) + } + if parsed.Rest != "telegram:direct:user123" { + t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest) + } +} + +func TestParseAgentSessionKey_Invalid(t *testing.T) { + tests := []string{ + "", + "foo:bar", + "notprefix:sales:main", + "agent::main", + "agent:sales:", + } + for _, input := range tests { + if got := ParseAgentSessionKey(input); got != nil { + t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got) + } + } +} + +func TestIsSubagentSessionKey(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"subagent:task-1", true}, + {"agent:main:subagent:task-1", true}, + {"agent:main:main", false}, + {"agent:main:telegram:direct:user123", false}, + {"", false}, + } + for _, tt := range tests { + if got := IsSubagentSessionKey(tt.input); got != tt.want { + t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} diff --git a/pkg/session/manager.go b/pkg/session/manager.go index b4b825764..08f0b0ad2 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -4,6 +4,7 @@ import ( "encoding/json" "os" "path/filepath" + "strings" "sync" "time" @@ -31,7 +32,7 @@ func NewSessionManager(storage string) *SessionManager { } if storage != "" { - os.MkdirAll(storage, 0755) + os.MkdirAll(storage, 0o755) sm.loadSessions() } @@ -39,21 +40,21 @@ func NewSessionManager(storage string) *SessionManager { } func (sm *SessionManager) GetOrCreate(key string) *Session { - sm.mu.RLock() + sm.mu.Lock() + defer sm.mu.Unlock() + session, ok := sm.sessions[key] - sm.mu.RUnlock() + if ok { + return session + } - if !ok { - sm.mu.Lock() - session = &Session{ - Key: key, - Messages: []providers.Message{}, - Created: time.Now(), - Updated: time.Now(), - } - sm.sessions[key] = session - sm.mu.Unlock() + session = &Session{ + Key: key, + Messages: []providers.Message{}, + Created: time.Now(), + Updated: time.Now(), } + sm.sessions[key] = session return session } @@ -130,6 +131,12 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) { return } + if keepLast <= 0 { + session.Messages = []providers.Message{} + session.Updated = time.Now() + return + } + if len(session.Messages) <= keepLast { return } @@ -138,22 +145,92 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) { session.Updated = time.Now() } -func (sm *SessionManager) Save(session *Session) error { +// sanitizeFilename converts a session key into a cross-platform safe filename. +// Session keys use "channel:chatID" (e.g. "telegram:123456") but ':' is the +// volume separator on Windows, so filepath.Base would misinterpret the key. +// We replace it with '_'. The original key is preserved inside the JSON file, +// so loadSessions still maps back to the right in-memory key. +func sanitizeFilename(key string) string { + return strings.ReplaceAll(key, ":", "_") +} + +func (sm *SessionManager) Save(key string) error { if sm.storage == "" { return nil } - sm.mu.Lock() - defer sm.mu.Unlock() + filename := sanitizeFilename(key) + + // filepath.IsLocal rejects empty names, "..", absolute paths, and + // OS-reserved device names (NUL, COM1 … on Windows). + // The extra checks reject "." and any directory separators so that + // the session file is always written directly inside sm.storage. + if filename == "." || !filepath.IsLocal(filename) || strings.ContainsAny(filename, `/\`) { + return os.ErrInvalid + } - sessionPath := filepath.Join(sm.storage, session.Key+".json") + // Snapshot under read lock, then perform slow file I/O after unlock. + sm.mu.RLock() + stored, ok := sm.sessions[key] + if !ok { + sm.mu.RUnlock() + return nil + } - data, err := json.MarshalIndent(session, "", " ") + snapshot := Session{ + Key: stored.Key, + Summary: stored.Summary, + Created: stored.Created, + Updated: stored.Updated, + } + if len(stored.Messages) > 0 { + snapshot.Messages = make([]providers.Message, len(stored.Messages)) + copy(snapshot.Messages, stored.Messages) + } else { + snapshot.Messages = []providers.Message{} + } + sm.mu.RUnlock() + + data, err := json.MarshalIndent(snapshot, "", " ") if err != nil { return err } - return os.WriteFile(sessionPath, data, 0644) + sessionPath := filepath.Join(sm.storage, filename+".json") + tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp") + if err != nil { + return err + } + + tmpPath := tmpFile.Name() + cleanup := true + defer func() { + if cleanup { + _ = os.Remove(tmpPath) + } + }() + + if _, err := tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Chmod(0o644); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Sync(); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Close(); err != nil { + return err + } + + if err := os.Rename(tmpPath, sessionPath); err != nil { + return err + } + cleanup = false + return nil } func (sm *SessionManager) loadSessions() error { @@ -187,3 +264,19 @@ func (sm *SessionManager) loadSessions() error { return nil } + +// SetHistory updates the messages of a session. +func (sm *SessionManager) SetHistory(key string, history []providers.Message) { + sm.mu.Lock() + defer sm.mu.Unlock() + + session, ok := sm.sessions[key] + if ok { + // Create a deep copy to strictly isolate internal state + // from the caller's slice. + msgs := make([]providers.Message, len(history)) + copy(msgs, history) + session.Messages = msgs + session.Updated = time.Now() + } +} diff --git a/pkg/session/manager_test.go b/pkg/session/manager_test.go new file mode 100644 index 000000000..5ef5f4349 --- /dev/null +++ b/pkg/session/manager_test.go @@ -0,0 +1,74 @@ +package session + +import ( + "os" + "path/filepath" + "testing" +) + +func TestSanitizeFilename(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"simple", "simple"}, + {"telegram:123456", "telegram_123456"}, + {"discord:987654321", "discord_987654321"}, + {"slack:C01234", "slack_C01234"}, + {"no-colons-here", "no-colons-here"}, + {"multiple:colons:here", "multiple_colons_here"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := sanitizeFilename(tt.input) + if got != tt.expected { + t.Errorf("sanitizeFilename(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} + +func TestSave_WithColonInKey(t *testing.T) { + tmpDir := t.TempDir() + sm := NewSessionManager(tmpDir) + + // Create a session with a key containing colon (typical channel session key). + key := "telegram:123456" + sm.GetOrCreate(key) + sm.AddMessage(key, "user", "hello") + + // Save should succeed even though the key contains ':' + if err := sm.Save(key); err != nil { + t.Fatalf("Save(%q) failed: %v", key, err) + } + + // The file on disk should use sanitized name. + expectedFile := filepath.Join(tmpDir, "telegram_123456.json") + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + t.Fatalf("expected session file %s to exist", expectedFile) + } + + // Load into a fresh manager and verify the session round-trips. + sm2 := NewSessionManager(tmpDir) + history := sm2.GetHistory(key) + if len(history) != 1 { + t.Fatalf("expected 1 message after reload, got %d", len(history)) + } + if history[0].Content != "hello" { + t.Errorf("expected message content %q, got %q", "hello", history[0].Content) + } +} + +func TestSave_RejectsPathTraversal(t *testing.T) { + tmpDir := t.TempDir() + sm := NewSessionManager(tmpDir) + + badKeys := []string{"", ".", "..", "foo/bar", "foo\\bar"} + for _, key := range badKeys { + sm.GetOrCreate(key) + if err := sm.Save(key); err == nil { + t.Errorf("Save(%q) should have failed but didn't", key) + } + } +} diff --git a/pkg/skills/clawhub_registry.go b/pkg/skills/clawhub_registry.go new file mode 100644 index 000000000..f78197bbe --- /dev/null +++ b/pkg/skills/clawhub_registry.go @@ -0,0 +1,314 @@ +package skills + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "time" + + "github.com/sipeed/picoclaw/pkg/utils" +) + +const ( + defaultClawHubTimeout = 30 * time.Second + defaultMaxZipSize = 50 * 1024 * 1024 // 50 MB + defaultMaxResponseSize = 2 * 1024 * 1024 // 2 MB +) + +// ClawHubRegistry implements SkillRegistry for the ClawHub platform. +type ClawHubRegistry struct { + baseURL string + authToken string // Optional - for elevated rate limits + searchPath string // Search API + skillsPath string // For retrieving skill metadata + downloadPath string // For fetching ZIP files for download + maxZipSize int + maxResponseSize int + client *http.Client +} + +// NewClawHubRegistry creates a new ClawHub registry client from config. +func NewClawHubRegistry(cfg ClawHubConfig) *ClawHubRegistry { + baseURL := cfg.BaseURL + if baseURL == "" { + baseURL = "https://clawhub.ai" + } + searchPath := cfg.SearchPath + if searchPath == "" { + searchPath = "/api/v1/search" + } + skillsPath := cfg.SkillsPath + if skillsPath == "" { + skillsPath = "/api/v1/skills" + } + downloadPath := cfg.DownloadPath + if downloadPath == "" { + downloadPath = "/api/v1/download" + } + + timeout := defaultClawHubTimeout + if cfg.Timeout > 0 { + timeout = time.Duration(cfg.Timeout) * time.Second + } + + maxZip := defaultMaxZipSize + if cfg.MaxZipSize > 0 { + maxZip = cfg.MaxZipSize + } + + maxResp := defaultMaxResponseSize + if cfg.MaxResponseSize > 0 { + maxResp = cfg.MaxResponseSize + } + + return &ClawHubRegistry{ + baseURL: baseURL, + authToken: cfg.AuthToken, + searchPath: searchPath, + skillsPath: skillsPath, + downloadPath: downloadPath, + maxZipSize: maxZip, + maxResponseSize: maxResp, + client: &http.Client{ + Timeout: timeout, + Transport: &http.Transport{ + MaxIdleConns: 5, + IdleConnTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + }, + }, + } +} + +func (c *ClawHubRegistry) Name() string { + return "clawhub" +} + +// --- Search --- + +type clawhubSearchResponse struct { + Results []clawhubSearchResult `json:"results"` +} + +type clawhubSearchResult struct { + Score float64 `json:"score"` + Slug *string `json:"slug"` + DisplayName *string `json:"displayName"` + Summary *string `json:"summary"` + Version *string `json:"version"` +} + +func (c *ClawHubRegistry) Search(ctx context.Context, query string, limit int) ([]SearchResult, error) { + u, err := url.Parse(c.baseURL + c.searchPath) + if err != nil { + return nil, fmt.Errorf("invalid base URL: %w", err) + } + + q := u.Query() + q.Set("q", query) + if limit > 0 { + q.Set("limit", fmt.Sprintf("%d", limit)) + } + u.RawQuery = q.Encode() + + body, err := c.doGet(ctx, u.String()) + if err != nil { + return nil, fmt.Errorf("search request failed: %w", err) + } + + var resp clawhubSearchResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse search response: %w", err) + } + + results := make([]SearchResult, 0, len(resp.Results)) + for _, r := range resp.Results { + slug := utils.DerefStr(r.Slug, "") + if slug == "" { + continue + } + + summary := utils.DerefStr(r.Summary, "") + if summary == "" { + continue + } + + displayName := utils.DerefStr(r.DisplayName, "") + if displayName == "" { + displayName = slug + } + + results = append(results, SearchResult{ + Score: r.Score, + Slug: slug, + DisplayName: displayName, + Summary: summary, + Version: utils.DerefStr(r.Version, ""), + RegistryName: c.Name(), + }) + } + + return results, nil +} + +// --- GetSkillMeta --- + +type clawhubSkillResponse struct { + Slug string `json:"slug"` + DisplayName string `json:"displayName"` + Summary string `json:"summary"` + LatestVersion *clawhubVersionInfo `json:"latestVersion"` + Moderation *clawhubModerationInfo `json:"moderation"` +} + +type clawhubVersionInfo struct { + Version string `json:"version"` +} + +type clawhubModerationInfo struct { + IsMalwareBlocked bool `json:"isMalwareBlocked"` + IsSuspicious bool `json:"isSuspicious"` +} + +func (c *ClawHubRegistry) GetSkillMeta(ctx context.Context, slug string) (*SkillMeta, error) { + if err := utils.ValidateSkillIdentifier(slug); err != nil { + return nil, fmt.Errorf("invalid slug %q: error: %s", slug, err.Error()) + } + + u := c.baseURL + c.skillsPath + "/" + url.PathEscape(slug) + + body, err := c.doGet(ctx, u) + if err != nil { + return nil, fmt.Errorf("skill metadata request failed: %w", err) + } + + var resp clawhubSkillResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse skill metadata: %w", err) + } + + meta := &SkillMeta{ + Slug: resp.Slug, + DisplayName: resp.DisplayName, + Summary: resp.Summary, + RegistryName: c.Name(), + } + + if resp.LatestVersion != nil { + meta.LatestVersion = resp.LatestVersion.Version + } + if resp.Moderation != nil { + meta.IsMalwareBlocked = resp.Moderation.IsMalwareBlocked + meta.IsSuspicious = resp.Moderation.IsSuspicious + } + + return meta, nil +} + +// --- DownloadAndInstall --- + +// DownloadAndInstall fetches metadata (with fallback), resolves version, +// downloads the skill ZIP, and extracts it to targetDir. +// Returns an InstallResult for the caller to use for moderation decisions. +func (c *ClawHubRegistry) DownloadAndInstall( + ctx context.Context, + slug, version, targetDir string, +) (*InstallResult, error) { + if err := utils.ValidateSkillIdentifier(slug); err != nil { + return nil, fmt.Errorf("invalid slug %q: error: %s", slug, err.Error()) + } + + // Step 1: Fetch metadata (with fallback). + result := &InstallResult{} + meta, err := c.GetSkillMeta(ctx, slug) + if err != nil { + // Fallback: proceed without metadata. + meta = nil + } + + if meta != nil { + result.IsMalwareBlocked = meta.IsMalwareBlocked + result.IsSuspicious = meta.IsSuspicious + result.Summary = meta.Summary + } + + // Step 2: Resolve version. + installVersion := version + if installVersion == "" && meta != nil { + installVersion = meta.LatestVersion + } + if installVersion == "" { + installVersion = "latest" + } + result.Version = installVersion + + // Step 3: Download ZIP to temp file (streams in ~32KB chunks). + u, err := url.Parse(c.baseURL + c.downloadPath) + if err != nil { + return nil, fmt.Errorf("invalid base URL: %w", err) + } + + q := u.Query() + q.Set("slug", slug) + if installVersion != "latest" { + q.Set("version", installVersion) + } + u.RawQuery = q.Encode() + + req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + if c.authToken != "" { + req.Header.Set("Authorization", "Bearer "+c.authToken) + } + + tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize)) + if err != nil { + return nil, fmt.Errorf("download failed: %w", err) + } + defer os.Remove(tmpPath) + + // Step 4: Extract from file on disk. + if err := utils.ExtractZipFile(tmpPath, targetDir); err != nil { + return nil, err + } + + return result, nil +} + +// --- HTTP helper --- + +func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + if c.authToken != "" { + req.Header.Set("Authorization", "Bearer "+c.authToken) + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // Limit response body read to prevent memory issues. + body, err := io.ReadAll(io.LimitReader(resp.Body, int64(c.maxResponseSize))) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + return body, nil +} diff --git a/pkg/skills/clawhub_registry_test.go b/pkg/skills/clawhub_registry_test.go new file mode 100644 index 000000000..65ee638da --- /dev/null +++ b/pkg/skills/clawhub_registry_test.go @@ -0,0 +1,257 @@ +package skills + +import ( + "archive/zip" + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sipeed/picoclaw/pkg/utils" +) + +func newTestRegistry(serverURL, authToken string) *ClawHubRegistry { + return NewClawHubRegistry(ClawHubConfig{ + Enabled: true, + BaseURL: serverURL, + AuthToken: authToken, + }) +} + +func TestClawHubRegistrySearch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v1/search", r.URL.Path) + assert.Equal(t, "github", r.URL.Query().Get("q")) + + slug := "github" + name := "GitHub Integration" + summary := "Interact with GitHub repos" + version := "1.0.0" + + json.NewEncoder(w).Encode(clawhubSearchResponse{ + Results: []clawhubSearchResult{ + {Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version}, + }, + }) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "") + results, err := reg.Search(context.Background(), "github", 5) + + require.NoError(t, err) + require.Len(t, results, 1) + assert.Equal(t, "github", results[0].Slug) + assert.Equal(t, "GitHub Integration", results[0].DisplayName) + assert.InDelta(t, 0.95, results[0].Score, 0.001) + assert.Equal(t, "clawhub", results[0].RegistryName) +} + +func TestClawHubRegistryGetSkillMeta(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v1/skills/github", r.URL.Path) + + json.NewEncoder(w).Encode(clawhubSkillResponse{ + Slug: "github", + DisplayName: "GitHub Integration", + Summary: "Full GitHub API integration", + LatestVersion: &clawhubVersionInfo{ + Version: "2.1.0", + }, + Moderation: &clawhubModerationInfo{ + IsMalwareBlocked: false, + IsSuspicious: true, + }, + }) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "") + meta, err := reg.GetSkillMeta(context.Background(), "github") + + require.NoError(t, err) + assert.Equal(t, "github", meta.Slug) + assert.Equal(t, "2.1.0", meta.LatestVersion) + assert.False(t, meta.IsMalwareBlocked) + assert.True(t, meta.IsSuspicious) +} + +func TestClawHubRegistryGetSkillMetaUnsafeSlug(t *testing.T) { + reg := newTestRegistry("https://example.com", "") + _, err := reg.GetSkillMeta(context.Background(), "../etc/passwd") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid slug") +} + +func TestClawHubRegistryDownloadAndInstall(t *testing.T) { + // Create a valid ZIP in memory. + zipBuf := createTestZip(t, map[string]string{ + "SKILL.md": "---\nname: test-skill\ndescription: A test\n---\nHello skill", + "README.md": "# Test Skill\n", + }) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/skills/test-skill": + // Metadata endpoint. + json.NewEncoder(w).Encode(clawhubSkillResponse{ + Slug: "test-skill", + DisplayName: "Test Skill", + Summary: "A test skill", + LatestVersion: &clawhubVersionInfo{Version: "1.0.0"}, + }) + case "/api/v1/download": + assert.Equal(t, "test-skill", r.URL.Query().Get("slug")) + w.Header().Set("Content-Type", "application/zip") + w.Write(zipBuf) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + tmpDir := t.TempDir() + targetDir := filepath.Join(tmpDir, "test-skill") + + reg := newTestRegistry(srv.URL, "") + result, err := reg.DownloadAndInstall(context.Background(), "test-skill", "1.0.0", targetDir) + + require.NoError(t, err) + assert.Equal(t, "1.0.0", result.Version) + assert.False(t, result.IsMalwareBlocked) + + // Verify extracted files. + skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md")) + require.NoError(t, err) + assert.Contains(t, string(skillContent), "Hello skill") + + readmeContent, err := os.ReadFile(filepath.Join(targetDir, "README.md")) + require.NoError(t, err) + assert.Contains(t, string(readmeContent), "# Test Skill") +} + +func TestClawHubRegistryAuthToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + assert.Equal(t, "Bearer test-token-123", authHeader) + json.NewEncoder(w).Encode(clawhubSearchResponse{Results: nil}) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "test-token-123") + _, _ = reg.Search(context.Background(), "test", 5) +} + +func TestExtractZipPathTraversal(t *testing.T) { + // Create a ZIP with a path traversal entry. + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + // Malicious entry trying to escape directory. + w, err := zw.Create("../../etc/passwd") + require.NoError(t, err) + w.Write([]byte("malicious")) + + zw.Close() + + // Write to temp file for extractZipFile. + tmpZip := filepath.Join(t.TempDir(), "bad.zip") + require.NoError(t, os.WriteFile(tmpZip, buf.Bytes(), 0o644)) + + tmpDir := t.TempDir() + err = utils.ExtractZipFile(tmpZip, tmpDir) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsafe path") +} + +func TestExtractZipWithSubdirectories(t *testing.T) { + zipBuf := createTestZip(t, map[string]string{ + "SKILL.md": "root file", + "scripts/helper.sh": "#!/bin/bash\necho hello", + "examples/demo.yaml": "key: value", + }) + + // Write to temp file for extractZipFile. + tmpZip := filepath.Join(t.TempDir(), "test.zip") + require.NoError(t, os.WriteFile(tmpZip, zipBuf, 0o644)) + + tmpDir := t.TempDir() + targetDir := filepath.Join(tmpDir, "my-skill") + + err := utils.ExtractZipFile(tmpZip, targetDir) + require.NoError(t, err) + + // Verify nested file. + data, err := os.ReadFile(filepath.Join(targetDir, "scripts", "helper.sh")) + require.NoError(t, err) + assert.Contains(t, string(data), "#!/bin/bash") +} + +func TestClawHubRegistrySearchHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal Server Error")) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "") + _, err := reg.Search(context.Background(), "test", 5) + assert.Error(t, err) + assert.Contains(t, err.Error(), "500") +} + +func TestClawHubRegistrySearchNullableFields(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + validSlug := "valid-slug" + validSummary := "valid summary" + + // Return results with various null/empty fields + json.NewEncoder(w).Encode(clawhubSearchResponse{ + Results: []clawhubSearchResult{ + // Case 1: Null Slug -> Skip + {Score: 0.1, Slug: nil, DisplayName: nil, Summary: nil, Version: nil}, + // Case 2: Valid Slug, Null Summary -> Skip + {Score: 0.2, Slug: &validSlug, DisplayName: nil, Summary: nil, Version: nil}, + // Case 3: Valid Slug, Valid Summary, Null Name -> Keep, Name=Slug + {Score: 0.8, Slug: &validSlug, DisplayName: nil, Summary: &validSummary, Version: nil}, + }, + }) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "") + results, err := reg.Search(context.Background(), "test", 5) + + require.NoError(t, err) + require.Len(t, results, 1, "should only return 1 valid result") + + r := results[0] + assert.Equal(t, "valid-slug", r.Slug) + assert.Equal(t, "valid-slug", r.DisplayName, "should fallback name to slug") + assert.Equal(t, "valid summary", r.Summary) +} + +// --- helpers --- + +func createTestZip(t *testing.T, files map[string]string) []byte { + t.Helper() + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + for name, content := range files { + w, err := zw.Create(name) + require.NoError(t, err) + _, err = w.Write([]byte(content)) + require.NoError(t, err) + } + + require.NoError(t, zw.Close()) + return buf.Bytes() +} diff --git a/pkg/skills/installer.go b/pkg/skills/installer.go index a3263c525..3210509df 100644 --- a/pkg/skills/installer.go +++ b/pkg/skills/installer.go @@ -8,7 +8,6 @@ import ( "net/http" "os" "path/filepath" - "strings" "time" ) @@ -24,12 +23,6 @@ type AvailableSkill struct { Tags []string `json:"tags"` } -type BuiltinSkill struct { - Name string `json:"name"` - Path string `json:"path"` - Enabled bool `json:"enabled"` -} - func NewSkillInstaller(workspace string) *SkillInstaller { return &SkillInstaller{ workspace: workspace, @@ -66,12 +59,12 @@ func (si *SkillInstaller) InstallFromGitHub(ctx context.Context, repo string) er return fmt.Errorf("failed to read response: %w", err) } - if err := os.MkdirAll(skillDir, 0755); err != nil { + if err := os.MkdirAll(skillDir, 0o755); err != nil { return fmt.Errorf("failed to create skill directory: %w", err) } skillPath := filepath.Join(skillDir, "SKILL.md") - if err := os.WriteFile(skillPath, body, 0644); err != nil { + if err := os.WriteFile(skillPath, body, 0o644); err != nil { return fmt.Errorf("failed to write skill file: %w", err) } @@ -123,49 +116,3 @@ func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableS return skills, nil } - -func (si *SkillInstaller) ListBuiltinSkills() []BuiltinSkill { - builtinSkillsDir := filepath.Join(filepath.Dir(si.workspace), "picoclaw", "skills") - - entries, err := os.ReadDir(builtinSkillsDir) - if err != nil { - return nil - } - - var skills []BuiltinSkill - for _, entry := range entries { - if entry.IsDir() { - _ = entry - skillName := entry.Name() - skillFile := filepath.Join(builtinSkillsDir, skillName, "SKILL.md") - - data, err := os.ReadFile(skillFile) - description := "" - if err == nil { - content := string(data) - if idx := strings.Index(content, "\n"); idx > 0 { - firstLine := content[:idx] - if strings.Contains(firstLine, "description:") { - descLine := strings.Index(content[idx:], "\n") - if descLine > 0 { - description = strings.TrimSpace(content[idx+descLine : idx+descLine]) - } - } - } - } - - // skill := BuiltinSkill{ - // Name: skillName, - // Path: description, - // Enabled: true, - // } - - status := "✓" - fmt.Printf(" %s %s\n", status, entry.Name()) - if description != "" { - fmt.Printf(" %s\n", description) - } - } - } - return skills -} diff --git a/pkg/skills/loader.go b/pkg/skills/loader.go index 1f952c1f5..eb0d5f322 100644 --- a/pkg/skills/loader.go +++ b/pkg/skills/loader.go @@ -2,11 +2,22 @@ package skills import ( "encoding/json" + "errors" "fmt" + "log/slog" "os" "path/filepath" "regexp" "strings" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`) + +const ( + MaxNameLength = 64 + MaxDescriptionLength = 1024 ) type SkillMetadata struct { @@ -21,6 +32,27 @@ type SkillInfo struct { Description string `json:"description"` } +func (info SkillInfo) validate() error { + var errs error + if info.Name == "" { + errs = errors.Join(errs, errors.New("name is required")) + } else { + if len(info.Name) > MaxNameLength { + errs = errors.Join(errs, fmt.Errorf("name exceeds %d characters", MaxNameLength)) + } + if !namePattern.MatchString(info.Name) { + errs = errors.Join(errs, errors.New("name must be alphanumeric with hyphens")) + } + } + + if info.Description == "" { + errs = errors.Join(errs, errors.New("description is required")) + } else if len(info.Description) > MaxDescriptionLength { + errs = errors.Join(errs, fmt.Errorf("description exceeds %d character", MaxDescriptionLength)) + } + return errs +} + type SkillsLoader struct { workspace string workspaceSkills string // workspace skills (项目级别) @@ -54,6 +86,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo { metadata := sl.getSkillMetadata(skillFile) if metadata != nil { info.Description = metadata.Description + info.Name = metadata.Name + } + if err := info.validate(); err != nil { + slog.Warn("invalid skill from workspace", "name", info.Name, "error", err) + continue } skills = append(skills, info) } @@ -89,6 +126,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo { metadata := sl.getSkillMetadata(skillFile) if metadata != nil { info.Description = metadata.Description + info.Name = metadata.Name + } + if err := info.validate(); err != nil { + slog.Warn("invalid skill from global", "name", info.Name, "error", err) + continue } skills = append(skills, info) } @@ -123,6 +165,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo { metadata := sl.getSkillMetadata(skillFile) if metadata != nil { info.Description = metadata.Description + info.Name = metadata.Name + } + if err := info.validate(); err != nil { + slog.Warn("invalid skill from builtin", "name", info.Name, "error", err) + continue } skills = append(skills, info) } @@ -206,6 +253,11 @@ func (sl *SkillsLoader) BuildSkillsSummary() string { func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata { content, err := os.ReadFile(skillPath) if err != nil { + logger.WarnCF("skills", "Failed to read skill metadata", + map[string]any{ + "skill_path": skillPath, + "error": err.Error(), + }) return nil } @@ -238,10 +290,15 @@ func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata { // parseSimpleYAML parses simple key: value YAML format // Example: name: github\n description: "..." +// Normalizes line endings to handle \n (Unix), \r\n (Windows), and \r (classic Mac) func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string { result := make(map[string]string) - for _, line := range strings.Split(content, "\n") { + // Normalize line endings: convert \r\n and \r to \n + normalized := strings.ReplaceAll(content, "\r\n", "\n") + normalized = strings.ReplaceAll(normalized, "\r", "\n") + + for _, line := range strings.Split(normalized, "\n") { line = strings.TrimSpace(line) if line == "" || strings.HasPrefix(line, "#") { continue @@ -261,9 +318,10 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string { } func (sl *SkillsLoader) extractFrontmatter(content string) string { - // (?s) enables DOTALL mode so . matches newlines - // Match first ---, capture everything until next --- on its own line - re := regexp.MustCompile(`(?s)^---\n(.*)\n---`) + // Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks + // (?s) enables DOTALL so . matches newlines; + // ^--- at start, then ... --- at start of line, honoring all three line ending types + re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---`) match := re.FindStringSubmatch(content) if len(match) > 1 { return match[1] @@ -272,7 +330,11 @@ func (sl *SkillsLoader) extractFrontmatter(content string) string { } func (sl *SkillsLoader) stripFrontmatter(content string) string { - re := regexp.MustCompile(`^---\n.*?\n---\n`) + // Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks + // (?s) enables DOTALL so . matches newlines; + // ^--- at start, then ... --- at start of line, honoring all three line ending types + // Match zero or more trailing line endings after closing --- (handles both with and without blank lines) + re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`) return re.ReplaceAllString(content, "") } diff --git a/pkg/skills/loader_test.go b/pkg/skills/loader_test.go new file mode 100644 index 000000000..aca901d33 --- /dev/null +++ b/pkg/skills/loader_test.go @@ -0,0 +1,197 @@ +package skills + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSkillsInfoValidate(t *testing.T) { + testcases := []struct { + name string + skillName string + description string + wantErr bool + errContains []string + }{ + { + name: "valid-skill", + skillName: "valid-skill", + description: "a valid skill description", + wantErr: false, + }, + { + name: "empty-name", + skillName: "", + description: "description without name", + wantErr: true, + errContains: []string{"name is required"}, + }, + { + name: "empty-description", + skillName: "skill-without-description", + description: "", + wantErr: true, + errContains: []string{"description is required"}, + }, + { + name: "empty-both", + skillName: "", + description: "", + wantErr: true, + errContains: []string{"name is required", "description is required"}, + }, + { + name: "name-with-spaces", + skillName: "skill with spaces", + description: "invalid name with spaces", + wantErr: true, + errContains: []string{"name must be alphanumeric with hyphens"}, + }, + { + name: "name-with-underscore", + skillName: "skill_underscore", + description: "invalid name with underscore", + wantErr: true, + errContains: []string{"name must be alphanumeric with hyphens"}, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + info := SkillInfo{ + Name: tc.skillName, + Description: tc.description, + } + err := info.validate() + if tc.wantErr { + assert.Error(t, err) + for _, msg := range tc.errContains { + assert.ErrorContains(t, err, msg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestExtractFrontmatter(t *testing.T) { + sl := &SkillsLoader{} + + testcases := []struct { + name string + content string + expectedName string + expectedDesc string + lineEndingType string + }{ + { + name: "unix-line-endings", + lineEndingType: "Unix (\\n)", + content: "---\nname: test-skill\ndescription: A test skill\n---\n\n# Skill Content", + expectedName: "test-skill", + expectedDesc: "A test skill", + }, + { + name: "windows-line-endings", + lineEndingType: "Windows (\\r\\n)", + content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n\r\n# Skill Content", + expectedName: "test-skill", + expectedDesc: "A test skill", + }, + { + name: "classic-mac-line-endings", + lineEndingType: "Classic Mac (\\r)", + content: "---\rname: test-skill\rdescription: A test skill\r---\r\r# Skill Content", + expectedName: "test-skill", + expectedDesc: "A test skill", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + // Extract frontmatter + frontmatter := sl.extractFrontmatter(tc.content) + assert.NotEmpty(t, frontmatter, "Frontmatter should be extracted for %s line endings", tc.lineEndingType) + + // Parse YAML to get name and description (parseSimpleYAML now handles all line ending types) + yamlMeta := sl.parseSimpleYAML(frontmatter) + assert.Equal( + t, + tc.expectedName, + yamlMeta["name"], + "Name should be correctly parsed from frontmatter with %s line endings", + tc.lineEndingType, + ) + assert.Equal( + t, + tc.expectedDesc, + yamlMeta["description"], + "Description should be correctly parsed from frontmatter with %s line endings", + tc.lineEndingType, + ) + }) + } +} + +func TestStripFrontmatter(t *testing.T) { + sl := &SkillsLoader{} + + testcases := []struct { + name string + content string + expectedContent string + lineEndingType string + }{ + { + name: "unix-line-endings", + lineEndingType: "Unix (\\n)", + content: "---\nname: test-skill\ndescription: A test skill\n---\n\n# Skill Content", + expectedContent: "# Skill Content", + }, + { + name: "windows-line-endings", + lineEndingType: "Windows (\\r\\n)", + content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n\r\n# Skill Content", + expectedContent: "# Skill Content", + }, + { + name: "classic-mac-line-endings", + lineEndingType: "Classic Mac (\\r)", + content: "---\rname: test-skill\rdescription: A test skill\r---\r\r# Skill Content", + expectedContent: "# Skill Content", + }, + { + name: "unix-line-endings-without-trailing-newline", + lineEndingType: "Unix (\\n) without trailing newline", + content: "---\nname: test-skill\ndescription: A test skill\n---\n# Skill Content", + expectedContent: "# Skill Content", + }, + { + name: "windows-line-endings-without-trailing-newline", + lineEndingType: "Windows (\\r\\n) without trailing newline", + content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n# Skill Content", + expectedContent: "# Skill Content", + }, + { + name: "no-frontmatter", + lineEndingType: "No frontmatter", + content: "# Skill Content\n\nSome content here.", + expectedContent: "# Skill Content\n\nSome content here.", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + result := sl.stripFrontmatter(tc.content) + assert.Equal( + t, + tc.expectedContent, + result, + "Frontmatter should be stripped correctly for %s", + tc.lineEndingType, + ) + }) + } +} diff --git a/pkg/skills/registry.go b/pkg/skills/registry.go new file mode 100644 index 000000000..45ae72253 --- /dev/null +++ b/pkg/skills/registry.go @@ -0,0 +1,223 @@ +package skills + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" +) + +const ( + defaultMaxConcurrentSearches = 2 +) + +// SearchResult represents a single result from a skill registry search. +type SearchResult struct { + Score float64 `json:"score"` + Slug string `json:"slug"` + DisplayName string `json:"display_name"` + Summary string `json:"summary"` + Version string `json:"version"` + RegistryName string `json:"registry_name"` +} + +// SkillMeta holds metadata about a skill from a registry. +type SkillMeta struct { + Slug string `json:"slug"` + DisplayName string `json:"display_name"` + Summary string `json:"summary"` + LatestVersion string `json:"latest_version"` + IsMalwareBlocked bool `json:"is_malware_blocked"` + IsSuspicious bool `json:"is_suspicious"` + RegistryName string `json:"registry_name"` +} + +// InstallResult is returned by DownloadAndInstall to carry metadata +// back to the caller for moderation and user messaging. +type InstallResult struct { + Version string + IsMalwareBlocked bool + IsSuspicious bool + Summary string +} + +// SkillRegistry is the interface that all skill registries must implement. +// Each registry represents a different source of skills (e.g., clawhub.ai) +type SkillRegistry interface { + // Name returns the unique name of this registry (e.g., "clawhub"). + Name() string + // Search searches the registry for skills matching the query. + Search(ctx context.Context, query string, limit int) ([]SearchResult, error) + // GetSkillMeta retrieves metadata for a specific skill by slug. + GetSkillMeta(ctx context.Context, slug string) (*SkillMeta, error) + // DownloadAndInstall fetches metadata, resolves the version, downloads and + // installs the skill to targetDir. Returns an InstallResult with metadata + // for the caller to use for moderation and user messaging. + DownloadAndInstall(ctx context.Context, slug, version, targetDir string) (*InstallResult, error) +} + +// RegistryConfig holds configuration for all skill registries. +// This is the input to NewRegistryManagerFromConfig. +type RegistryConfig struct { + ClawHub ClawHubConfig + MaxConcurrentSearches int +} + +// ClawHubConfig configures the ClawHub registry. +type ClawHubConfig struct { + Enabled bool + BaseURL string + AuthToken string + SearchPath string // e.g. "/api/v1/search" + SkillsPath string // e.g. "/api/v1/skills" + DownloadPath string // e.g. "/api/v1/download" + Timeout int // seconds, 0 = default (30s) + MaxZipSize int // bytes, 0 = default (50MB) + MaxResponseSize int // bytes, 0 = default (2MB) +} + +// RegistryManager coordinates multiple skill registries. +// It fans out search requests and routes installs to the correct registry. +type RegistryManager struct { + registries []SkillRegistry + maxConcurrent int + mu sync.RWMutex +} + +// NewRegistryManager creates an empty RegistryManager. +func NewRegistryManager() *RegistryManager { + return &RegistryManager{ + registries: make([]SkillRegistry, 0), + maxConcurrent: defaultMaxConcurrentSearches, + } +} + +// NewRegistryManagerFromConfig builds a RegistryManager from config, +// instantiating only the enabled registries. +func NewRegistryManagerFromConfig(cfg RegistryConfig) *RegistryManager { + rm := NewRegistryManager() + if cfg.MaxConcurrentSearches > 0 { + rm.maxConcurrent = cfg.MaxConcurrentSearches + } + if cfg.ClawHub.Enabled { + rm.AddRegistry(NewClawHubRegistry(cfg.ClawHub)) + } + return rm +} + +// AddRegistry adds a registry to the manager. +func (rm *RegistryManager) AddRegistry(r SkillRegistry) { + rm.mu.Lock() + defer rm.mu.Unlock() + rm.registries = append(rm.registries, r) +} + +// GetRegistry returns a registry by name, or nil if not found. +func (rm *RegistryManager) GetRegistry(name string) SkillRegistry { + rm.mu.RLock() + defer rm.mu.RUnlock() + for _, r := range rm.registries { + if r.Name() == name { + return r + } + } + return nil +} + +// SearchAll fans out the query to all registries concurrently +// and merges results sorted by score descending. +func (rm *RegistryManager) SearchAll(ctx context.Context, query string, limit int) ([]SearchResult, error) { + rm.mu.RLock() + regs := make([]SkillRegistry, len(rm.registries)) + copy(regs, rm.registries) + rm.mu.RUnlock() + + if len(regs) == 0 { + return nil, fmt.Errorf("no registries configured") + } + + type regResult struct { + results []SearchResult + err error + } + + // Semaphore: limit concurrency. + sem := make(chan struct{}, rm.maxConcurrent) + resultsCh := make(chan regResult, len(regs)) + + var wg sync.WaitGroup + for _, reg := range regs { + wg.Add(1) + go func(r SkillRegistry) { + defer wg.Done() + + // Acquire semaphore slot. + select { + case sem <- struct{}{}: + defer func() { <-sem }() + case <-ctx.Done(): + resultsCh <- regResult{err: ctx.Err()} + return + } + + searchCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) + defer cancel() + + results, err := r.Search(searchCtx, query, limit) + if err != nil { + slog.Warn("registry search failed", "registry", r.Name(), "error", err) + resultsCh <- regResult{err: err} + return + } + resultsCh <- regResult{results: results} + }(reg) + } + + // Close results channel after all goroutines complete. + go func() { + wg.Wait() + close(resultsCh) + }() + + var merged []SearchResult + var lastErr error + + var anyRegistrySucceeded bool + for rr := range resultsCh { + if rr.err != nil { + lastErr = rr.err + continue + } + anyRegistrySucceeded = true + merged = append(merged, rr.results...) + } + + // If all registries failed, return the last error. + if !anyRegistrySucceeded && lastErr != nil { + return nil, fmt.Errorf("all registries failed: %w", lastErr) + } + + // Sort by score descending. + sortByScoreDesc(merged) + + // Clamp to limit. + if limit > 0 && len(merged) > limit { + merged = merged[:limit] + } + + return merged, nil +} + +// sortByScoreDesc sorts SearchResults by Score in descending order (insertion sort — small slices). +func sortByScoreDesc(results []SearchResult) { + for i := 1; i < len(results); i++ { + key := results[i] + j := i - 1 + for j >= 0 && results[j].Score < key.Score { + results[j+1] = results[j] + j-- + } + results[j+1] = key + } +} diff --git a/pkg/skills/registry_test.go b/pkg/skills/registry_test.go new file mode 100644 index 000000000..a4694bd43 --- /dev/null +++ b/pkg/skills/registry_test.go @@ -0,0 +1,180 @@ +package skills + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/sipeed/picoclaw/pkg/utils" +) + +// mockRegistry is a test double implementing SkillRegistry. +type mockRegistry struct { + name string + searchResults []SearchResult + searchErr error + meta *SkillMeta + metaErr error + installResult *InstallResult + installErr error +} + +func (m *mockRegistry) Name() string { return m.name } + +func (m *mockRegistry) Search(_ context.Context, _ string, _ int) ([]SearchResult, error) { + return m.searchResults, m.searchErr +} + +func (m *mockRegistry) GetSkillMeta(_ context.Context, _ string) (*SkillMeta, error) { + return m.meta, m.metaErr +} + +func (m *mockRegistry) DownloadAndInstall(_ context.Context, _, _, _ string) (*InstallResult, error) { + return m.installResult, m.installErr +} + +func TestRegistryManagerSearchAllSingle(t *testing.T) { + mgr := NewRegistryManager() + mgr.AddRegistry(&mockRegistry{ + name: "test", + searchResults: []SearchResult{ + {Slug: "skill-a", Score: 0.9, RegistryName: "test"}, + {Slug: "skill-b", Score: 0.5, RegistryName: "test"}, + }, + }) + + results, err := mgr.SearchAll(context.Background(), "test query", 10) + assert.NoError(t, err) + assert.Len(t, results, 2) + assert.Equal(t, "skill-a", results[0].Slug) +} + +func TestRegistryManagerSearchAllMultiple(t *testing.T) { + mgr := NewRegistryManager() + mgr.AddRegistry(&mockRegistry{ + name: "alpha", + searchResults: []SearchResult{ + {Slug: "skill-a", Score: 0.8, RegistryName: "alpha"}, + }, + }) + mgr.AddRegistry(&mockRegistry{ + name: "beta", + searchResults: []SearchResult{ + {Slug: "skill-b", Score: 0.95, RegistryName: "beta"}, + }, + }) + + results, err := mgr.SearchAll(context.Background(), "test query", 10) + assert.NoError(t, err) + assert.Len(t, results, 2) + // Should be sorted by score descending + assert.Equal(t, "skill-b", results[0].Slug) + assert.Equal(t, "skill-a", results[1].Slug) +} + +func TestRegistryManagerSearchAllOneFailsGracefully(t *testing.T) { + mgr := NewRegistryManager() + mgr.AddRegistry(&mockRegistry{ + name: "failing", + searchErr: fmt.Errorf("network error"), + }) + mgr.AddRegistry(&mockRegistry{ + name: "working", + searchResults: []SearchResult{ + {Slug: "skill-a", Score: 0.8, RegistryName: "working"}, + }, + }) + + results, err := mgr.SearchAll(context.Background(), "test query", 10) + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Equal(t, "skill-a", results[0].Slug) +} + +func TestRegistryManagerSearchAllAllFail(t *testing.T) { + mgr := NewRegistryManager() + mgr.AddRegistry(&mockRegistry{ + name: "fail-1", + searchErr: fmt.Errorf("error 1"), + }) + + _, err := mgr.SearchAll(context.Background(), "test query", 10) + assert.Error(t, err) +} + +func TestRegistryManagerSearchAllNoRegistries(t *testing.T) { + mgr := NewRegistryManager() + _, err := mgr.SearchAll(context.Background(), "test query", 10) + assert.Error(t, err) +} + +func TestRegistryManagerGetRegistry(t *testing.T) { + mgr := NewRegistryManager() + mock := &mockRegistry{name: "clawhub"} + mgr.AddRegistry(mock) + + got := mgr.GetRegistry("clawhub") + assert.NotNil(t, got) + assert.Equal(t, "clawhub", got.Name()) + + got = mgr.GetRegistry("nonexistent") + assert.Nil(t, got) +} + +func TestRegistryManagerSearchAllRespectLimit(t *testing.T) { + mgr := NewRegistryManager() + results := make([]SearchResult, 20) + for i := range results { + results[i] = SearchResult{Slug: fmt.Sprintf("skill-%d", i), Score: float64(20 - i)} + } + mgr.AddRegistry(&mockRegistry{ + name: "test", + searchResults: results, + }) + + got, err := mgr.SearchAll(context.Background(), "test", 5) + assert.NoError(t, err) + assert.Len(t, got, 5) + // Top scores first + assert.Equal(t, "skill-0", got[0].Slug) +} + +func TestRegistryManagerSearchAllTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + time.Sleep(5 * time.Millisecond) // Let context expire. + + mgr := NewRegistryManager() + mgr.AddRegistry(&mockRegistry{ + name: "slow", + searchErr: fmt.Errorf("context deadline exceeded"), + }) + + _, err := mgr.SearchAll(ctx, "test", 5) + assert.Error(t, err) +} + +func TestSortByScoreDesc(t *testing.T) { + results := []SearchResult{ + {Slug: "c", Score: 0.3}, + {Slug: "a", Score: 0.9}, + {Slug: "b", Score: 0.5}, + } + sortByScoreDesc(results) + assert.Equal(t, "a", results[0].Slug) + assert.Equal(t, "b", results[1].Slug) + assert.Equal(t, "c", results[2].Slug) +} + +func TestIsSafeSlug(t *testing.T) { + assert.NoError(t, utils.ValidateSkillIdentifier("github")) + assert.NoError(t, utils.ValidateSkillIdentifier("docker-compose")) + assert.Error(t, utils.ValidateSkillIdentifier("")) + assert.Error(t, utils.ValidateSkillIdentifier("../etc/passwd")) + assert.Error(t, utils.ValidateSkillIdentifier("path/traversal")) + assert.Error(t, utils.ValidateSkillIdentifier("path\\traversal")) +} diff --git a/pkg/skills/search_cache.go b/pkg/skills/search_cache.go new file mode 100644 index 000000000..5d7d2797e --- /dev/null +++ b/pkg/skills/search_cache.go @@ -0,0 +1,229 @@ +package skills + +import ( + "sort" + "strings" + "sync" + "time" +) + +// SearchCache provides lightweight caching for search results. +// It uses trigram-based similarity to match similar queries to cached results, +// avoiding redundant API calls. Thread-safe for concurrent access. +type SearchCache struct { + mu sync.RWMutex + entries map[string]*cacheEntry + order []string // LRU order: oldest first. + maxEntries int + ttl time.Duration +} + +type cacheEntry struct { + query string + trigrams []uint32 + results []SearchResult + createdAt time.Time +} + +// similarityThreshold is the minimum trigram Jaccard similarity for a cache hit. +const similarityThreshold = 0.7 + +// NewSearchCache creates a new search cache. +// maxEntries is the maximum number of cached queries (excess evicts LRU). +// ttl is how long each entry lives before expiration. +func NewSearchCache(maxEntries int, ttl time.Duration) *SearchCache { + if maxEntries <= 0 { + maxEntries = 50 + } + if ttl <= 0 { + ttl = 5 * time.Minute + } + return &SearchCache{ + entries: make(map[string]*cacheEntry), + order: make([]string, 0), + maxEntries: maxEntries, + ttl: ttl, + } +} + +// Get looks up results for a query. Returns cached results and true if found +// (either exact or similar match above threshold). Returns nil, false on miss. +func (sc *SearchCache) Get(query string) ([]SearchResult, bool) { + normalized := normalizeQuery(query) + if normalized == "" { + return nil, false + } + + sc.mu.Lock() + defer sc.mu.Unlock() + + // Exact match first. + if entry, ok := sc.entries[normalized]; ok { + if time.Since(entry.createdAt) < sc.ttl { + sc.moveToEndLocked(normalized) + return copyResults(entry.results), true + } + } + + // Similarity match. + queryTrigrams := buildTrigrams(normalized) + var bestEntry *cacheEntry + var bestSim float64 + + for _, entry := range sc.entries { + if time.Since(entry.createdAt) >= sc.ttl { + continue // Skip expired. + } + sim := jaccardSimilarity(queryTrigrams, entry.trigrams) + if sim > bestSim { + bestSim = sim + bestEntry = entry + } + } + + if bestSim >= similarityThreshold && bestEntry != nil { + sc.moveToEndLocked(bestEntry.query) + return copyResults(bestEntry.results), true + } + + return nil, false +} + +// Put stores results for a query. Evicts the oldest entry if at capacity. +func (sc *SearchCache) Put(query string, results []SearchResult) { + normalized := normalizeQuery(query) + if normalized == "" { + return + } + + sc.mu.Lock() + defer sc.mu.Unlock() + + // Evict expired entries first. + sc.evictExpiredLocked() + + // If already exists, update. + if _, ok := sc.entries[normalized]; ok { + sc.entries[normalized] = &cacheEntry{ + query: normalized, + trigrams: buildTrigrams(normalized), + results: copyResults(results), + createdAt: time.Now(), + } + // Move to end of LRU order. + sc.moveToEndLocked(normalized) + return + } + + // Evict LRU if at capacity. + for len(sc.entries) >= sc.maxEntries && len(sc.order) > 0 { + oldest := sc.order[0] + sc.order = sc.order[1:] + delete(sc.entries, oldest) + } + + // Insert new entry. + sc.entries[normalized] = &cacheEntry{ + query: normalized, + trigrams: buildTrigrams(normalized), + results: copyResults(results), + createdAt: time.Now(), + } + sc.order = append(sc.order, normalized) +} + +// Len returns the number of entries (for testing). +func (sc *SearchCache) Len() int { + sc.mu.RLock() + defer sc.mu.RUnlock() + return len(sc.entries) +} + +// --- internal --- + +func (sc *SearchCache) evictExpiredLocked() { + now := time.Now() + newOrder := make([]string, 0, len(sc.order)) + for _, key := range sc.order { + entry, ok := sc.entries[key] + if !ok || now.Sub(entry.createdAt) >= sc.ttl { + delete(sc.entries, key) + continue + } + newOrder = append(newOrder, key) + } + sc.order = newOrder +} + +func (sc *SearchCache) moveToEndLocked(key string) { + for i, k := range sc.order { + if k == key { + sc.order = append(sc.order[:i], sc.order[i+1:]...) + break + } + } + sc.order = append(sc.order, key) +} + +func normalizeQuery(q string) string { + return strings.ToLower(strings.TrimSpace(q)) +} + +// buildTrigrams generates hash of trigrams from a string. +// Example: "hello" → {"hel", "ell", "llo"} +// "hel" -> 0x0068656c -> 4 bytes; compared to 16 bytes of a string +func buildTrigrams(s string) []uint32 { + if len(s) < 3 { + return nil + } + + trigrams := make([]uint32, 0, len(s)-2) + for i := 0; i <= len(s)-3; i++ { + trigrams = append(trigrams, uint32(s[i])<<16|uint32(s[i+1])<<8|uint32(s[i+2])) + } + + // Sort and Deduplication + sort.Slice(trigrams, func(i, j int) bool { return trigrams[i] < trigrams[j] }) + n := 1 + for i := 1; i < len(trigrams); i++ { + if trigrams[i] != trigrams[i-1] { + trigrams[n] = trigrams[i] + n++ + } + } + + return trigrams[:n] +} + +// jaccardSimilarity computes |A ∩ B| / |A ∪ B|. +func jaccardSimilarity(a, b []uint32) float64 { + if len(a) == 0 && len(b) == 0 { + return 1 + } + i, j := 0, 0 + intersection := 0 + + for i < len(a) && j < len(b) { + if a[i] == b[j] { + intersection++ + i++ + j++ + } else if a[i] < b[j] { + i++ + } else { + j++ + } + } + + union := len(a) + len(b) - intersection + return float64(intersection) / float64(union) +} + +func copyResults(results []SearchResult) []SearchResult { + if results == nil { + return nil + } + cp := make([]SearchResult, len(results)) + copy(cp, results) + return cp +} diff --git a/pkg/skills/search_cache_test.go b/pkg/skills/search_cache_test.go new file mode 100644 index 000000000..816bdfb93 --- /dev/null +++ b/pkg/skills/search_cache_test.go @@ -0,0 +1,200 @@ +package skills + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSearchCacheExactHit(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + results := []SearchResult{ + {Slug: "github", Score: 0.9, RegistryName: "clawhub"}, + {Slug: "docker", Score: 0.7, RegistryName: "clawhub"}, + } + cache.Put("github integration", results) + + got, hit := cache.Get("github integration") + assert.True(t, hit) + assert.Len(t, got, 2) + assert.Equal(t, "github", got[0].Slug) +} + +func TestSearchCacheExactHitCaseInsensitive(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + results := []SearchResult{{Slug: "github", Score: 0.9}} + cache.Put("GitHub Integration", results) + + got, hit := cache.Get("github integration") + assert.True(t, hit) + assert.Len(t, got, 1) +} + +func TestSearchCacheSimilarHit(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + results := []SearchResult{{Slug: "github", Score: 0.9}} + cache.Put("github integration tool", results) + + // "github integration" is very similar to "github integration tool" + got, hit := cache.Get("github integration") + assert.True(t, hit) + assert.Len(t, got, 1) +} + +func TestSearchCacheDissimilarMiss(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + results := []SearchResult{{Slug: "github", Score: 0.9}} + cache.Put("github integration", results) + + // Completely unrelated query + _, hit := cache.Get("database management") + assert.False(t, hit) +} + +func TestSearchCacheTTLExpiration(t *testing.T) { + cache := NewSearchCache(10, 50*time.Millisecond) + + results := []SearchResult{{Slug: "github", Score: 0.9}} + cache.Put("github integration", results) + + // Immediately should hit + _, hit := cache.Get("github integration") + assert.True(t, hit) + + // Wait for expiration + time.Sleep(100 * time.Millisecond) + + _, hit = cache.Get("github integration") + assert.False(t, hit) +} + +func TestSearchCacheLRUEviction(t *testing.T) { + cache := NewSearchCache(3, 5*time.Minute) + + cache.Put("query-1", []SearchResult{{Slug: "a"}}) + cache.Put("query-2", []SearchResult{{Slug: "b"}}) + cache.Put("query-3", []SearchResult{{Slug: "c"}}) + + assert.Equal(t, 3, cache.Len()) + + // Adding a 4th should evict query-1 (oldest) + cache.Put("query-4", []SearchResult{{Slug: "d"}}) + assert.Equal(t, 3, cache.Len()) + + _, hit := cache.Get("query-1") + assert.False(t, hit, "oldest entry should be evicted") + + got, hit := cache.Get("query-4") + assert.True(t, hit) + assert.Equal(t, "d", got[0].Slug) +} + +func TestSearchCacheEmptyQuery(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + _, hit := cache.Get("") + assert.False(t, hit) + + _, hit = cache.Get(" ") + assert.False(t, hit) +} + +func TestSearchCacheResultsCopied(t *testing.T) { + cache := NewSearchCache(10, 5*time.Minute) + + original := []SearchResult{{Slug: "github", Score: 0.9}} + cache.Put("test", original) + + // Mutate original after putting + original[0].Slug = "mutated" + + got, hit := cache.Get("test") + assert.True(t, hit) + assert.Equal(t, "github", got[0].Slug, "cache should hold a copy, not a reference") +} + +func TestBuildTrigrams(t *testing.T) { + trigrams := buildTrigrams("hello") + assert.Contains(t, trigrams, uint32('h')<<16|uint32('e')<<8|uint32('l')) + assert.Contains(t, trigrams, uint32('e')<<16|uint32('l')<<8|uint32('l')) + assert.Contains(t, trigrams, uint32('l')<<16|uint32('l')<<8|uint32('o')) + assert.Len(t, trigrams, 3) +} + +func TestJaccardSimilarity(t *testing.T) { + a := buildTrigrams("github integration") + b := buildTrigrams("github integration tool") + + sim := jaccardSimilarity(a, b) + assert.Greater(t, sim, 0.5, "similar strings should have high sim") + + c := buildTrigrams("completely different query about databases") + sim2 := jaccardSimilarity(a, c) + assert.Less(t, sim2, 0.3, "dissimilar strings should have low sim") +} + +func TestJaccardSimilarityEdgeCases(t *testing.T) { + empty := buildTrigrams("") + nonempty := buildTrigrams("hello") + + assert.Equal(t, 1.0, jaccardSimilarity(empty, empty)) + assert.Equal(t, 0.0, jaccardSimilarity(empty, nonempty)) + assert.Equal(t, 0.0, jaccardSimilarity(nonempty, empty)) +} + +func TestSearchCacheConcurrency(t *testing.T) { + cache := NewSearchCache(50, 5*time.Minute) + done := make(chan struct{}) + + // Concurrent writes + go func() { + for i := 0; i < 100; i++ { + cache.Put("query-write-"+string(rune('a'+i%26)), []SearchResult{{Slug: "x"}}) + } + done <- struct{}{} + }() + + // Concurrent reads + go func() { + for i := 0; i < 100; i++ { + cache.Get("query-write-a") + } + done <- struct{}{} + }() + + <-done +} + +func TestSearchCacheLRUUpdateOnGet(t *testing.T) { + // Capacity 3 + cache := NewSearchCache(3, time.Hour) + + // Fill cache: query-A, query-B, query-C + // Use longer strings to ensure trigrams are generated and avoid false positive similarity + cache.Put("query-A", []SearchResult{{Slug: "A"}}) + cache.Put("query-B", []SearchResult{{Slug: "B"}}) + cache.Put("query-C", []SearchResult{{Slug: "C"}}) + + // Access query-A (should make it most recently used) + if _, found := cache.Get("query-A"); !found { + t.Fatal("query-A should be in cache") + } + + // Add query-D. Should evict query-B (LRU) instead of query-A (which was refreshed) + cache.Put("query-D", []SearchResult{{Slug: "D"}}) + + // Check if query-A is still there + if _, found := cache.Get("query-A"); !found { + t.Fatalf("query-A was evicted! valid LRU should have kept query-A and evicted query-B.") + } + + // Check if query-B is evicted + if _, found := cache.Get("query-B"); found { + t.Fatal("query-B should have been evicted") + } +} diff --git a/pkg/state/state.go b/pkg/state/state.go index 0bb9cd497..1a92f82ed 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -38,7 +38,7 @@ func NewManager(workspace string) *Manager { oldStateFile := filepath.Join(workspace, "state.json") // Create state directory if it doesn't exist - os.MkdirAll(stateDir, 0755) + os.MkdirAll(stateDir, 0o755) sm := &Manager{ workspace: workspace, @@ -139,7 +139,7 @@ func (sm *Manager) saveAtomic() error { } // Write to temp file - if err := os.WriteFile(tempFile, data, 0644); err != nil { + if err := os.WriteFile(tempFile, data, 0o644); err != nil { return fmt.Errorf("failed to write temp file: %w", err) } diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index ce3dd7215..f717a5bb4 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -98,7 +98,7 @@ func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) { // Simulate a crash scenario by manually creating a corrupted temp file tempFile := filepath.Join(tmpDir, "state", "state.json.tmp") - err = os.WriteFile(tempFile, []byte("corrupted data"), 0644) + err = os.WriteFile(tempFile, []byte("corrupted data"), 0o644) if err != nil { t.Fatalf("Failed to create temp file: %v", err) } diff --git a/pkg/swarm/activities.go b/pkg/swarm/activities.go new file mode 100644 index 000000000..48d0a0b66 --- /dev/null +++ b/pkg/swarm/activities.go @@ -0,0 +1,376 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// Activities bridges Temporal workflows with LLM and agent functionality +type Activities struct { + provider providers.LLMProvider + agentLoop *agent.AgentLoop + cfg *config.SwarmConfig + nodeInfo *NodeInfo +} + +// NewActivities creates a new Activities instance +func NewActivities(provider providers.LLMProvider, agentLoop *agent.AgentLoop, cfg *config.SwarmConfig, nodeInfo *NodeInfo) *Activities { + return &Activities{ + provider: provider, + agentLoop: agentLoop, + cfg: cfg, + nodeInfo: nodeInfo, + } +} + +// DecomposeTaskActivity breaks down a complex task into subtasks using LLM analysis +func (a *Activities) DecomposeTaskActivity(ctx context.Context, task *SwarmTask) ([]*SwarmTask, error) { + logger.InfoCF("swarm", "Decomposing task", map[string]interface{}{ + "task_id": task.ID, + "prompt": truncateString(task.Prompt, 100), + }) + + // Test mode: force decomposition for test tasks + if strings.HasPrefix(task.Prompt, "PARALLEL") || strings.HasPrefix(task.Prompt, "TEST PARALLEL") { + logger.InfoCF("swarm", "Test mode: forcing decomposition", map[string]interface{}{ + "task_id": task.ID, + }) + // Extract the actual task from PARALLEL: prefix + actualTask := strings.TrimPrefix(task.Prompt, "PARALLEL:") + actualTask = strings.TrimPrefix(actualTask, "TEST PARALLEL:") + actualTask = strings.TrimSpace(actualTask) + + // If no specific task given, use a default one + if actualTask == "" { + actualTask = "list current directory structure" + } + + // Create test subtasks - each worker will report its own directory + return []*SwarmTask{ + { + ID: task.ID + "-sub-1", + ParentID: task.ID, + Type: TaskTypeDirect, + Priority: task.Priority, + Capability: "general", + Prompt: fmt.Sprintf(`IMPORTANT: First identify yourself by calling swarm_info tool to find your node ID and role. Then: +1. Use 'pwd' to show your current working directory +2. Use 'ls' to list your current directory contents +3. Report in format: "I am [node-id] ([role]), current directory: [path], contents: [directory listing]" + +Task: %s`, actualTask), + Context: task.Context, + Status: TaskPending, + CreatedAt: time.Now().UnixMilli(), + Timeout: task.Timeout, + }, + { + ID: task.ID + "-sub-2", + ParentID: task.ID, + Type: TaskTypeDirect, + Priority: task.Priority, + Capability: "general", + Prompt: fmt.Sprintf(`IMPORTANT: First identify yourself by calling swarm_info tool to find your node ID and role. Then: +1. Use 'pwd' to show your current working directory +2. Use 'ls' to list your current directory contents +3. Report in format: "I am [node-id] ([role]), current directory: [path], contents: [directory listing]" + +Task: %s`, actualTask), + Context: task.Context, + Status: TaskPending, + CreatedAt: time.Now().UnixMilli(), + Timeout: task.Timeout, + }, + }, nil + } + + // Build decomposition prompt + decomposePrompt := fmt.Sprintf(`You are a task decomposition expert. Analyze the following task and determine if it should be decomposed into parallel subtasks. + +TASK: %s + +CAPABILITY REQUIRED: %s + +IMPORTANT DECOMPOSITION RULES: +1. If the task mentions multiple files, multiple operations, or explicitly asks for parallel execution - ALWAYS DECOMPOSE +2. If the task contains the words "parallel", "concurrent", "together", or "simultaneously" - ALWAYS DECOMPOSE +3. Simple single-file operations can be executed directly + +Respond with a JSON object. If the task is simple and can be executed directly, return: +{"decompose": false, "reason": "explanation"} + +If the task should be decomposed, return: +{ + "decompose": true, + "reason": "explanation of why decomposition helps", + "subtasks": [ + {"id": "subtask-1", "prompt": "specific instruction", "capability": "capability_needed"}, + {"id": "subtask-2", "prompt": "specific instruction", "capability": "capability_needed"} + ] +} + +Keep subtasks focused and independently executable. Each subtask should produce a partial result that can be synthesized later.`, + task.Prompt, task.Capability) + + // Call LLM for decomposition decision + messages := []providers.Message{ + {Role: "user", Content: decomposePrompt}, + } + + // Get model from config, provider default, or fallback + model := a.getModel() + + response, err := a.provider.Chat(ctx, messages, nil, model, map[string]interface{}{ + "max_tokens": 2048, + "temperature": 0.3, + }) + if err != nil { + logger.WarnCF("swarm", "LLM decomposition failed, executing directly", map[string]interface{}{ + "error": err.Error(), + }) + return nil, nil // Fall back to direct execution + } + + // Parse LLM response + var result struct { + Decompose bool `json:"decompose"` + Reason string `json:"reason"` + Subtasks []SubtaskSpec `json:"subtasks"` + } + + if err := json.Unmarshal([]byte(response.Content), &result); err != nil { + logger.WarnCF("swarm", "Failed to parse decomposition response, executing directly", map[string]interface{}{ + "error": err.Error(), + "response": response.Content, + }) + return nil, nil + } + + if !result.Decompose { + logger.InfoCF("swarm", "Task marked for direct execution", map[string]interface{}{ + "reason": result.Reason, + }) + return nil, nil + } + + // Create subtask structures + subtasks := make([]*SwarmTask, len(result.Subtasks)) + for i, spec := range result.Subtasks { + subtasks[i] = &SwarmTask{ + ID: fmt.Sprintf("%s-sub-%d", task.ID, i+1), + ParentID: task.ID, + Type: TaskTypeDirect, + Priority: task.Priority, + Capability: spec.Capability, + Prompt: spec.Prompt, + Context: task.Context, + Status: TaskPending, + CreatedAt: time.Now().UnixMilli(), + Timeout: task.Timeout, + } + } + + logger.InfoCF("swarm", "Task decomposed into subtasks", map[string]interface{}{ + "task_id": task.ID, + "subtasks": len(subtasks), + "reason": result.Reason, + }) + + return subtasks, nil +} + +// ExecuteDirectActivity executes a task directly on the local agent +func (a *Activities) ExecuteDirectActivity(ctx context.Context, task *SwarmTask) (string, error) { + logger.InfoCF("swarm", "Executing task directly", map[string]interface{}{ + "task_id": task.ID, + "prompt": truncateString(task.Prompt, 100), + }) + + // Check if agentLoop is available + if a.agentLoop == nil { + return "", fmt.Errorf("agentLoop is not configured") + } + + // Use agent loop's ProcessDirect for execution + result, err := a.agentLoop.ProcessDirect(ctx, task.Prompt, "swarm:"+task.ID) + if err != nil { + logger.WarnCF("swarm", "Direct execution failed", map[string]interface{}{ + "task_id": task.ID, + "error": err.Error(), + }) + return "", err + } + + logger.InfoCF("swarm", "Direct execution completed", map[string]interface{}{ + "task_id": task.ID, + "result_length": len(result), + }) + + return result, nil +} + +// ExecuteSubtaskActivity executes a subtask, potentially dispatching to a specialist worker +func (a *Activities) ExecuteSubtaskActivity(ctx context.Context, task *SwarmTask) (string, error) { + logger.InfoCF("swarm", "Executing subtask", map[string]interface{}{ + "task_id": task.ID, + "parent": task.ParentID, + "prompt": truncateString(task.Prompt, 100), + "capability": task.Capability, + }) + + // For now, execute locally using the agent loop + // In a full implementation, this would check for specialist workers + // and dispatch to the appropriate node based on capability + result, err := a.agentLoop.ProcessDirect(ctx, task.Prompt, "swarm:subtask:"+task.ID) + if err != nil { + logger.WarnCF("swarm", "Subtask execution failed", map[string]interface{}{ + "task_id": task.ID, + "error": err.Error(), + }) + return "", err + } + + logger.InfoCF("swarm", "Subtask completed", map[string]interface{}{ + "task_id": task.ID, + "result_length": len(result), + }) + + // Prefix result with node identifier for synthesis + nodeName := a.getNodeName() + result = fmt.Sprintf("=== %s ===\n%s", nodeName, result) + + return result, nil +} + +// getNodeName returns a human-readable name for this node +func (a *Activities) getNodeName() string { + if a.nodeInfo == nil { + return "unknown-node" + } + + // Try to use SID (service ID) as the primary identifier + if sid := a.nodeInfo.Metadata["sid"]; sid != "" { + return sid + } + + // Fall back to node ID (shortened) + nodeID := a.nodeInfo.ID + if len(nodeID) > 12 { + nodeID = nodeID[:12] + } + return nodeID +} + +// SynthesizeResultsActivity combines subtask results into a coherent final response +func (a *Activities) SynthesizeResultsActivity(ctx context.Context, task *SwarmTask, results []string) (string, error) { + logger.InfoCF("swarm", "Synthesizing results", map[string]interface{}{ + "task_id": task.ID, + "results": len(results), + }) + + // Build synthesis prompt + var resultsBlock strings.Builder + resultsBlock.WriteString(fmt.Sprintf("ORIGINAL TASK: %s\n\n", task.Prompt)) + resultsBlock.WriteString("SUBTASK RESULTS:\n\n") + + for i, result := range results { + // Skip failed results + if strings.Contains(result, "[FAILED]") { + resultsBlock.WriteString(fmt.Sprintf("[Result %d - FAILED]\n%s\n\n", i+1, result)) + continue + } + // Truncate very long results for the synthesis prompt + truncated := result + if len(result) > 2000 { + truncated = result[:2000] + "\n...[truncated]" + } + resultsBlock.WriteString(fmt.Sprintf("[Result %d]\n%s\n\n", i+1, truncated)) + } + + synthesisPrompt := fmt.Sprintf(`You are synthesizing results from parallel task execution. + +%s + +Your job: +1. Analyze all subtask results +2. Identify key findings and insights +3. Create a coherent, unified response that addresses the original task +4. If any results failed or contain errors, acknowledge them appropriately +5. Present the final answer in a clear, well-structured format + +Provide a comprehensive synthesis that directly addresses the original task.`, resultsBlock.String()) + + messages := []providers.Message{ + {Role: "user", Content: synthesisPrompt}, + } + + // Get model from config, provider default, or fallback + model := a.getModel() + + response, err := a.provider.Chat(ctx, messages, nil, model, map[string]interface{}{ + "max_tokens": 4096, + "temperature": 0.5, + }) + if err != nil { + logger.WarnCF("swarm", "LLM synthesis failed, returning error for Temporal retry", map[string]interface{}{ + "error": err.Error(), + }) + // Return error to trigger Temporal retry + // Temporal will retry up to MaximumAttempts (3) before giving up + return "", fmt.Errorf("LLM synthesis failed: %w", err) + } + + logger.InfoCF("swarm", "Synthesis completed", map[string]interface{}{ + "task_id": task.ID, + "final_length": len(response.Content), + }) + + return response.Content, nil +} + +// SubtaskSpec defines a subtask from decomposition +type SubtaskSpec struct { + ID string `json:"id"` + Prompt string `json:"prompt"` + Capability string `json:"capability"` +} + +// truncateString limits string length for logging +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +// getModel returns the LLM model to use for swarm tasks +// Priority: config > provider default > fallback +func (a *Activities) getModel() string { + // 1. Check if model is configured in swarm config + if a.cfg != nil && a.cfg.Temporal.Model != "" { + return a.cfg.Temporal.Model + } + + // 2. Fall back to provider's default model + if a.provider != nil { + if model := a.provider.GetDefaultModel(); model != "" { + return model + } + } + + // 3. Final fallback + return "gpt-4" +} diff --git a/pkg/swarm/activities_test.go b/pkg/swarm/activities_test.go new file mode 100644 index 000000000..2e32e534f --- /dev/null +++ b/pkg/swarm/activities_test.go @@ -0,0 +1,147 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "testing" + + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// MockLLMProvider is a mock implementation of LLMProvider for testing +type MockLLMProvider struct { + mock.Mock +} + +func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + args := m.Called(ctx, messages, tools, model, options) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*providers.LLMResponse), args.Error(1) +} + +func (m *MockLLMProvider) GetDefaultModel() string { + args := m.Called() + return args.String(0) +} + +// AgentLoopInterface defines the methods we need from AgentLoop for testing +type AgentLoopInterface interface { + ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) +} + +// MockAgentLoop is a mock implementation of AgentLoop interface for testing +type MockAgentLoop struct { + mock.Mock +} + +func (m *MockAgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) { + args := m.Called(ctx, content, sessionKey) + return args.String(0), args.Error(1) +} + +// mockAgentLoopWrapper wraps MockAgentLoop to conform to the *agent.AgentLoop type expectation +// In practice, you would create a proper interface or use dependency injection +// For now, we'll modify the activities to use an interface +func mockAgentLoopAsPtr(mock *MockAgentLoop) *agent.AgentLoop { + // This is a workaround for testing - in production, use proper interfaces + // For now, we just pass nil and test without agentLoop + return nil +} + +func TestNewActivities(t *testing.T) { + cfg := &config.SwarmConfig{} + + activities := NewActivities(nil, nil, cfg, nil) + assert.NotNil(t, activities) + assert.Equal(t, cfg, activities.cfg) +} + +func TestActivities_DecomposeTaskActivity_SimpleTask(t *testing.T) { + mockProvider := new(MockLLMProvider) + cfg := &config.SwarmConfig{} + + // Create activities with nil agentLoop for this test + activities := NewActivities(mockProvider, nil, cfg, nil) + + task := &SwarmTask{ + ID: "test-task-1", + Type: TaskTypeWorkflow, + Prompt: "Simple task that doesn't need decomposition", + Capability: "general", + } + + // Mock LLM to return non-decompose response + mockProvider.On("Chat", mock.Anything, mock.Anything, mock.Anything, "gpt-4", mock.Anything). + Return(&providers.LLMResponse{ + Content: `{"decompose": false, "reason": "Task is simple enough to execute directly"}`, + }, nil) + + ctx := context.Background() + subtasks, err := activities.DecomposeTaskActivity(ctx, task) + + require.NoError(t, err) + assert.Nil(t, subtasks, "Simple tasks should not be decomposed") + mockProvider.AssertExpectations(t) +} + +func TestActivities_ExecuteDirectActivity(t *testing.T) { + mockProvider := new(MockLLMProvider) + cfg := &config.SwarmConfig{} + + // Use nil for agentLoop in this test + activities := NewActivities(mockProvider, nil, cfg, nil) + + task := &SwarmTask{ + ID: "test-task-2", + Prompt: "Execute this directly", + Capability: "general", + } + + ctx := context.Background() + _, err := activities.ExecuteDirectActivity(ctx, task) + + // Should fail because agentLoop is nil + assert.Error(t, err, "Should error when agentLoop is nil") +} + +func TestActivities_SynthesizeResultsActivity(t *testing.T) { + mockProvider := new(MockLLMProvider) + cfg := &config.SwarmConfig{} + + activities := NewActivities(mockProvider, nil, cfg, nil) + + task := &SwarmTask{ + ID: "test-task-3", + Prompt: "Original task", + Capability: "general", + } + + results := []string{ + "Result 1: Data analysis complete", + "Result 2: Report generated", + } + + // Mock LLM to return synthesis result + mockProvider.On("Chat", mock.Anything, mock.Anything, mock.Anything, "gpt-4", mock.Anything). + Return(&providers.LLMResponse{ + Content: "Synthesized final result combining all findings", + }, nil) + + ctx := context.Background() + synthesized, err := activities.SynthesizeResultsActivity(ctx, task, results) + + require.NoError(t, err) + assert.Contains(t, synthesized, "Synthesized final result") + mockProvider.AssertExpectations(t) +} diff --git a/pkg/swarm/checkpoint.go b/pkg/swarm/checkpoint.go new file mode 100644 index 000000000..295d3f74f --- /dev/null +++ b/pkg/swarm/checkpoint.go @@ -0,0 +1,317 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + // CheckpointBucketName is the KV bucket for task checkpoints + CheckpointBucketName = "PICOCLAW_CHECKPOINTS" +) + +// CheckpointStore manages task checkpoints using JetStream KV +type CheckpointStore struct { + bucket nats.KeyValue +} + +// NewCheckpointStore creates a new checkpoint store +func NewCheckpointStore(js nats.JetStreamContext) (*CheckpointStore, error) { + // Create or get KV bucket for checkpoints + bucket, err := js.KeyValue(CheckpointBucketName) + if err != nil { + // Bucket doesn't exist, create it + bucket, err = js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: CheckpointBucketName, + Description: "Task checkpoints for PicoClaw swarm failover", + MaxBytes: 1024 * 1024 * 100, // 100MB + TTL: 24 * time.Hour * 7, // 7 day retention + Storage: nats.FileStorage, + Replicas: 1, + }) + if err != nil { + return nil, fmt.Errorf("failed to create checkpoint bucket: %w", err) + } + logger.InfoC("swarm", fmt.Sprintf("Created checkpoint bucket: %s", CheckpointBucketName)) + } + + return &CheckpointStore{bucket: bucket}, nil +} + +// SaveCheckpoint persists a checkpoint to the KV store +func (s *CheckpointStore) SaveCheckpoint(ctx context.Context, cp *TaskCheckpoint) error { + if cp.CheckpointID == "" { + cp.CheckpointID = fmt.Sprintf("cp-%d", time.Now().UnixNano()) + } + cp.Timestamp = time.Now().UnixMilli() + + // Marshal checkpoint to JSON + data, err := json.Marshal(cp) + if err != nil { + return fmt.Errorf("failed to marshal checkpoint: %w", err) + } + + // Key format: {task_id}_{checkpoint_id} + key := fmt.Sprintf("%s_%s", cp.TaskID, cp.CheckpointID) + + // Use Update for atomic operation with version checking + _, err = s.bucket.Put(key, data) + if err != nil { + return fmt.Errorf("failed to save checkpoint: %w", err) + } + + logger.DebugCF("swarm", "Saved checkpoint", map[string]interface{}{ + "task_id": cp.TaskID, + "checkpoint_id": cp.CheckpointID, + "type": string(cp.Type), + "progress": cp.Progress, + }) + + return nil +} + +// SaveCheckpointAtomic saves a checkpoint with atomic compare-and-swap semantics +// This prevents race conditions during concurrent checkpoint saves +func (s *CheckpointStore) SaveCheckpointAtomic(ctx context.Context, cp *TaskCheckpoint, lastRevision uint64) error { + if cp.CheckpointID == "" { + cp.CheckpointID = fmt.Sprintf("cp-%d", time.Now().UnixNano()) + } + cp.Timestamp = time.Now().UnixMilli() + + // Marshal checkpoint to JSON + data, err := json.Marshal(cp) + if err != nil { + return fmt.Errorf("failed to marshal checkpoint: %w", err) + } + + // Key format: {task_id}_{checkpoint_id} + key := fmt.Sprintf("%s_%s", cp.TaskID, cp.CheckpointID) + + // Atomic update with version check + _, err = s.bucket.Update(key, data, lastRevision) + if err != nil { + return fmt.Errorf("atomic checkpoint save failed: %w", err) + } + + logger.DebugCF("swarm", "Saved checkpoint atomically", map[string]interface{}{ + "task_id": cp.TaskID, + "checkpoint_id": cp.CheckpointID, + "revision": lastRevision, + }) + + return nil +} + +// LoadCheckpoint retrieves the latest checkpoint for a task +func (s *CheckpointStore) LoadCheckpoint(ctx context.Context, taskID string) (*TaskCheckpoint, error) { + // List all checkpoints for this task + checkpoints, err := s.ListCheckpoints(ctx, taskID) + if err != nil { + return nil, err + } + + if len(checkpoints) == 0 { + return nil, fmt.Errorf("no checkpoints found for task %s", taskID) + } + + // Return the most recent checkpoint + return checkpoints[0], nil +} + +// LoadCheckpointByID retrieves a specific checkpoint by ID +func (s *CheckpointStore) LoadCheckpointByID(ctx context.Context, taskID, checkpointID string) (*TaskCheckpoint, error) { + key := fmt.Sprintf("%s_%s", taskID, checkpointID) + + entry, err := s.bucket.Get(key) + if err != nil { + if err == nats.ErrKeyNotFound { + return nil, fmt.Errorf("checkpoint not found: %s", checkpointID) + } + return nil, fmt.Errorf("failed to load checkpoint: %w", err) + } + + var cp TaskCheckpoint + if err := json.Unmarshal(entry.Value(), &cp); err != nil { + return nil, fmt.Errorf("failed to unmarshal checkpoint: %w", err) + } + + return &cp, nil +} + +// ListCheckpoints retrieves all checkpoints for a task, ordered by timestamp (newest first) +func (s *CheckpointStore) ListCheckpoints(ctx context.Context, taskID string) ([]*TaskCheckpoint, error) { + // List all keys with the task prefix + watcher, err := s.bucket.WatchAll(nats.Context(ctx)) + if err != nil { + return nil, fmt.Errorf("failed to watch checkpoints: %w", err) + } + defer watcher.Stop() + + checkpoints := make([]*TaskCheckpoint, 0) + prefix := taskID + "_" + + for { + select { + case entry := <-watcher.Updates(): + if entry == nil { + // No more entries + goto done + } + + if len(entry.Key()) <= len(prefix) { + continue + } + + // Check if key matches our task prefix + if entry.Key()[:len(prefix)] == prefix { + var cp TaskCheckpoint + if err := json.Unmarshal(entry.Value(), &cp); err != nil { + logger.WarnCF("swarm", "Failed to unmarshal checkpoint", map[string]interface{}{ + "key": entry.Key(), + "error": err.Error(), + }) + continue + } + checkpoints = append(checkpoints, &cp) + } + + case <-ctx.Done(): + return nil, ctx.Err() + } + } + +done: + // Sort by timestamp descending (newest first) + sortCheckpoints(checkpoints) + + return checkpoints, nil +} + +// DeleteCheckpoint removes a specific checkpoint +func (s *CheckpointStore) DeleteCheckpoint(ctx context.Context, taskID, checkpointID string) error { + key := fmt.Sprintf("%s_%s", taskID, checkpointID) + + err := s.bucket.Delete(key) + if err != nil { + if err == nats.ErrKeyNotFound { + return fmt.Errorf("checkpoint not found: %s", checkpointID) + } + return fmt.Errorf("failed to delete checkpoint: %w", err) + } + + logger.DebugCF("swarm", "Deleted checkpoint", map[string]interface{}{ + "task_id": taskID, + "checkpoint_id": checkpointID, + }) + + return nil +} + +// DeleteAllCheckpoints removes all checkpoints for a task +func (s *CheckpointStore) DeleteAllCheckpoints(ctx context.Context, taskID string) error { + checkpoints, err := s.ListCheckpoints(ctx, taskID) + if err != nil { + return err + } + + for _, cp := range checkpoints { + if err := s.DeleteCheckpoint(ctx, taskID, cp.CheckpointID); err != nil { + logger.WarnCF("swarm", "Failed to delete checkpoint", map[string]interface{}{ + "task_id": taskID, + "checkpoint_id": cp.CheckpointID, + "error": err.Error(), + }) + } + } + + return nil +} + +// GetCheckpointRevision returns the current revision of a checkpoint for CAS operations +func (s *CheckpointStore) GetCheckpointRevision(ctx context.Context, taskID, checkpointID string) (uint64, error) { + key := fmt.Sprintf("%s_%s", taskID, checkpointID) + + entry, err := s.bucket.Get(key) + if err != nil { + return 0, fmt.Errorf("failed to get checkpoint revision: %w", err) + } + + return entry.Revision(), nil +} + +// PruneOldCheckpoints removes old checkpoints keeping only the most recent N +func (s *CheckpointStore) PruneOldCheckpoints(ctx context.Context, taskID string, keep int) error { + checkpoints, err := s.ListCheckpoints(ctx, taskID) + if err != nil { + return err + } + + if len(checkpoints) <= keep { + return nil + } + + // Delete oldest checkpoints (starting from index 'keep') + for i := keep; i < len(checkpoints); i++ { + cp := checkpoints[i] + if err := s.DeleteCheckpoint(ctx, taskID, cp.CheckpointID); err != nil { + logger.WarnCF("swarm", "Failed to prune old checkpoint", map[string]interface{}{ + "task_id": taskID, + "checkpoint_id": cp.CheckpointID, + "error": err.Error(), + }) + } + } + + return nil +} + +// sortCheckpoints sorts checkpoints by timestamp descending (newest first) +func sortCheckpoints(checkpoints []*TaskCheckpoint) { + // Simple bubble sort for small lists + n := len(checkpoints) + for i := 0; i < n-1; i++ { + for j := 0; j < n-i-1; j++ { + if checkpoints[j].Timestamp < checkpoints[j+1].Timestamp { + checkpoints[j], checkpoints[j+1] = checkpoints[j+1], checkpoints[j] + } + } + } +} + +// CreateProgressCheckpoint creates a progress checkpoint from current execution state +func CreateProgressCheckpoint(taskID, nodeID string, progress float64, partialResult string, state map[string]interface{}) *TaskCheckpoint { + return &TaskCheckpoint{ + CheckpointID: fmt.Sprintf("prog-%d", time.Now().UnixNano()), + TaskID: taskID, + Type: CheckpointTypeProgress, + NodeID: nodeID, + Progress: progress, + PartialResult: partialResult, + State: state, + Timestamp: time.Now().UnixMilli(), + } +} + +// CreateMilestoneCheckpoint creates a milestone checkpoint +func CreateMilestoneCheckpoint(taskID, nodeID string, progress float64, result string, metadata map[string]string) *TaskCheckpoint { + return &TaskCheckpoint{ + CheckpointID: fmt.Sprintf("milestone-%d", time.Now().UnixNano()), + TaskID: taskID, + Type: CheckpointTypeMilestone, + NodeID: nodeID, + Progress: progress, + PartialResult: result, + Metadata: metadata, + Timestamp: time.Now().UnixMilli(), + } +} diff --git a/pkg/swarm/checkpoint_test.go b/pkg/swarm/checkpoint_test.go new file mode 100644 index 000000000..ebf8356fe --- /dev/null +++ b/pkg/swarm/checkpoint_test.go @@ -0,0 +1,261 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCheckpointStore(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + store, err := NewCheckpointStore(tn.JS()) + require.NoError(t, err) + assert.NotNil(t, store) + }) +} + +func TestCheckpointStore_SaveAndLoad(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store, err := NewCheckpointStore(tn.JS()) + require.NoError(t, err) + + // Create a test checkpoint + cp := &TaskCheckpoint{ + CheckpointID: "cp-test-1", + TaskID: "task-1", + Type: CheckpointTypeProgress, + Timestamp: time.Now().UnixMilli(), + NodeID: "node-1", + Progress: 0.5, + State: map[string]interface{}{"step": 2}, + PartialResult: "Partial result", + } + + // Save checkpoint + err = store.SaveCheckpoint(ctx, cp) + require.NoError(t, err) + + // Load checkpoint + loaded, err := store.LoadCheckpoint(ctx, "task-1") + require.NoError(t, err) + assert.Equal(t, cp.CheckpointID, loaded.CheckpointID) + assert.Equal(t, cp.TaskID, loaded.TaskID) + assert.Equal(t, cp.Type, loaded.Type) + assert.Equal(t, cp.NodeID, loaded.NodeID) + assert.Equal(t, cp.Progress, loaded.Progress) + assert.Equal(t, cp.PartialResult, loaded.PartialResult) + }) +} + +func TestCheckpointStore_LoadByID(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store, err := NewCheckpointStore(tn.JS()) + require.NoError(t, err) + + cp := &TaskCheckpoint{ + CheckpointID: "cp-test-2", + TaskID: "task-2", + Type: CheckpointTypeMilestone, + Timestamp: time.Now().UnixMilli(), + NodeID: "node-2", + Progress: 1.0, + PartialResult: "Complete", + } + + err = store.SaveCheckpoint(ctx, cp) + require.NoError(t, err) + + // Load by checkpoint ID + loaded, err := store.LoadCheckpointByID(ctx, "task-2", cp.CheckpointID) + require.NoError(t, err) + assert.NotNil(t, loaded) + assert.Equal(t, cp.CheckpointID, loaded.CheckpointID) + assert.Equal(t, cp.TaskID, loaded.TaskID) + }) +} + +func TestCheckpointStore_ListCheckpoints(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store, err := NewCheckpointStore(tn.JS()) + require.NoError(t, err) + + // Create multiple checkpoints for the same task + for i := 0; i < 3; i++ { + cp := &TaskCheckpoint{ + CheckpointID: fmt.Sprintf("cp-list-%d", i), + TaskID: "task-list", + Type: CheckpointTypeProgress, + Timestamp: time.Now().Add(time.Duration(i) * time.Second).UnixMilli(), + NodeID: "node-1", + Progress: float64(i) * 0.3, + } + err = store.SaveCheckpoint(ctx, cp) + require.NoError(t, err) + } + + // List checkpoints + checkpoints, err := store.ListCheckpoints(ctx, "task-list") + require.NoError(t, err) + assert.Len(t, checkpoints, 3) + }) +} + +func TestCheckpointStore_PruneOldCheckpoints(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store, err := NewCheckpointStore(tn.JS()) + require.NoError(t, err) + + // Create multiple checkpoints + for i := 0; i < 3; i++ { + cp := &TaskCheckpoint{ + CheckpointID: fmt.Sprintf("cp-prune-%d", i), + TaskID: "task-prune", + Type: CheckpointTypeProgress, + Timestamp: time.Now().Add(time.Duration(i) * time.Second).UnixMilli(), + NodeID: "node-1", + Progress: float64(i) * 0.3, + } + err = store.SaveCheckpoint(ctx, cp) + require.NoError(t, err) + } + + // Prune to keep only 1 most recent checkpoint + err = store.PruneOldCheckpoints(ctx, "task-prune", 1) + require.NoError(t, err) + + // Verify only one checkpoint remains + checkpoints, err := store.ListCheckpoints(ctx, "task-prune") + require.NoError(t, err) + assert.Len(t, checkpoints, 1) + }) +} + +func TestCheckpointStore_DeleteCheckpoint(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store, err := NewCheckpointStore(tn.JS()) + require.NoError(t, err) + + cp := &TaskCheckpoint{ + CheckpointID: "cp-delete", + TaskID: "task-delete", + Type: CheckpointTypeProgress, + Timestamp: time.Now().UnixMilli(), + NodeID: "node-1", + Progress: 0.5, + } + + err = store.SaveCheckpoint(ctx, cp) + require.NoError(t, err) + + // Verify it exists + loaded, err := store.LoadCheckpoint(ctx, "task-delete") + require.NoError(t, err) + assert.NotNil(t, loaded) + + // Delete checkpoint + err = store.DeleteCheckpoint(ctx, "task-delete", cp.CheckpointID) + require.NoError(t, err) + + // Verify it's gone - LoadCheckpoint should return error or empty result + _, err = store.LoadCheckpoint(ctx, "task-delete") + // After deleting the only checkpoint, the result should be empty or error + assert.Error(t, err) + }) +} + +func TestCreateProgressCheckpoint(t *testing.T) { + taskID := "test-task-1" + nodeID := "node-1" + progress := 0.5 + partialResult := "Partial work done" + state := map[string]interface{}{ + "key1": "value1", + "key2": 42, + } + + cp := CreateProgressCheckpoint(taskID, nodeID, progress, partialResult, state) + + assert.NotEmpty(t, cp.CheckpointID) + assert.Equal(t, taskID, cp.TaskID) + assert.Equal(t, CheckpointTypeProgress, cp.Type) + assert.Equal(t, nodeID, cp.NodeID) + assert.Equal(t, progress, cp.Progress) + assert.Equal(t, partialResult, cp.PartialResult) + assert.Equal(t, state, cp.State) + assert.NotZero(t, cp.Timestamp) +} + +func TestCreateMilestoneCheckpoint(t *testing.T) { + taskID := "test-task-2" + nodeID := "node-2" + progress := 1.0 + result := "Work completed" + metadata := map[string]string{ + "milestone": "first", + "category": "test", + } + + cp := CreateMilestoneCheckpoint(taskID, nodeID, progress, result, metadata) + + assert.NotEmpty(t, cp.CheckpointID) + assert.Contains(t, cp.CheckpointID, "milestone-") + assert.Equal(t, taskID, cp.TaskID) + assert.Equal(t, CheckpointTypeMilestone, cp.Type) + assert.Equal(t, nodeID, cp.NodeID) + assert.Equal(t, progress, cp.Progress) + assert.Equal(t, result, cp.PartialResult) + assert.Equal(t, metadata, cp.Metadata) + assert.NotZero(t, cp.Timestamp) +} + +func TestCheckpointTypes(t *testing.T) { + types := []CheckpointType{ + CheckpointTypeProgress, + CheckpointTypeMilestone, + CheckpointTypePreFailover, + CheckpointTypeUserCheckpointType, + } + + for _, ct := range types { + assert.NotEmpty(t, string(ct), "Checkpoint type should not be empty") + } +} + +func TestTaskCheckpoint(t *testing.T) { + cp := &TaskCheckpoint{ + CheckpointID: "cp-test-1", + TaskID: "task-1", + Type: CheckpointTypeProgress, + Timestamp: time.Now().UnixMilli(), + NodeID: "node-1", + Progress: 0.75, + State: map[string]interface{}{"step": 3}, + PartialResult: "Partial result", + Context: map[string]interface{}{"messages": []string{"msg1", "msg2"}}, + Metadata: map[string]string{"key": "value"}, + } + + assert.Equal(t, "cp-test-1", cp.CheckpointID) + assert.Equal(t, "task-1", cp.TaskID) + assert.Equal(t, CheckpointTypeProgress, cp.Type) + assert.Equal(t, "node-1", cp.NodeID) + assert.Equal(t, 0.75, cp.Progress) + assert.Equal(t, "Partial result", cp.PartialResult) + assert.NotNil(t, cp.State) + assert.NotNil(t, cp.Context) + assert.NotNil(t, cp.Metadata) +} diff --git a/pkg/swarm/context_pool.go b/pkg/swarm/context_pool.go new file mode 100644 index 000000000..b1441b790 --- /dev/null +++ b/pkg/swarm/context_pool.go @@ -0,0 +1,558 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// ContextPool manages shared context for swarm tasks +// It uses NATS JetStream KV store for distributed context sharing +type ContextPool struct { + js nats.JetStreamContext + kv nats.KeyValue + bucket string + nodeID string + hid string // Human ID (tenant/cluster identity) + sid string // Service ID (instance identity) + mu sync.RWMutex + running bool +} + +// ContextEntry represents a single context entry +type ContextEntry struct { + Key string `json:"key"` + Value interface{} `json:"value"` + Type string `json:"type"` // "string", "number", "boolean", "object", "array" + Timestamp int64 `json:"timestamp"` + NodeID string `json:"node_id"` + ExpiresAt int64 `json:"expires_at,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// TaskContext represents all context for a specific task +type TaskContext struct { + TaskID string `json:"task_id"` + WorkflowID string `json:"workflow_id,omitempty"` + ParentTaskID string `json:"parent_task_id,omitempty"` + Entries map[string]*ContextEntry `json:"entries"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + CreatedBy string `json:"created_by"` + Permissions map[string]ContextPermission `json:"permissions,omitempty"` // H-id -> permission level +} + +// ContextPermission defines access level for context +type ContextPermission string + +const ( + PermRead ContextPermission = "read" + PermWrite ContextPermission = "write" + PermAdmin ContextPermission = "admin" +) + +const ( + // Default context bucket name + contextBucketName = "PICOCLAW_CONTEXT" + + // Default TTL for context entries (24 hours) + defaultContextTTL = 24 * time.Hour + + // Key prefix for task context + taskContextPrefix = "task:" +) + +// NewContextPool creates a new shared context pool +func NewContextPool(js nats.JetStreamContext, nodeID, hid, sid string) *ContextPool { + return &ContextPool{ + js: js, + nodeID: nodeID, + hid: hid, + sid: sid, + bucket: contextBucketName, + } +} + +// Start initializes the context pool KV store +func (cp *ContextPool) Start(ctx context.Context) error { + cp.mu.Lock() + defer cp.mu.Unlock() + + if cp.running { + return nil + } + + // Create or get KV bucket for context storage + kv, err := cp.js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: cp.bucket, + Description: "PicoClaw swarm shared context storage", + TTL: defaultContextTTL, + MaxBytes: 100 * 1024 * 1024, // 100MB default + Storage: nats.FileStorage, + Replicas: 1, + }) + if err != nil { + // Try to get existing bucket + kv, err = cp.js.KeyValue(cp.bucket) + if err != nil { + return fmt.Errorf("failed to create/get context KV store: %w", err) + } + } + + cp.kv = kv + cp.running = true + + logger.InfoCF("swarm", "Context pool started", map[string]interface{}{ + "bucket": cp.bucket, + "node_id": cp.nodeID, + }) + + return nil +} + +// Stop gracefully stops the context pool +func (cp *ContextPool) Stop() error { + cp.mu.Lock() + defer cp.mu.Unlock() + + cp.running = false + return nil +} + +// CreateTaskContext creates a new context for a task +func (cp *ContextPool) CreateTaskContext(taskID, workflowID, parentTaskID string) (*TaskContext, error) { + cp.mu.RLock() + defer cp.mu.RUnlock() + + if !cp.running { + return nil, fmt.Errorf("context pool not running") + } + + now := time.Now().UnixMilli() + + taskCtx := &TaskContext{ + TaskID: taskID, + WorkflowID: workflowID, + ParentTaskID: parentTaskID, + Entries: make(map[string]*ContextEntry), + CreatedAt: now, + UpdatedAt: now, + CreatedBy: cp.nodeID, + Permissions: make(map[string]ContextPermission), + } + + // Grant creator admin permissions + taskCtx.Permissions[cp.hid] = PermAdmin + + // Save to KV store + if err := cp.saveTaskContext(taskCtx); err != nil { + return nil, err + } + + logger.InfoCF("swarm", "Created task context", map[string]interface{}{ + "task_id": taskID, + "workflow_id": workflowID, + "created_by": cp.nodeID, + }) + + return taskCtx, nil +} + +// GetTaskContext retrieves context for a task +func (cp *ContextPool) GetTaskContext(taskID string) (*TaskContext, error) { + cp.mu.RLock() + defer cp.mu.RUnlock() + + if !cp.running { + return nil, fmt.Errorf("context pool not running") + } + + key := cp.taskContextKey(taskID) + entry, err := cp.kv.Get(key) + if err != nil { + if err == nats.ErrKeyNotFound { + // Return empty context instead of error + return &TaskContext{ + TaskID: taskID, + Entries: make(map[string]*ContextEntry), + CreatedAt: time.Now().UnixMilli(), + UpdatedAt: time.Now().UnixMilli(), + CreatedBy: cp.nodeID, + Permissions: make(map[string]ContextPermission), + }, nil + } + return nil, fmt.Errorf("failed to get task context: %w", err) + } + + var taskCtx TaskContext + if err := json.Unmarshal(entry.Value(), &taskCtx); err != nil { + return nil, fmt.Errorf("failed to unmarshal task context: %w", err) + } + + return &taskCtx, nil +} + +// SetEntry sets a context entry for a task +func (cp *ContextPool) SetEntry(taskID, key string, value interface{}) error { + cp.mu.RLock() + defer cp.mu.RUnlock() + + if !cp.running { + return fmt.Errorf("context pool not running") + } + + // Get existing context + taskCtx, err := cp.GetTaskContext(taskID) + if err != nil { + return err + } + + // Check write permission + if perm, ok := taskCtx.Permissions[cp.hid]; !ok || perm == PermRead { + return fmt.Errorf("no write permission for context %s", taskID) + } + + // Determine value type + var valueType string + switch value.(type) { + case string: + valueType = "string" + case int, int32, int64, float32, float64: + valueType = "number" + case bool: + valueType = "boolean" + case map[string]interface{}, []byte: + valueType = "object" + case []interface{}: + valueType = "array" + default: + valueType = "unknown" + } + + // Create/update entry + taskCtx.Entries[key] = &ContextEntry{ + Key: key, + Value: value, + Type: valueType, + Timestamp: time.Now().UnixMilli(), + NodeID: cp.nodeID, + } + taskCtx.UpdatedAt = time.Now().UnixMilli() + + // Save to KV store + return cp.saveTaskContext(taskCtx) +} + +// GetEntry gets a specific context entry for a task +func (cp *ContextPool) GetEntry(taskID, key string) (*ContextEntry, error) { + taskCtx, err := cp.GetTaskContext(taskID) + if err != nil { + return nil, err + } + + entry, ok := taskCtx.Entries[key] + if !ok { + return nil, fmt.Errorf("entry not found: %s", key) + } + + return entry, nil +} + +// GetAllEntries retrieves all entries for a task +func (cp *ContextPool) GetAllEntries(taskID string) (map[string]*ContextEntry, error) { + taskCtx, err := cp.GetTaskContext(taskID) + if err != nil { + return nil, err + } + + return taskCtx.Entries, nil +} + +// DeleteEntry deletes a context entry +func (cp *ContextPool) DeleteEntry(taskID, key string) error { + taskCtx, err := cp.GetTaskContext(taskID) + if err != nil { + return err + } + + // Check write permission + if perm, ok := taskCtx.Permissions[cp.hid]; !ok || perm == PermRead { + return fmt.Errorf("no write permission for context %s", taskID) + } + + delete(taskCtx.Entries, key) + taskCtx.UpdatedAt = time.Now().UnixMilli() + + return cp.saveTaskContext(taskCtx) +} + +// GrantPermission grants permission to another H-id +func (cp *ContextPool) GrantPermission(taskID, targetHID string, perm ContextPermission) error { + taskCtx, err := cp.GetTaskContext(taskID) + if err != nil { + return err + } + + // Only admin can grant permissions + if existingPerm, ok := taskCtx.Permissions[cp.hid]; !ok || existingPerm != PermAdmin { + return fmt.Errorf("no admin permission for context %s", taskID) + } + + taskCtx.Permissions[targetHID] = perm + taskCtx.UpdatedAt = time.Now().UnixMilli() + + return cp.saveTaskContext(taskCtx) +} + +// RevokePermission revokes permission from an H-id +func (cp *ContextPool) RevokePermission(taskID, targetHID string) error { + taskCtx, err := cp.GetTaskContext(taskID) + if err != nil { + return err + } + + // Only admin can revoke permissions + if existingPerm, ok := taskCtx.Permissions[cp.hid]; !ok || existingPerm != PermAdmin { + return fmt.Errorf("no admin permission for context %s", taskID) + } + + delete(taskCtx.Permissions, targetHID) + taskCtx.UpdatedAt = time.Now().UnixMilli() + + return cp.saveTaskContext(taskCtx) +} + +// MergeContext merges entries from parent task context +func (cp *ContextPool) MergeContext(taskID, parentTaskID string) error { + parentCtx, err := cp.GetTaskContext(parentTaskID) + if err != nil { + return err + } + + taskCtx, err := cp.GetTaskContext(taskID) + if err != nil { + return err + } + + // Merge entries from parent + for key, entry := range parentCtx.Entries { + // Only add if not already present + if _, exists := taskCtx.Entries[key]; !exists { + // Copy entry + copiedEntry := *entry + taskCtx.Entries[key] = &copiedEntry + } + } + + // Inherit permissions from parent + for hid, perm := range parentCtx.Permissions { + if _, exists := taskCtx.Permissions[hid]; !exists { + taskCtx.Permissions[hid] = perm + } + } + + taskCtx.ParentTaskID = parentTaskID + taskCtx.UpdatedAt = time.Now().UnixMilli() + + return cp.saveTaskContext(taskCtx) +} + +// DeleteTaskContext removes context for a task +func (cp *ContextPool) DeleteTaskContext(taskID string) error { + cp.mu.RLock() + defer cp.mu.RUnlock() + + if !cp.running { + return fmt.Errorf("context pool not running") + } + + key := cp.taskContextKey(taskID) + return cp.kv.Delete(key) +} + +// ListTaskContexts lists all task contexts (with optional filtering) +func (cp *ContextPool) ListTaskContexts(filter string) ([]*TaskContext, error) { + cp.mu.RLock() + defer cp.mu.RUnlock() + + if !cp.running { + return nil, fmt.Errorf("context pool not running") + } + + watcher, err := cp.kv.WatchAll() + if err != nil { + return nil, fmt.Errorf("failed to create watcher: %w", err) + } + defer watcher.Stop() + + var contexts []*TaskContext + + for entry := range watcher.Updates() { + if entry == nil { + break + } + + // Filter by task prefix + if len(entry.Key()) <= len(taskContextPrefix) || entry.Key()[:len(taskContextPrefix)+1] != taskContextPrefix { + continue + } + + // Apply additional filter if provided + if filter != "" && entry.Key() != cp.taskContextKey(filter) && entry.Key() != taskContextPrefix+filter { + continue + } + + var taskCtx TaskContext + if err := json.Unmarshal(entry.Value(), &taskCtx); err != nil { + continue + } + + contexts = append(contexts, &taskCtx) + } + + return contexts, nil +} + +// SetEntryWithTTL sets a context entry with TTL +func (cp *ContextPool) SetEntryWithTTL(taskID, key string, value interface{}, ttl time.Duration) error { + taskCtx, err := cp.GetTaskContext(taskID) + if err != nil { + return err + } + + // Check write permission + if perm, ok := taskCtx.Permissions[cp.hid]; !ok || perm == PermRead { + return fmt.Errorf("no write permission for context %s", taskID) + } + + // Determine value type + var valueType string + switch value.(type) { + case string: + valueType = "string" + case int, int32, int64, float32, float64: + valueType = "number" + case bool: + valueType = "boolean" + case map[string]interface{}, []byte: + valueType = "object" + case []interface{}: + valueType = "array" + default: + valueType = "unknown" + } + + now := time.Now().UnixMilli() + + // Create/update entry with TTL + taskCtx.Entries[key] = &ContextEntry{ + Key: key, + Value: value, + Type: valueType, + Timestamp: now, + NodeID: cp.nodeID, + ExpiresAt: now + ttl.Milliseconds(), + } + taskCtx.UpdatedAt = now + + return cp.saveTaskContext(taskCtx) +} + +// CleanExpiredEntries removes expired entries from a task context +func (cp *ContextPool) CleanExpiredEntries(taskID string) (int, error) { + taskCtx, err := cp.GetTaskContext(taskID) + if err != nil { + return 0, err + } + + now := time.Now().UnixMilli() + removed := 0 + + for key, entry := range taskCtx.Entries { + if entry.ExpiresAt > 0 && entry.ExpiresAt < now { + delete(taskCtx.Entries, key) + removed++ + } + } + + if removed > 0 { + taskCtx.UpdatedAt = now + if err := cp.saveTaskContext(taskCtx); err != nil { + return 0, err + } + } + + return removed, nil +} + +// saveTaskContext saves task context to KV store +func (cp *ContextPool) saveTaskContext(taskCtx *TaskContext) error { + data, err := json.Marshal(taskCtx) + if err != nil { + return fmt.Errorf("failed to marshal task context: %w", err) + } + + key := cp.taskContextKey(taskCtx.TaskID) + + // Use Update with CreateIfMissing for safety + _, err = cp.kv.Put(key, data) + if err != nil { + return fmt.Errorf("failed to save task context: %w", err) + } + + return nil +} + +// taskContextKey returns the KV store key for a task context +func (cp *ContextPool) taskContextKey(taskID string) string { + return taskContextPrefix + taskID +} + +// GetContextForPrompt returns a formatted string of context entries for use in prompts +func (cp *ContextPool) GetContextForPrompt(taskID string) (string, error) { + entries, err := cp.GetAllEntries(taskID) + if err != nil { + return "", err + } + + if len(entries) == 0 { + return "", nil + } + + var result string + result = fmt.Sprintf("[Shared Context for task %s]\n", taskID) + + for key, entry := range entries { + // Skip expired entries + if entry.ExpiresAt > 0 && entry.ExpiresAt < time.Now().UnixMilli() { + continue + } + + result += fmt.Sprintf("- %s: ", key) + switch v := entry.Value.(type) { + case string: + result += fmt.Sprintf("%q", v) + case map[string]interface{}: + jsonBytes, _ := json.Marshal(v) + result += string(jsonBytes) + case []interface{}: + jsonBytes, _ := json.Marshal(v) + result += string(jsonBytes) + default: + result += fmt.Sprintf("%v", v) + } + result += "\n" + } + + return result, nil +} diff --git a/pkg/swarm/coordinator.go b/pkg/swarm/coordinator.go new file mode 100644 index 000000000..2034a499b --- /dev/null +++ b/pkg/swarm/coordinator.go @@ -0,0 +1,370 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/google/uuid" + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// Coordinator orchestrates task distribution across the swarm +type Coordinator struct { + bridge *NATSBridge + temporal *TemporalClient + discovery *Discovery + agentLoop *agent.AgentLoop + provider providers.LLMProvider + cfg *config.SwarmConfig + localBus *bus.MessageBus + pendingTasks map[string]*SwarmTask + taskResults map[string]chan *TaskResult + mu sync.RWMutex +} + +// NewCoordinator creates a new coordinator +func NewCoordinator( + cfg *config.SwarmConfig, + bridge *NATSBridge, + temporal *TemporalClient, + discovery *Discovery, + agentLoop *agent.AgentLoop, + provider providers.LLMProvider, + localBus *bus.MessageBus, +) *Coordinator { + return &Coordinator{ + bridge: bridge, + temporal: temporal, + discovery: discovery, + agentLoop: agentLoop, + provider: provider, + cfg: cfg, + localBus: localBus, + pendingTasks: make(map[string]*SwarmTask), + taskResults: make(map[string]chan *TaskResult), + } +} + +// Start begins the coordinator +func (c *Coordinator) Start(ctx context.Context) error { + logger.InfoC("swarm", "Coordinator starting") + + // Listen for inbound messages from local bus and dispatch to swarm + go c.processInboundMessages(ctx) + + return nil +} + +// Stop stops the coordinator +func (c *Coordinator) Stop() { + c.mu.Lock() + defer c.mu.Unlock() + + // Close all pending result channels + for _, ch := range c.taskResults { + close(ch) + } + c.taskResults = make(map[string]chan *TaskResult) + c.pendingTasks = make(map[string]*SwarmTask) +} + +// DispatchTask sends a task to the swarm +func (c *Coordinator) DispatchTask(ctx context.Context, task *SwarmTask) (*TaskResult, error) { + // Assign task ID if not set + if task.ID == "" { + task.ID = fmt.Sprintf("task-%s", uuid.New().String()[:8]) + } + + logger.InfoCF("swarm", "Dispatching task", map[string]interface{}{ + "task_id": task.ID, + "type": string(task.Type), + "capability": task.Capability, + }) + + switch task.Type { + case TaskTypeWorkflow: + return c.dispatchWorkflow(ctx, task) + case TaskTypeDirect: + return c.dispatchDirect(ctx, task) + case TaskTypeBroadcast: + return c.dispatchBroadcast(ctx, task) + default: + return nil, fmt.Errorf("unknown task type: %s", task.Type) + } +} + +func (c *Coordinator) dispatchWorkflow(ctx context.Context, task *SwarmTask) (*TaskResult, error) { + if !c.temporal.IsConnected() { + // Fall back to direct dispatch + logger.WarnC("swarm", "Temporal not connected, falling back to direct dispatch") + task.Type = TaskTypeDirect + return c.dispatchDirect(ctx, task) + } + + // Start Temporal workflow + workflowID, err := c.temporal.StartWorkflow(ctx, "SwarmWorkflow", task) + if err != nil { + return nil, err + } + task.WorkflowID = workflowID + + // Wait for result + result, err := c.temporal.GetWorkflowResult(ctx, workflowID) + if err != nil { + return &TaskResult{ + TaskID: task.ID, + Status: string(TaskFailed), + Error: err.Error(), + CompletedAt: time.Now().UnixMilli(), + }, nil + } + + return &TaskResult{ + TaskID: task.ID, + Status: string(TaskDone), + Result: result, + CompletedAt: time.Now().UnixMilli(), + }, nil +} + +func (c *Coordinator) dispatchDirect(ctx context.Context, task *SwarmTask) (*TaskResult, error) { + // Find best worker with priority consideration + if task.AssignedTo == "" { + worker := c.discovery.SelectWorkerWithPriority(task.Capability, task.Priority) + if worker == nil { + // No remote worker available, execute locally + logger.InfoC("swarm", "No remote workers, executing locally") + return c.executeLocally(ctx, task) + } + task.AssignedTo = worker.ID + + logger.DebugCF("swarm", "Selected worker for task", map[string]interface{}{ + "task_id": task.ID, + "worker_id": worker.ID, + "priority": task.Priority, + "load": worker.Load, + }) + } + + // Create result channel + resultCh := make(chan *TaskResult, 1) + c.mu.Lock() + c.taskResults[task.ID] = resultCh + c.pendingTasks[task.ID] = task + c.mu.Unlock() + + defer func() { + c.mu.Lock() + delete(c.taskResults, task.ID) + delete(c.pendingTasks, task.ID) + c.mu.Unlock() + }() + + // Subscribe to result + sub, err := c.bridge.SubscribeTaskResult(task.ID, func(result *TaskResult) { + select { + case resultCh <- result: + default: + } + }) + if err != nil { + return nil, err + } + defer sub.Unsubscribe() + + // Publish task + if err := c.bridge.PublishTask(task); err != nil { + return nil, err + } + + // Wait for result with timeout + timeout := time.Duration(task.Timeout) * time.Millisecond + if timeout == 0 { + timeout = 10 * time.Minute + } + + select { + case result := <-resultCh: + return result, nil + case <-time.After(timeout): + return &TaskResult{ + TaskID: task.ID, + Status: string(TaskFailed), + Error: "task timeout", + CompletedAt: time.Now().UnixMilli(), + }, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (c *Coordinator) dispatchBroadcast(ctx context.Context, task *SwarmTask) (*TaskResult, error) { + // Same as direct but without specific assignment - NATS queue group handles distribution + task.AssignedTo = "" // Clear any assignment + return c.dispatchDirect(ctx, task) +} + +func (c *Coordinator) executeLocally(ctx context.Context, task *SwarmTask) (*TaskResult, error) { + result, err := c.agentLoop.ProcessDirect(ctx, task.Prompt, "swarm:"+task.ID) + + taskResult := &TaskResult{ + TaskID: task.ID, + CompletedAt: time.Now().UnixMilli(), + } + + if err != nil { + taskResult.Status = string(TaskFailed) + taskResult.Error = err.Error() + } else { + taskResult.Status = string(TaskDone) + taskResult.Result = result + } + + return taskResult, nil +} + +func (c *Coordinator) processInboundMessages(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + default: + msg, ok := c.localBus.ConsumeInbound(ctx) + if !ok { + continue + } + + // Analyze message complexity to decide routing + task := c.analyzeAndCreateTask(ctx, msg) + if task == nil { + // Simple task - process locally by agent + go c.processLocally(ctx, msg) + continue + } + + // Complex task - dispatch to swarm + go c.dispatchWorkflowTask(ctx, task, msg) + } + } +} + +// processLocally handles simple tasks by forwarding them to the local agent +func (c *Coordinator) processLocally(ctx context.Context, msg bus.InboundMessage) { + response, err := c.agentLoop.ProcessInboundMessage(ctx, msg) + if err != nil { + response = fmt.Sprintf("Error processing message: %v", err) + } + + if response != "" { + c.localBus.PublishOutbound(bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: response, + }) + } +} + +// dispatchWorkflowTask handles complex tasks by dispatching them to the swarm +func (c *Coordinator) dispatchWorkflowTask(ctx context.Context, task *SwarmTask, msg bus.InboundMessage) { + result, err := c.DispatchTask(ctx, task) + if err != nil { + logger.ErrorCF("swarm", "Task dispatch failed", map[string]interface{}{ + "error": err.Error(), + }) + return + } + + // Send result back to original channel + c.localBus.PublishOutbound(bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: result.Result, + }) +} + +// analyzeAndCreateTask uses heuristics to decide if a task should be distributed +func (c *Coordinator) analyzeAndCreateTask(ctx context.Context, msg bus.InboundMessage) *SwarmTask { + content := msg.Content + + // Check for keywords that indicate workflow/decomposition is needed + workflowKeywords := []string{ + "PARALLEL:", "parallel", "concurrent", + "同时", "分别", "一起", + "analyze all", "compare", "summarize", + "汇总", "分别", "列出", + } + + shouldUseWorkflow := false + for _, keyword := range workflowKeywords { + if contains(content, keyword) { + shouldUseWorkflow = true + break + } + } + + if !shouldUseWorkflow { + // Simple task - process locally + return nil + } + + // Create workflow task for decomposition + task := &SwarmTask{ + ID: fmt.Sprintf("task-%d", time.Now().UnixNano()), + Prompt: content, + Type: TaskTypeWorkflow, + Capability: "general", + Priority: 5, + Status: TaskPending, + CreatedAt: time.Now().UnixMilli(), + Timeout: 300000, // 5 minutes + } + + logger.InfoCF("swarm", "Created workflow task from message", map[string]interface{}{ + "task_id": task.ID, + "prompt": truncateString(content, 50), + }) + + return task +} + +// contains checks if a string contains a substring (case-insensitive) +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && ( + // Simple case-insensitive contains + toLower(s) == toLower(substr) || + findSubstring(toLower(s), toLower(substr)))) +} + +func toLower(s string) string { + // Simple ASCII lowercase + result := make([]byte, len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + if c >= 'A' && c <= 'Z' { + result[i] = c + 32 + } else { + result[i] = c + } + } + return string(result) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/pkg/swarm/coordinator_test.go b/pkg/swarm/coordinator_test.go new file mode 100644 index 000000000..ef2c9de99 --- /dev/null +++ b/pkg/swarm/coordinator_test.go @@ -0,0 +1,290 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestCoordinator_DispatchDirect(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + workerID string + task *SwarmTask + }{ + { + name: "task dispatched to discovered worker", + workerID: "coord-worker-1", + task: &SwarmTask{ + ID: "task-coord001", + Type: TaskTypeDirect, + Capability: "code", + Prompt: "test task", + Status: TaskPending, + Timeout: 5000, + }, + }, + { + name: "task with specific AssignedTo", + workerID: "coord-worker-2", + task: &SwarmTask{ + ID: "task-coord002", + Type: TaskTypeDirect, + Capability: "code", + Prompt: "specific assignment", + Status: TaskPending, + AssignedTo: "coord-worker-2", + Timeout: 5000, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up worker bridge that will receive and respond + workerNode := newTestNodeInfo(tt.workerID, RoleWorker, []string{"code"}, 4) + workerBridge := connectTestBridge(t, url, workerNode) + defer workerBridge.Stop() + + workerBridge.SetOnTaskReceived(func(task *SwarmTask) { + // Simulate worker execution: publish result + result := &TaskResult{ + TaskID: task.ID, + NodeID: tt.workerID, + Status: string(TaskDone), + Result: "executed: " + task.Prompt, + CompletedAt: time.Now().UnixMilli(), + } + workerBridge.PublishTaskResult(result) + }) + + if err := workerBridge.Start(context.Background()); err != nil { + t.Fatalf("worker Start() error: %v", err) + } + + // Set up coordinator + coordNode := newTestNodeInfo("coord-main", RoleCoordinator, nil, 1) + coordBridge := connectTestBridge(t, url, coordNode) + defer coordBridge.Stop() + + if err := coordBridge.Start(context.Background()); err != nil { + t.Fatalf("coordinator bridge Start() error: %v", err) + } + + swarmCfg := newTestSwarmConfig(0) + discovery := NewDiscovery(coordBridge, coordNode, swarmCfg) + // Register the worker in discovery + discovery.handleNodeJoin(workerNode) + + temporal := NewTemporalClient(&config.TemporalConfig{TaskQueue: "test"}) + agentLoop := newTestAgentLoop(t, "local result", nil) + localBus := bus.NewMessageBus() + + coordinator := NewCoordinator(swarmCfg, coordBridge, temporal, discovery, agentLoop, &mockLLMProvider{}, localBus) + + // Give subscriptions time to propagate + time.Sleep(50 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + result, err := coordinator.DispatchTask(ctx, tt.task) + if err != nil { + t.Fatalf("DispatchTask() error: %v", err) + } + if result == nil { + t.Fatal("DispatchTask() returned nil result") + } + if result.Status != string(TaskDone) { + t.Errorf("Status = %q, want %q", result.Status, string(TaskDone)) + } + if !strings.Contains(result.Result, tt.task.Prompt) { + t.Errorf("Result = %q, want it to contain %q", result.Result, tt.task.Prompt) + } + }) + } +} + +func TestCoordinator_DispatchNoWorkers(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + chatResponse string + chatErr error + wantStatus string + wantContains string // check result or error contains this + }{ + { + name: "local fallback success", + chatResponse: "local execution result", + chatErr: nil, + wantStatus: string(TaskDone), + wantContains: "local execution result", + }, + { + name: "local fallback on error", + chatResponse: "", + chatErr: fmt.Errorf("LLM unavailable"), + wantStatus: string(TaskFailed), + wantContains: "LLM unavailable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + coordNode := newTestNodeInfo("coord-noworker", RoleCoordinator, nil, 1) + coordBridge := connectTestBridge(t, url, coordNode) + defer coordBridge.Stop() + + if err := coordBridge.Start(context.Background()); err != nil { + t.Fatalf("Start() error: %v", err) + } + + swarmCfg := newTestSwarmConfig(0) + discovery := NewDiscovery(coordBridge, coordNode, swarmCfg) + // No workers registered -- empty discovery + + temporal := NewTemporalClient(&config.TemporalConfig{TaskQueue: "test"}) + agentLoop := newTestAgentLoop(t, tt.chatResponse, tt.chatErr) + localBus := bus.NewMessageBus() + + coordinator := NewCoordinator(swarmCfg, coordBridge, temporal, discovery, agentLoop, &mockLLMProvider{}, localBus) + + task := &SwarmTask{ + ID: "task-local001", + Type: TaskTypeDirect, + Capability: "code", + Prompt: "test prompt", + Status: TaskPending, + Timeout: 5000, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + result, err := coordinator.DispatchTask(ctx, task) + if err != nil { + t.Fatalf("DispatchTask() error: %v", err) + } + if result == nil { + t.Fatal("DispatchTask() returned nil result") + } + if result.Status != tt.wantStatus { + t.Errorf("Status = %q, want %q", result.Status, tt.wantStatus) + } + // Check either Result or Error field + combined := result.Result + result.Error + if !strings.Contains(combined, tt.wantContains) { + t.Errorf("Result+Error = %q, want it to contain %q", combined, tt.wantContains) + } + }) + } +} + +func TestCoordinator_TaskTimeout(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + // Set up a worker that never responds + workerNode := newTestNodeInfo("timeout-worker", RoleWorker, []string{"code"}, 4) + workerBridge := connectTestBridge(t, url, workerNode) + defer workerBridge.Stop() + + // Intentionally do NOT set onTaskReceived - worker never processes + if err := workerBridge.Start(context.Background()); err != nil { + t.Fatalf("worker Start() error: %v", err) + } + + coordNode := newTestNodeInfo("timeout-coord", RoleCoordinator, nil, 1) + coordBridge := connectTestBridge(t, url, coordNode) + defer coordBridge.Stop() + + if err := coordBridge.Start(context.Background()); err != nil { + t.Fatalf("coord Start() error: %v", err) + } + + swarmCfg := newTestSwarmConfig(0) + discovery := NewDiscovery(coordBridge, coordNode, swarmCfg) + discovery.handleNodeJoin(workerNode) + + temporal := NewTemporalClient(&config.TemporalConfig{TaskQueue: "test"}) + agentLoop := newTestAgentLoop(t, "unused", nil) + localBus := bus.NewMessageBus() + + coordinator := NewCoordinator(swarmCfg, coordBridge, temporal, discovery, agentLoop, &mockLLMProvider{}, localBus) + + time.Sleep(50 * time.Millisecond) + + task := &SwarmTask{ + ID: "task-timeout1", + Type: TaskTypeDirect, + Capability: "code", + Prompt: "will timeout", + Status: TaskPending, + Timeout: 100, // 100ms -- very short + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + result, err := coordinator.DispatchTask(ctx, task) + if err != nil { + t.Fatalf("DispatchTask() error: %v", err) + } + if result == nil { + t.Fatal("DispatchTask() returned nil result") + } + if result.Status != string(TaskFailed) { + t.Errorf("Status = %q, want %q", result.Status, string(TaskFailed)) + } + if !strings.Contains(result.Error, "timeout") { + t.Errorf("Error = %q, want it to contain 'timeout'", result.Error) + } +} + +func TestCoordinator_UnknownTaskType(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + coordNode := newTestNodeInfo("unknown-coord", RoleCoordinator, nil, 1) + coordBridge := connectTestBridge(t, url, coordNode) + defer coordBridge.Stop() + + swarmCfg := newTestSwarmConfig(0) + discovery := NewDiscovery(coordBridge, coordNode, swarmCfg) + temporal := NewTemporalClient(&config.TemporalConfig{TaskQueue: "test"}) + agentLoop := newTestAgentLoop(t, "unused", nil) + localBus := bus.NewMessageBus() + + coordinator := NewCoordinator(swarmCfg, coordBridge, temporal, discovery, agentLoop, &mockLLMProvider{}, localBus) + + task := &SwarmTask{ + ID: "task-unknown01", + Type: SwarmTaskType("invalid"), + Prompt: "should fail", + } + + ctx := context.Background() + _, err := coordinator.DispatchTask(ctx, task) + if err == nil { + t.Fatal("DispatchTask() expected error for unknown task type, got nil") + } + if !strings.Contains(err.Error(), "unknown task type") { + t.Errorf("error = %q, want it to contain 'unknown task type'", err.Error()) + } +} diff --git a/pkg/swarm/cross_hid.go b/pkg/swarm/cross_hid.go new file mode 100644 index 000000000..ae4929346 --- /dev/null +++ b/pkg/swarm/cross_hid.go @@ -0,0 +1,441 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/google/uuid" + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/relation" +) + +// CrossHIDBridge manages cross H-id communication with authorization +type CrossHIDBridge struct { + bridge *NATSBridge + localHID string + authorizer *relation.Authorizer + mu sync.RWMutex + + // Exported H-ids (H-ids we allow to communicate with us) + exported map[string]bool + + // Imported H-ids (H-ids we allow ourselves to communicate with) + imported map[string]bool + + // Subscription handlers + handlers map[string]func(*CrossHIDMessage) +} + +// CrossHIDMessage represents a message sent between H-ids +type CrossHIDMessage struct { + // FromHID is the sender's H-id + FromHID string `json:"from_hid"` + + // FromSID is the sender's S-id + FromSID string `json:"from_sid"` + + // ToHID is the recipient's H-id + ToHID string `json:"to_hid"` + + // Type is the message type + Type string `json:"type"` + + // Payload is the message payload + Payload map[string]interface{} `json:"payload"` + + // Timestamp is when the message was sent + Timestamp int64 `json:"timestamp"` + + // ID is a unique message identifier + ID string `json:"id"` +} + +// NewCrossHIDBridge creates a new cross H-id communication bridge +func NewCrossHIDBridge(bridge *NATSBridge, localHID string, authorizer *relation.Authorizer) *CrossHIDBridge { + return &CrossHIDBridge{ + bridge: bridge, + localHID: localHID, + authorizer: authorizer, + exported: make(map[string]bool), + imported: make(map[string]bool), + handlers: make(map[string]func(*CrossHIDMessage)), + } +} + +// Start initializes the cross H-id bridge +func (b *CrossHIDBridge) Start(ctx context.Context) error { + // Subscribe to cross-domain messages for our H-id + subject := fmt.Sprintf("picoclaw.x.*.%s.>", b.localHID) + + sub, err := b.bridge.conn.Subscribe(subject, b.handleIncomingMessage) + if err != nil { + return fmt.Errorf("failed to subscribe to cross H-id messages: %w", err) + } + + logger.InfoCF("swarm", "Cross H-id bridge started", map[string]interface{}{ + "subject": subject, + }) + + // Keep subscription alive + go func() { + <-ctx.Done() + sub.Unsubscribe() + }() + + return nil +} + +// Export allows another H-id to send messages to us +func (b *CrossHIDBridge) Export(hid string) error { + b.mu.Lock() + defer b.mu.Unlock() + + b.exported[hid] = true + + logger.InfoCF("swarm", "Exported H-id", map[string]interface{}{ + "hid": hid, + }) + + return nil +} + +// Import allows us to send messages to another H-id +func (b *CrossHIDBridge) Import(hid string) error { + b.mu.Lock() + defer b.mu.Unlock() + + b.imported[hid] = true + + logger.InfoCF("swarm", "Imported H-id", map[string]interface{}{ + "hid": hid, + }) + + return nil +} + +// Revoke removes an export +func (b *CrossHIDBridge) RevokeExport(hid string) { + b.mu.Lock() + defer b.mu.Unlock() + + delete(b.exported, hid) +} + +// RevokeImport removes an import +func (b *CrossHIDBridge) RevokeImport(hid string) { + b.mu.Lock() + defer b.mu.Unlock() + + delete(b.imported, hid) +} + +// Send sends a message to another H-id +func (b *CrossHIDBridge) Send(ctx context.Context, toHID, messageType string, payload map[string]interface{}) error { + b.mu.RLock() + _, isImported := b.imported[toHID] + b.mu.RUnlock() + + // Check if we're allowed to send to this H-id + if !isImported { + return fmt.Errorf("H-id %s is not imported", toHID) + } + + msg := &CrossHIDMessage{ + FromHID: b.localHID, + ToHID: toHID, + Type: messageType, + Payload: payload, + Timestamp: currentTimeMillis(), + ID: generateMessageID(), + } + + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + subject := fmt.Sprintf("picoclaw.x.%s.%s.%s", b.localHID, toHID, messageType) + + if err := b.bridge.conn.Publish(subject, data); err != nil { + return fmt.Errorf("failed to publish message: %w", err) + } + + logger.DebugCF("swarm", "Sent cross H-id message", map[string]interface{}{ + "to_hid": toHID, + "type": messageType, + "msg_id": msg.ID, + }) + + return nil +} + +// SendWithAuth sends a message with authorization check +func (b *CrossHIDBridge) SendWithAuth(ctx context.Context, fromSID, toHID, messageType string, payload map[string]interface{}) error { + // Check authorization using relation system + resource := relation.NewResourceID(relation.ResourceNode, toHID) + + authzReq := &relation.AuthzRequest{ + SubjectHID: b.localHID, + SubjectSID: fromSID, + Action: relation.ActionRead, // Use read as "communicate with" + Resource: resource, + } + + result := b.authorizer.Authorize(authzReq) + if !result.Allowed { + return fmt.Errorf("authorization denied: %s", result.Reason) + } + + return b.Send(ctx, toHID, messageType, payload) +} + +// RegisterHandler registers a handler for a specific message type +func (b *CrossHIDBridge) RegisterHandler(messageType string, handler func(*CrossHIDMessage)) { + b.mu.Lock() + defer b.mu.Unlock() + + b.handlers[messageType] = handler +} + +// UnregisterHandler removes a handler +func (b *CrossHIDBridge) UnregisterHandler(messageType string) { + b.mu.Lock() + defer b.mu.Unlock() + + delete(b.handlers, messageType) +} + +// handleIncomingMessage handles incoming cross H-id messages +func (b *CrossHIDBridge) handleIncomingMessage(msg *nats.Msg) { + var message CrossHIDMessage + if err := json.Unmarshal(msg.Data, &message); err != nil { + logger.WarnCF("swarm", "Failed to unmarshal cross H-id message", map[string]interface{}{ + "error": err.Error(), + }) + return + } + + // Verify the message is for us + if message.ToHID != b.localHID { + logger.DebugCF("swarm", "Ignoring cross H-id message for different H-id", map[string]interface{}{ + "to_hid": message.ToHID, + "local_hid": b.localHID, + }) + return + } + + b.mu.RLock() + isExported := b.exported[message.FromHID] + handler, hasHandler := b.handlers[message.Type] + b.mu.RUnlock() + + // Check if the sender is exported (allowed to send to us) + if !isExported { + logger.WarnCF("swarm", "Rejected cross H-id message from non-exported H-id", map[string]interface{}{ + "from_hid": message.FromHID, + }) + return + } + + logger.DebugCF("swarm", "Received cross H-id message", map[string]interface{}{ + "from_hid": message.FromHID, + "type": message.Type, + "msg_id": message.ID, + }) + + // Call handler if registered + if hasHandler && handler != nil { + handler(&message) + } +} + +// GetExported returns all exported H-ids +func (b *CrossHIDBridge) GetExported() []string { + b.mu.RLock() + defer b.mu.RUnlock() + + hids := make([]string, 0, len(b.exported)) + for hid := range b.exported { + hids = append(hids, hid) + } + return hids +} + +// GetImported returns all imported H-ids +func (b *CrossHIDBridge) GetImported() []string { + b.mu.RLock() + defer b.mu.RUnlock() + + hids := make([]string, 0, len(b.imported)) + for hid := range b.imported { + hids = append(hids, hid) + } + return hids +} + +// IsExported checks if an H-id is exported +func (b *CrossHIDBridge) IsExported(hid string) bool { + b.mu.RLock() + defer b.mu.RUnlock() + return b.exported[hid] +} + +// IsImported checks if an H-id is imported +func (b *CrossHIDBridge) IsImported(hid string) bool { + b.mu.RLock() + defer b.mu.RUnlock() + return b.imported[hid] +} + +// currentTimeMillis returns the current time in milliseconds +func currentTimeMillis() int64 { + return time.Now().UnixMilli() +} + +// generateMessageID generates a unique message ID +func generateMessageID() string { + return fmt.Sprintf("xmsg-%s", uuid.New().String()[:8]) +} + +// CrossHIDConfig contains configuration for cross H-id communication +type CrossHIDConfig struct { + // DefaultExportPolicy determines the default export policy + DefaultExportPolicy string // "allow", "deny", "auth" + + // DefaultImportPolicy determines the default import policy + DefaultImportPolicy string // "allow", "deny", "auth" + + // ExportedHIDs is a list of H-ids to export to + ExportedHIDs []string + + // ImportedHIDs is a list of H-ids to import from + ImportedHIDs []string +} + +// ApplyConfig applies a configuration to the bridge +func (b *CrossHIDBridge) ApplyConfig(cfg *CrossHIDConfig) error { + b.mu.Lock() + defer b.mu.Unlock() + + // Clear existing + b.exported = make(map[string]bool) + b.imported = make(map[string]bool) + + // Apply default policy + if cfg.DefaultExportPolicy == "allow" { + // Wildcard - all H-ids allowed (use with caution) + b.exported["*"] = true + } + + if cfg.DefaultImportPolicy == "allow" { + b.imported["*"] = true + } + + // Apply explicit exports + for _, hid := range cfg.ExportedHIDs { + b.exported[hid] = true + } + + // Apply explicit imports + for _, hid := range cfg.ImportedHIDs { + b.imported[hid] = true + } + + return nil +} + +// Message types for cross H-id communication +const ( + // MessageTypeTaskRequest is for requesting a task across H-ids + MessageTypeTaskRequest = "task.request" + + // MessageTypeTaskResponse is for responding to a task + MessageTypeTaskResponse = "task.response" + + // MessageTypeMemoryQuery is for querying memory across H-ids + MessageTypeMemoryQuery = "memory.query" + + // MessageTypeMemoryResponse is for responding to memory queries + MessageTypeMemoryResponse = "memory.response" + + // MessageTypeDiscovery is for discovering nodes across H-ids + MessageTypeDiscovery = "discovery" + + // MessageTypeHeartbeat is for heartbeat across H-ids + MessageTypeHeartbeat = "heartbeat" +) + +// TaskRequestPayload is the payload for a task request +type TaskRequestPayload struct { + TaskID string `json:"task_id"` + Prompt string `json:"prompt"` + Context map[string]interface{} `json:"context"` + Timeout int64 `json:"timeout"` +} + +// TaskResponsePayload is the payload for a task response +type TaskResponsePayload struct { + TaskID string `json:"task_id"` + Result string `json:"result"` + Error string `json:"error,omitempty"` + Completed bool `json:"completed"` +} + +// SendTaskRequest sends a task request to another H-id +func (b *CrossHIDBridge) SendTaskRequest(ctx context.Context, toHID string, payload *TaskRequestPayload) error { + msgPayload := map[string]interface{}{ + "task_id": payload.TaskID, + "prompt": payload.Prompt, + "context": payload.Context, + "timeout": payload.Timeout, + } + return b.Send(ctx, toHID, MessageTypeTaskRequest, msgPayload) +} + +// SendTaskResponse sends a task response to another H-id +func (b *CrossHIDBridge) SendTaskResponse(ctx context.Context, toHID string, payload *TaskResponsePayload) error { + msgPayload := map[string]interface{}{ + "task_id": payload.TaskID, + "result": payload.Result, + "error": payload.Error, + "completed": payload.Completed, + } + return b.Send(ctx, toHID, MessageTypeTaskResponse, msgPayload) +} + +// NewTaskRequest creates a new task request payload +func NewTaskRequest(taskID, prompt string) *TaskRequestPayload { + return &TaskRequestPayload{ + TaskID: taskID, + Prompt: prompt, + Context: make(map[string]interface{}), + Timeout: 600000, // 10 minutes default + } +} + +// NewTaskResponse creates a new task response payload +func NewTaskResponse(taskID, result string, completed bool) *TaskResponsePayload { + return &TaskResponsePayload{ + TaskID: taskID, + Result: result, + Completed: completed, + } +} + +// NewTaskResponseError creates a task response with an error +func NewTaskResponseError(taskID, errMsg string) *TaskResponsePayload { + return &TaskResponsePayload{ + TaskID: taskID, + Error: errMsg, + Completed: false, + } +} diff --git a/pkg/swarm/dag.go b/pkg/swarm/dag.go new file mode 100644 index 000000000..0c1d32034 --- /dev/null +++ b/pkg/swarm/dag.go @@ -0,0 +1,471 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// DAG represents a directed acyclic graph for workflow execution +type DAG struct { + nodes map[string]*DAGNode + edges map[string][]string // adjacency list: node -> dependents + mu sync.RWMutex +} + +// NewDAG creates a new empty DAG +func NewDAG() *DAG { + return &DAG{ + nodes: make(map[string]*DAGNode), + edges: make(map[string][]string), + } +} + +// AddNode adds a node to the DAG +func (d *DAG) AddNode(node *DAGNode) error { + d.mu.Lock() + defer d.mu.Unlock() + + if _, exists := d.nodes[node.ID]; exists { + return fmt.Errorf("node %s already exists", node.ID) + } + + node.Status = DAGNodePending + d.nodes[node.ID] = node + d.edges[node.ID] = []string{} + + return nil +} + +// AddDependency adds a dependency edge between nodes +func (d *DAG) AddDependency(from, to string) error { + d.mu.Lock() + defer d.mu.Unlock() + + if _, exists := d.nodes[from]; !exists { + return fmt.Errorf("node %s does not exist", from) + } + if _, exists := d.nodes[to]; !exists { + return fmt.Errorf("node %s does not exist", to) + } + + // Check for cycles + if d.wouldCreateCycle(from, to) { + return fmt.Errorf("adding edge %s -> %s would create a cycle", from, to) + } + + // Add edge and update node dependencies + d.edges[from] = append(d.edges[from], to) + d.nodes[to].Dependencies = append(d.nodes[to].Dependencies, from) + + return nil +} + +// wouldCreateCycle checks if adding an edge would create a cycle +func (d *DAG) wouldCreateCycle(from, to string) bool { + visited := make(map[string]bool) + return d.hasPathDFS(to, from, visited) +} + +// hasPathDFS performs DFS to check if a path exists +func (d *DAG) hasPathDFS(start, end string, visited map[string]bool) bool { + if start == end { + return true + } + visited[start] = true + + for _, neighbor := range d.edges[start] { + if !visited[neighbor] { + if d.hasPathDFS(neighbor, end, visited) { + return true + } + } + } + + return false +} + +// Validate checks if the DAG is valid (no cycles, all dependencies exist) +func (d *DAG) Validate() error { + d.mu.RLock() + defer d.mu.RUnlock() + + // Build adjacency list from both edges and Dependencies field + adjacency := make(map[string][]string) + for from, toList := range d.edges { + adjacency[from] = append(adjacency[from], toList...) + } + // Also check Dependencies fields for any additional edges + for nodeID, node := range d.nodes { + for _, depID := range node.Dependencies { + // Check if this creates a duplicate edge + exists := false + for _, existing := range adjacency[depID] { + if existing == nodeID { + exists = true + break + } + } + if !exists { + adjacency[depID] = append(adjacency[depID], nodeID) + } + } + } + + // Check for cycles using DFS + visited := make(map[string]bool) + recStack := make(map[string]bool) + + var checkCycleDFS func(string) bool + checkCycleDFS = func(nodeID string) bool { + visited[nodeID] = true + recStack[nodeID] = true + + for _, neighbor := range adjacency[nodeID] { + if !visited[neighbor] { + if checkCycleDFS(neighbor) { + return true + } + } else if recStack[neighbor] { + return true + } + } + + recStack[nodeID] = false + return false + } + + for nodeID := range d.nodes { + if !visited[nodeID] { + if checkCycleDFS(nodeID) { + return fmt.Errorf("cycle detected in DAG") + } + } + } + + // Validate all dependencies exist + for nodeID, node := range d.nodes { + for _, depID := range node.Dependencies { + if _, exists := d.nodes[depID]; !exists { + return fmt.Errorf("node %s depends on non-existent node %s", nodeID, depID) + } + } + } + + return nil +} + +// detectCycleDFS uses DFS with recursion stack to detect cycles +func (d *DAG) detectCycleDFS(nodeID string, visited, recStack map[string]bool) bool { + visited[nodeID] = true + recStack[nodeID] = true + + for _, neighbor := range d.edges[nodeID] { + if !visited[neighbor] { + if d.detectCycleDFS(neighbor, visited, recStack) { + return true + } + } else if recStack[neighbor] { + return true + } + } + + recStack[nodeID] = false + return false +} + +// GetReadyNodes returns all nodes that are ready to execute (dependencies satisfied) +func (d *DAG) GetReadyNodes() []*DAGNode { + d.mu.RLock() + defer d.mu.RUnlock() + + ready := make([]*DAGNode, 0) + + for _, node := range d.nodes { + if node.Status == DAGNodePending { + // Check if all dependencies are completed + allDepsComplete := true + for _, depID := range node.Dependencies { + if depNode, exists := d.nodes[depID]; !exists || depNode.Status != DAGNodeCompleted { + allDepsComplete = false + break + } + } + + if allDepsComplete && len(node.Dependencies) > 0 { + // Has dependencies and they're all complete + ready = append(ready, node) + } else if len(node.Dependencies) == 0 { + // No dependencies - root node + ready = append(ready, node) + } + } + } + + return ready +} + +// GetNode returns a node by ID +func (d *DAG) GetNode(id string) (*DAGNode, bool) { + d.mu.RLock() + defer d.mu.RUnlock() + + node, exists := d.nodes[id] + return node, exists +} + +// UpdateNodeStatus updates the status of a node +func (d *DAG) UpdateNodeStatus(id string, status DAGNodeStatus, result string, err error) { + d.mu.Lock() + defer d.mu.Unlock() + + if node, exists := d.nodes[id]; exists { + node.Status = status + node.Result = result + if err != nil { + node.Error = err.Error() + } + + if status == DAGNodeRunning { + node.StartedAt = time.Now().UnixMilli() + } else if status == DAGNodeCompleted || status == DAGNodeFailed { + node.CompletedAt = time.Now().UnixMilli() + } + } +} + +// IsComplete returns true if all nodes are completed, failed, or skipped +func (d *DAG) IsComplete() bool { + d.mu.RLock() + defer d.mu.RUnlock() + + for _, node := range d.nodes { + if node.Status == DAGNodePending || node.Status == DAGNodeReady || node.Status == DAGNodeRunning { + return false + } + } + + return true +} + +// HasFailed returns true if any node failed +func (d *DAG) HasFailed() bool { + d.mu.RLock() + defer d.mu.RUnlock() + + for _, node := range d.nodes { + if node.Status == DAGNodeFailed { + return true + } + } + + return false +} + +// GetCompletedNodes returns all completed nodes +func (d *DAG) GetCompletedNodes() []*DAGNode { + d.mu.RLock() + defer d.mu.RUnlock() + + completed := make([]*DAGNode, 0) + for _, node := range d.nodes { + if node.Status == DAGNodeCompleted { + completed = append(completed, node) + } + } + + return completed +} + +// GetFailedNodes returns all failed nodes +func (d *DAG) GetFailedNodes() []*DAGNode { + d.mu.RLock() + defer d.mu.RUnlock() + + failed := make([]*DAGNode, 0) + for _, node := range d.nodes { + if node.Status == DAGNodeFailed { + failed = append(failed, node) + } + } + + return failed +} + +// NodeCount returns the total number of nodes +func (d *DAG) NodeCount() int { + d.mu.RLock() + defer d.mu.RUnlock() + + return len(d.nodes) +} + +// DAGExecutor executes DAG workflows with parallel execution support +type DAGExecutor struct { + dag *DAG + coordinator *Coordinator + maxParallel int + nodeResults map[string]string + mu sync.Mutex +} + +// NewDAGExecutor creates a new DAG executor +func NewDAGExecutor(dag *DAG, coordinator *Coordinator, maxParallel int) *DAGExecutor { + if maxParallel <= 0 { + maxParallel = 5 // Default parallelism + } + + return &DAGExecutor{ + dag: dag, + coordinator: coordinator, + maxParallel: maxParallel, + nodeResults: make(map[string]string), + } +} + +// Execute runs the DAG workflow +func (e *DAGExecutor) Execute(ctx context.Context) (map[string]string, error) { + logger.InfoCF("swarm", "Starting DAG execution", map[string]interface{}{ + "nodes": e.dag.NodeCount(), + "max_parallel": e.maxParallel, + }) + + // Validate DAG before execution + if err := e.dag.Validate(); err != nil { + return nil, fmt.Errorf("DAG validation failed: %w", err) + } + + // Track running nodes with semaphore + semaphore := make(chan struct{}, e.maxParallel) + var wg sync.WaitGroup + + // Continue until DAG is complete + for !e.dag.IsComplete() { + // Check for failures + if e.dag.HasFailed() { + return e.nodeResults, fmt.Errorf("DAG execution failed: some nodes failed") + } + + // Get ready nodes + readyNodes := e.dag.GetReadyNodes() + if len(readyNodes) == 0 { + // No ready nodes but DAG not complete - might be waiting for running nodes + if e.hasRunningNodes() { + time.Sleep(100 * time.Millisecond) + continue + } + // No running nodes and no ready nodes - should be complete + break + } + + // Execute ready nodes in parallel + for _, node := range readyNodes { + // Mark as running + e.dag.UpdateNodeStatus(node.ID, DAGNodeRunning, "", nil) + + wg.Add(1) + go func(n *DAGNode) { + defer wg.Done() + + // Acquire semaphore + semaphore <- struct{}{} + defer func() { <-semaphore }() + + // Execute the node's task + result, err := e.executeNode(ctx, n) + + e.mu.Lock() + e.nodeResults[n.ID] = result + e.mu.Unlock() + + if err != nil { + e.dag.UpdateNodeStatus(n.ID, DAGNodeFailed, "", err) + logger.WarnCF("swarm", "DAG node failed", map[string]interface{}{ + "node_id": n.ID, + "error": err.Error(), + }) + } else { + e.dag.UpdateNodeStatus(n.ID, DAGNodeCompleted, result, nil) + logger.DebugCF("swarm", "DAG node completed", map[string]interface{}{ + "node_id": n.ID, + }) + } + }(node) + } + + // Wait a bit before checking for more ready nodes + time.Sleep(50 * time.Millisecond) + } + + // Wait for all running nodes to complete + wg.Wait() + + if e.dag.HasFailed() { + return e.nodeResults, fmt.Errorf("DAG execution completed with failures") + } + + logger.InfoCF("swarm", "DAG execution completed successfully", map[string]interface{}{ + "nodes": len(e.nodeResults), + }) + + return e.nodeResults, nil +} + +// executeNode executes a single DAG node +func (e *DAGExecutor) executeNode(ctx context.Context, node *DAGNode) (string, error) { + // For now, execute through coordinator + // In a full implementation, this could dispatch to specialist workers + result, err := e.coordinator.executeLocally(ctx, node.Task) + if err != nil { + return "", err + } + return result.Result, nil +} + +// hasRunningNodes checks if any nodes are currently running +func (e *DAGExecutor) hasRunningNodes() bool { + _ = e.dag.GetReadyNodes() // Check for any ready nodes + + // Check actual node status in DAG + e.dag.mu.RLock() + defer e.dag.mu.RUnlock() + + for _, node := range e.dag.nodes { + if node.Status == DAGNodeRunning { + return true + } + } + + return false +} + +// BuildDAGFromTasks creates a DAG from a list of tasks with dependencies +func BuildDAGFromTasks(tasks []*DAGNode) (*DAG, error) { + dag := NewDAG() + + // Add all nodes first + for _, task := range tasks { + if err := dag.AddNode(task); err != nil { + return nil, fmt.Errorf("failed to add node %s: %w", task.ID, err) + } + } + + // Add dependencies + for _, task := range tasks { + for _, dep := range task.Dependencies { + if err := dag.AddDependency(dep, task.ID); err != nil { + return nil, fmt.Errorf("failed to add dependency %s -> %s: %w", dep, task.ID, err) + } + } + } + + return dag, nil +} diff --git a/pkg/swarm/dag_test.go b/pkg/swarm/dag_test.go new file mode 100644 index 000000000..e5056d3d1 --- /dev/null +++ b/pkg/swarm/dag_test.go @@ -0,0 +1,245 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewDAG(t *testing.T) { + dag := NewDAG() + assert.NotNil(t, dag) + assert.NotNil(t, dag.nodes) + assert.NotNil(t, dag.edges) + assert.Equal(t, 0, dag.NodeCount()) +} + +func TestDAG_AddNode(t *testing.T) { + dag := NewDAG() + + node := &DAGNode{ + ID: "node-1", + Task: &SwarmTask{ + ID: "task-1", + Prompt: "Task 1", + }, + Status: DAGNodePending, + } + + err := dag.AddNode(node) + require.NoError(t, err) + assert.Equal(t, 1, dag.NodeCount()) + + // Try adding duplicate node + err = dag.AddNode(node) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already exists") +} + +func TestDAG_AddDependency(t *testing.T) { + dag := NewDAG() + + node1 := &DAGNode{ID: "node-1", Task: &SwarmTask{ID: "task-1"}, Status: DAGNodePending} + node2 := &DAGNode{ID: "node-2", Task: &SwarmTask{ID: "task-2"}, Status: DAGNodePending} + + require.NoError(t, dag.AddNode(node1)) + require.NoError(t, dag.AddNode(node2)) + + // Add dependency: node-2 depends on node-1 + err := dag.AddDependency("node-1", "node-2") + require.NoError(t, err) + + // Check that dependency was added + retrieved, _ := dag.GetNode("node-2") + assert.Contains(t, retrieved.Dependencies, "node-1") +} + +func TestDAG_AddDependency_CycleDetection(t *testing.T) { + dag := NewDAG() + + node1 := &DAGNode{ID: "node-1", Task: &SwarmTask{ID: "task-1"}, Status: DAGNodePending} + node2 := &DAGNode{ID: "node-2", Task: &SwarmTask{ID: "task-2"}, Status: DAGNodePending} + node3 := &DAGNode{ID: "node-3", Task: &SwarmTask{ID: "task-3"}, Status: DAGNodePending} + + require.NoError(t, dag.AddNode(node1)) + require.NoError(t, dag.AddNode(node2)) + require.NoError(t, dag.AddNode(node3)) + + // Create dependencies: node-1 -> node-2 -> node-3 + require.NoError(t, dag.AddDependency("node-1", "node-2")) + require.NoError(t, dag.AddDependency("node-2", "node-3")) + + // Try to create a cycle: node-3 -> node-1 + err := dag.AddDependency("node-3", "node-1") + assert.Error(t, err) + assert.Contains(t, err.Error(), "cycle") +} + +func TestDAG_Validate(t *testing.T) { + t.Run("valid DAG", func(t *testing.T) { + dag := NewDAG() + + node1 := &DAGNode{ID: "node-1", Task: &SwarmTask{ID: "task-1"}, Status: DAGNodePending} + node2 := &DAGNode{ID: "node-2", Task: &SwarmTask{ID: "task-2"}, Status: DAGNodePending} + + require.NoError(t, dag.AddNode(node1)) + require.NoError(t, dag.AddNode(node2)) + require.NoError(t, dag.AddDependency("node-1", "node-2")) + + err := dag.Validate() + assert.NoError(t, err) + }) + + t.Run("DAG with cycle", func(t *testing.T) { + dag := NewDAG() + + node1 := &DAGNode{ID: "node-1", Task: &SwarmTask{ID: "task-1"}, Status: DAGNodePending} + node2 := &DAGNode{ID: "node-2", Task: &SwarmTask{ID: "task-2"}, Status: DAGNodePending} + + require.NoError(t, dag.AddNode(node1)) + require.NoError(t, dag.AddNode(node2)) + + // Manually create a cycle + node1.Dependencies = []string{"node-2"} + node2.Dependencies = []string{"node-1"} + + err := dag.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "cycle") + }) +} + +func TestDAG_GetReadyNodes(t *testing.T) { + dag := NewDAG() + + node1 := &DAGNode{ID: "node-1", Task: &SwarmTask{ID: "task-1"}, Status: DAGNodePending} + node2 := &DAGNode{ID: "node-2", Task: &SwarmTask{ID: "task-2"}, Status: DAGNodePending, Dependencies: []string{"node-1"}} + node3 := &DAGNode{ID: "node-3", Task: &SwarmTask{ID: "task-3"}, Status: DAGNodePending, Dependencies: []string{"node-1"}} + + require.NoError(t, dag.AddNode(node1)) + require.NoError(t, dag.AddNode(node2)) + require.NoError(t, dag.AddNode(node3)) + + // Initially only node-1 is ready (no dependencies) + ready := dag.GetReadyNodes() + assert.Len(t, ready, 1) + assert.Equal(t, "node-1", ready[0].ID) + + // Mark node-1 as completed + dag.UpdateNodeStatus("node-1", DAGNodeCompleted, "done", nil) + + // Now node-2 and node-3 should be ready + ready = dag.GetReadyNodes() + assert.Len(t, ready, 2) +} + +func TestDAG_IsComplete(t *testing.T) { + dag := NewDAG() + + node1 := &DAGNode{ID: "node-1", Task: &SwarmTask{ID: "task-1"}, Status: DAGNodePending} + + require.NoError(t, dag.AddNode(node1)) + + assert.False(t, dag.IsComplete()) + + dag.UpdateNodeStatus("node-1", DAGNodeCompleted, "done", nil) + assert.True(t, dag.IsComplete()) +} + +func TestDAG_HasFailed(t *testing.T) { + dag := NewDAG() + + node1 := &DAGNode{ID: "node-1", Task: &SwarmTask{ID: "task-1"}, Status: DAGNodePending} + + require.NoError(t, dag.AddNode(node1)) + + assert.False(t, dag.HasFailed()) + + dag.UpdateNodeStatus("node-1", DAGNodeFailed, "", assert.AnError) + assert.True(t, dag.HasFailed()) +} + +func TestDAG_GetCompletedNodes(t *testing.T) { + dag := NewDAG() + + node1 := &DAGNode{ID: "node-1", Task: &SwarmTask{ID: "task-1"}, Status: DAGNodePending} + node2 := &DAGNode{ID: "node-2", Task: &SwarmTask{ID: "task-2"}, Status: DAGNodePending} + + require.NoError(t, dag.AddNode(node1)) + require.NoError(t, dag.AddNode(node2)) + + dag.UpdateNodeStatus("node-1", DAGNodeCompleted, "done", nil) + dag.UpdateNodeStatus("node-2", DAGNodeRunning, "", nil) + + completed := dag.GetCompletedNodes() + assert.Len(t, completed, 1) + assert.Equal(t, "node-1", completed[0].ID) +} + +func TestDAG_GetFailedNodes(t *testing.T) { + dag := NewDAG() + + node1 := &DAGNode{ID: "node-1", Task: &SwarmTask{ID: "task-1"}, Status: DAGNodePending} + node2 := &DAGNode{ID: "node-2", Task: &SwarmTask{ID: "task-2"}, Status: DAGNodePending} + + require.NoError(t, dag.AddNode(node1)) + require.NoError(t, dag.AddNode(node2)) + + dag.UpdateNodeStatus("node-1", DAGNodeFailed, "", assert.AnError) + dag.UpdateNodeStatus("node-2", DAGNodeCompleted, "done", nil) + + failed := dag.GetFailedNodes() + assert.Len(t, failed, 1) + assert.Equal(t, "node-1", failed[0].ID) +} + +func TestDAGNodeStatus(t *testing.T) { + statuses := []DAGNodeStatus{ + DAGNodePending, + DAGNodeReady, + DAGNodeRunning, + DAGNodeCompleted, + DAGNodeFailed, + DAGNodeSkipped, + } + + for _, status := range statuses { + assert.NotEmpty(t, string(status), "Status should not be empty") + } +} + +func TestBuildDAGFromTasks(t *testing.T) { + tasks := []*DAGNode{ + { + ID: "task-1", + Task: &SwarmTask{ID: "task-1", Prompt: "Task 1"}, + }, + { + ID: "task-2", + Task: &SwarmTask{ID: "task-2", Prompt: "Task 2"}, + Dependencies: []string{"task-1"}, + }, + { + ID: "task-3", + Task: &SwarmTask{ID: "task-3", Prompt: "Task 3"}, + Dependencies: []string{"task-1"}, + }, + } + + dag, err := BuildDAGFromTasks(tasks) + require.NoError(t, err) + assert.Equal(t, 3, dag.NodeCount()) + + // Verify dependencies + task2, _ := dag.GetNode("task-2") + assert.Contains(t, task2.Dependencies, "task-1") + + task3, _ := dag.GetNode("task-3") + assert.Contains(t, task3.Dependencies, "task-1") +} diff --git a/pkg/swarm/dashboard.go b/pkg/swarm/dashboard.go new file mode 100644 index 000000000..75775331d --- /dev/null +++ b/pkg/swarm/dashboard.go @@ -0,0 +1,519 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// Dashboard provides a text-based UI for monitoring swarm status +type Dashboard struct { + manager *Manager + stopChan chan struct{} + mu sync.RWMutex + enabled bool + refresh time.Duration + lastState *DashboardState +} + +// DashboardState represents the current state of the swarm +type DashboardState struct { + Timestamp int64 + ThisNode *NodeInfoSnapshot + Nodes []*NodeInfoSnapshot + Connections *ConnectionStatus + Stats *SwarmStats + LeaderStatus *LeaderInfo +} + +// NodeInfoSnapshot is a serializable snapshot of NodeInfo +type NodeInfoSnapshot struct { + ID string + Role string + Status string + Capabilities []string + Model string + Load float64 + TasksRunning int + MaxTasks int + LastSeen int64 + StartedAt int64 + Uptime string +} + +// ConnectionStatus shows connection states +type ConnectionStatus struct { + NATSConnected bool + TemporalConnected bool + EmbeddedNATS bool + NATSURL string +} + +// SwarmStats provides aggregate statistics +type SwarmStats struct { + TotalNodes int + OnlineNodes int + OfflineNodes int + CoordinatorCount int + WorkerCount int + SpecialistCount int + TotalCapacity int + UsedCapacity int +} + +// LeaderInfo shows election status +type LeaderInfo struct { + Enabled bool + IsLeader bool + LeaderID string + LeaseExpiry int64 +} + +// NewDashboard creates a new dashboard +func NewDashboard(manager *Manager) *Dashboard { + return &Dashboard{ + manager: manager, + stopChan: make(chan struct{}), + refresh: 2 * time.Second, + } +} + +// SetRefreshInterval sets the dashboard refresh interval +func (d *Dashboard) SetRefreshInterval(interval time.Duration) { + d.mu.Lock() + defer d.mu.Unlock() + d.refresh = interval +} + +// Start begins the dashboard update loop +func (d *Dashboard) Start(ctx context.Context) error { + d.mu.Lock() + d.enabled = true + d.mu.Unlock() + + // Initial state + d.updateState() + + logger.InfoC("swarm", "Dashboard started") + + // Start background update loop + go d.runLoop(ctx) + + return nil +} + +// Stop stops the dashboard +func (d *Dashboard) Stop() { + d.mu.Lock() + if d.enabled { + close(d.stopChan) + d.enabled = false + } + d.mu.Unlock() + logger.InfoC("swarm", "Dashboard stopped") +} + +// GetState returns the current dashboard state +func (d *Dashboard) GetState() *DashboardState { + d.mu.RLock() + defer d.mu.RUnlock() + return d.lastState +} + +// runLoop runs the background update loop +func (d *Dashboard) runLoop(ctx context.Context) { + ticker := time.NewTicker(d.refresh) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + d.updateState() + case <-d.stopChan: + return + case <-ctx.Done(): + return + } + } +} + +// updateState updates the dashboard state from the manager +func (d *Dashboard) updateState() { + d.mu.Lock() + defer d.mu.Unlock() + + state := &DashboardState{ + Timestamp: time.Now().UnixMilli(), + } + + // Get this node info + if d.manager != nil && d.manager.nodeInfo != nil { + state.ThisNode = snapshotNodeInfo(d.manager.nodeInfo) + } + + // Get discovered nodes + if d.manager != nil && d.manager.discovery != nil { + nodes := d.manager.discovery.GetAllNodes() + state.Nodes = make([]*NodeInfoSnapshot, 0, len(nodes)) + for _, node := range nodes { + state.Nodes = append(state.Nodes, snapshotNodeInfo(node)) + } + } + + // Get connection status + if d.manager != nil { + state.Connections = &ConnectionStatus{ + NATSConnected: d.manager.IsNATSConnected(), + TemporalConnected: d.manager.IsTemporalConnected(), + } + if d.manager.embeddedNATS != nil { + state.Connections.EmbeddedNATS = true + state.Connections.NATSURL = d.manager.embeddedNATS.ClientURL() + } + + // Get leader status + state.LeaderStatus = &LeaderInfo{ + Enabled: d.manager.electionMgr != nil, + IsLeader: d.manager.IsLeader(), + LeaderID: d.manager.GetLeaderID(), + } + } + + // Calculate statistics + state.Stats = calculateStats(state.Nodes) + + d.lastState = state +} + +// snapshotNodeInfo creates a snapshot of NodeInfo +func snapshotNodeInfo(node *NodeInfo) *NodeInfoSnapshot { + now := time.Now().UnixMilli() + uptime := time.Duration(now - node.StartedAt) + + return &NodeInfoSnapshot{ + ID: node.ID, + Role: string(node.Role), + Status: string(node.Status), + Capabilities: node.Capabilities, + Model: node.Model, + Load: node.Load, + TasksRunning: node.TasksRunning, + MaxTasks: node.MaxTasks, + LastSeen: node.LastSeen, + StartedAt: node.StartedAt, + Uptime: formatUptime(uptime), + } +} + +// calculateStats calculates aggregate statistics +func calculateStats(nodes []*NodeInfoSnapshot) *SwarmStats { + stats := &SwarmStats{ + TotalNodes: len(nodes), + } + + for _, node := range nodes { + switch node.Status { + case "online", "busy": + stats.OnlineNodes++ + case "offline": + stats.OfflineNodes++ + } + + switch node.Role { + case "coordinator": + stats.CoordinatorCount++ + case "worker": + stats.WorkerCount++ + case "specialist": + stats.SpecialistCount++ + } + + stats.TotalCapacity += node.MaxTasks + stats.UsedCapacity += node.TasksRunning + } + + return stats +} + +// formatUptime formats a duration as a human-readable uptime +func formatUptime(d time.Duration) string { + if d < time.Minute { + return fmt.Sprintf("%ds", int(d.Seconds())) + } else if d < time.Hour { + return fmt.Sprintf("%dm", int(d.Minutes())) + } else if d < 24*time.Hour { + return fmt.Sprintf("%dh%dm", int(d.Hours()), int(d.Minutes())%60) + } + return fmt.Sprintf("%dd%dh", int(d.Hours()/24), int(d.Hours())%24) +} + +// Render returns a formatted string representation of the dashboard +func (d *Dashboard) Render() string { + d.mu.RLock() + state := d.lastState + d.mu.RUnlock() + + if state == nil { + return "Dashboard not initialized" + } + + var sb strings.Builder + + // Header + sb.WriteString("\n") + sb.WriteString("╔════════════════════════════════════════════════════════════╗\n") + sb.WriteString("║ PicoClaw Swarm Status Dashboard ║\n") + sb.WriteString("╚════════════════════════════════════════════════════════════╝\n") + sb.WriteString("\n") + + // This Node + if state.ThisNode != nil { + sb.WriteString("【This Node】\n") + sb.WriteString(fmt.Sprintf(" ID: %s\n", state.ThisNode.ID)) + sb.WriteString(fmt.Sprintf(" Role: %s\n", formatRole(state.ThisNode.Role))) + sb.WriteString(fmt.Sprintf(" Status: %s\n", formatStatus(state.ThisNode.Status))) + sb.WriteString(fmt.Sprintf(" Load: %s\n", formatLoadBar(state.ThisNode.Load, state.ThisNode.TasksRunning, state.ThisNode.MaxTasks))) + sb.WriteString(fmt.Sprintf(" Uptime: %s\n", state.ThisNode.Uptime)) + sb.WriteString("\n") + } + + // Connections + if state.Connections != nil { + sb.WriteString("【Connections】\n") + sb.WriteString(fmt.Sprintf(" NATS: %s %s\n", formatBool(state.Connections.NATSConnected), state.Connections.NATSURL)) + sb.WriteString(fmt.Sprintf(" Temporal: %s\n", formatBool(state.Connections.TemporalConnected))) + if state.Connections.EmbeddedNATS { + sb.WriteString(" (Embedded NATS Server)\n") + } + sb.WriteString("\n") + } + + // Leader Status + if state.LeaderStatus != nil && state.LeaderStatus.Enabled { + sb.WriteString("【Leader Election】\n") + sb.WriteString(fmt.Sprintf(" Enabled: Yes\n")) + sb.WriteString(fmt.Sprintf(" IsLeader: %s\n", formatBool(state.LeaderStatus.IsLeader))) + sb.WriteString(fmt.Sprintf(" LeaderID: %s\n", state.LeaderStatus.LeaderID)) + if state.LeaderStatus.LeaseExpiry > 0 { + remaining := time.Until(time.UnixMilli(state.LeaderStatus.LeaseExpiry)) + sb.WriteString(fmt.Sprintf(" Lease: %s\n", formatDuration(remaining))) + } + sb.WriteString("\n") + } + + // Statistics + if state.Stats != nil { + sb.WriteString("【Swarm Statistics】\n") + sb.WriteString(fmt.Sprintf(" Nodes: %d total, %d online, %d offline\n", + state.Stats.TotalNodes, state.Stats.OnlineNodes, state.Stats.OfflineNodes)) + sb.WriteString(fmt.Sprintf(" Roles: %d coordinator(s), %d worker(s), %d specialist(s)\n", + state.Stats.CoordinatorCount, state.Stats.WorkerCount, state.Stats.SpecialistCount)) + sb.WriteString(fmt.Sprintf(" Capacity: %d/%d tasks used\n", + state.Stats.UsedCapacity, state.Stats.TotalCapacity)) + sb.WriteString("\n") + } + + // Nodes List + if len(state.Nodes) > 0 { + sb.WriteString("【Discovered Nodes】\n") + for _, node := range state.Nodes { + statusIcon := getNodeStatusIcon(node.Status) + roleIcon := getRoleIcon(node.Role) + sb.WriteString(fmt.Sprintf(" %s %-20s %-2s %-8s %s\n", + statusIcon, + truncateID(node.ID), + roleIcon, + node.Status, + formatLoadBar(node.Load, node.TasksRunning, node.MaxTasks), + )) + } + sb.WriteString("\n") + } + + // Legend + sb.WriteString("【Legend】\n") + sb.WriteString(" ● = Online ○ = Offline ? = Unknown ◐ = Suspicious\n") + sb.WriteString(" C = Coordinator W = Worker S = Specialist\n") + sb.WriteString(fmt.Sprintf(" Updated: %s\n", time.UnixMilli(state.Timestamp).Format("15:04:05"))) + + return sb.String() +} + +// RenderCompact returns a compact one-line status +func (d *Dashboard) RenderCompact() string { + d.mu.RLock() + state := d.lastState + d.mu.RUnlock() + + if state == nil { + return "Swarm: initializing" + } + + var parts []string + + if state.ThisNode != nil { + parts = append(parts, fmt.Sprintf("%s:%s", state.ThisNode.Role, state.ThisNode.Status)) + } + + if state.Stats != nil { + parts = append(parts, fmt.Sprintf("%d/%d nodes", state.Stats.OnlineNodes, state.Stats.TotalNodes)) + } + + if state.Connections != nil { + if !state.Connections.NATSConnected { + parts = append(parts, "NATS:down") + } + } + + if state.LeaderStatus != nil && state.LeaderStatus.Enabled { + if state.LeaderStatus.IsLeader { + parts = append(parts, "LEADER") + } else if state.LeaderStatus.LeaderID != "" { + parts = append(parts, fmt.Sprintf("leader:%s", truncateID(state.LeaderStatus.LeaderID))) + } + } + + return fmt.Sprintf("Swarm[%s]", strings.Join(parts, " ")) +} + +// formatRole formats a role with an icon +func formatRole(role string) string { + switch role { + case "coordinator": + return "📋 " + role + case "worker": + return "⚙️ " + role + case "specialist": + return "🔧 " + role + default: + return role + } +} + +// formatStatus formats a status with an icon +func formatStatus(status string) string { + switch status { + case "online": + return "● " + status + case "busy": + return "🔄 " + status + case "offline": + return "○ " + status + case "suspicious": + return "◐ " + status + default: + return "? " + status + } +} + +// formatBool formats a boolean as Yes/No +func formatBool(b bool) string { + if b { + return "✓ Yes" + } + return "✗ No" +} + +// formatLoadBar creates a visual load bar +func formatLoadBar(load float64, tasksRunning, maxTasks int) string { + width := 10 + filled := int(load * float64(width)) + if filled > width { + filled = width + } + + bar := strings.Repeat("█", filled) + strings.Repeat("░", width-filled) + return fmt.Sprintf("[%s] %.0f%% (%d/%d)", bar, load*100, tasksRunning, maxTasks) +} + +// formatDuration formats a duration for display +func formatDuration(d time.Duration) string { + if d < 0 { + return "expired" + } + if d < time.Second { + return fmt.Sprintf("%dms", d.Milliseconds()) + } + if d < time.Minute { + return fmt.Sprintf("%ds", int(d.Seconds())) + } + return fmt.Sprintf("%dm%ds", int(d.Minutes()), int(d.Seconds())%60) +} + +// getNodeStatusIcon returns an icon for node status +func getNodeStatusIcon(status string) string { + switch status { + case "online": + return "●" + case "busy": + return "🔄" + case "offline": + return "○" + case "suspicious": + return "◐" + default: + return "?" + } +} + +// getRoleIcon returns an icon for role +func getRoleIcon(role string) string { + switch role { + case "coordinator": + return "C" + case "worker": + return "W" + case "specialist": + return "S" + default: + return "?" + } +} + +// truncateID truncates an ID for display +func truncateID(id string) string { + if len(id) <= 20 { + return id + } + return id[:17] + "..." +} + +// GetStatusSummary returns a quick status summary for logging +func (d *Dashboard) GetStatusSummary() map[string]interface{} { + d.mu.RLock() + defer d.mu.RUnlock() + + if d.lastState == nil { + return map[string]interface{}{ + "status": "initializing", + } + } + + summary := map[string]interface{}{ + "timestamp": d.lastState.Timestamp, + "total_nodes": d.lastState.Stats.TotalNodes, + "online_nodes": d.lastState.Stats.OnlineNodes, + "nats": d.lastState.Connections.NATSConnected, + "temporal": d.lastState.Connections.TemporalConnected, + } + + if d.lastState.ThisNode != nil { + summary["node_id"] = d.lastState.ThisNode.ID + summary["role"] = d.lastState.ThisNode.Role + } + + if d.lastState.LeaderStatus != nil && d.lastState.LeaderStatus.Enabled { + summary["is_leader"] = d.lastState.LeaderStatus.IsLeader + } + + return summary +} diff --git a/pkg/swarm/dashboard_test.go b/pkg/swarm/dashboard_test.go new file mode 100644 index 000000000..43b8d5c8a --- /dev/null +++ b/pkg/swarm/dashboard_test.go @@ -0,0 +1,217 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 Picooclaw contributors + +package swarm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewDashboard(t *testing.T) { + dash := NewDashboard(nil) + assert.NotNil(t, dash) + assert.False(t, dash.enabled) + assert.Equal(t, 2*time.Second, dash.refresh) +} + +func TestDashboard_SetRefreshInterval(t *testing.T) { + dash := NewDashboard(nil) + dash.SetRefreshInterval(5 * time.Second) + assert.Equal(t, 5*time.Second, dash.refresh) +} + +func TestDashboard_StopWithoutStart(t *testing.T) { + dash := NewDashboard(nil) + // Should not panic + dash.Stop() + assert.False(t, dash.enabled) +} + +func TestDashboard_RenderEmpty(t *testing.T) { + dash := NewDashboard(nil) + output := dash.Render() + assert.Contains(t, output, "Dashboard not initialized") +} + +func TestDashboard_RenderCompactEmpty(t *testing.T) { + dash := NewDashboard(nil) + output := dash.RenderCompact() + assert.Contains(t, output, "initializing") +} + +func TestDashboard_SnapshotNodeInfo(t *testing.T) { + node := &NodeInfo{ + ID: "test-node-1", + Role: RoleWorker, + Status: StatusOnline, + Capabilities: []string{"code", "research"}, + Model: "test-model", + Load: 0.5, + TasksRunning: 2, + MaxTasks: 5, + LastSeen: time.Now().UnixMilli(), + StartedAt: time.Now().UnixMilli(), // Started now + } + + snapshot := snapshotNodeInfo(node) + + assert.Equal(t, "test-node-1", snapshot.ID) + assert.Equal(t, "worker", snapshot.Role) + assert.Equal(t, "online", snapshot.Status) + assert.Equal(t, []string{"code", "research"}, snapshot.Capabilities) + assert.Equal(t, 0.5, snapshot.Load) + assert.Equal(t, 2, snapshot.TasksRunning) + assert.Equal(t, 5, snapshot.MaxTasks) + // Just started, should show seconds or 0s + assert.Contains(t, snapshot.Uptime, "s") +} + +func TestDashboard_CalculateStats(t *testing.T) { + nodes := []*NodeInfoSnapshot{ + {Role: "coordinator", Status: "online", MaxTasks: 10, TasksRunning: 2}, + {Role: "worker", Status: "online", MaxTasks: 5, TasksRunning: 3}, + {Role: "worker", Status: "busy", MaxTasks: 5, TasksRunning: 5}, + {Role: "worker", Status: "offline", MaxTasks: 5, TasksRunning: 0}, + {Role: "specialist", Status: "online", MaxTasks: 3, TasksRunning: 1}, + } + + stats := calculateStats(nodes) + + assert.Equal(t, 5, stats.TotalNodes) + assert.Equal(t, 4, stats.OnlineNodes) // online + busy nodes + assert.Equal(t, 1, stats.OfflineNodes) + assert.Equal(t, 1, stats.CoordinatorCount) + assert.Equal(t, 3, stats.WorkerCount) + assert.Equal(t, 1, stats.SpecialistCount) + assert.Equal(t, 28, stats.TotalCapacity) + assert.Equal(t, 11, stats.UsedCapacity) +} + +func TestDashboard_FormatUptime(t *testing.T) { + tests := []struct { + name string + duration time.Duration + want string + }{ + {"seconds", 30 * time.Second, "30s"}, + {"minutes", 5 * time.Minute, "5m"}, + {"hours", 3 * time.Hour, "3h0m"}, + {"days", 2*24*time.Hour + 5*time.Hour, "2d5h"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatUptime(tt.duration) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestDashboard_FormatLoadBar(t *testing.T) { + tests := []struct { + name string + load float64 + tasksRunning int + maxTasks int + }{ + {"empty", 0.0, 0, 10}, + {"half", 0.5, 5, 10}, + {"full", 1.0, 10, 10}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatLoadBar(tt.load, tt.tasksRunning, tt.maxTasks) + assert.Contains(t, got, "[") + assert.Contains(t, got, "]") + assert.Contains(t, got, "%") + assert.Contains(t, got, "/") + }) + } +} + +func TestDashboard_GetStatusSummary(t *testing.T) { + dash := NewDashboard(nil) + summary := dash.GetStatusSummary() + assert.NotNil(t, summary) + assert.Contains(t, summary, "status") + assert.Equal(t, "initializing", summary["status"]) +} + +func TestDashboard_RenderWithState(t *testing.T) { + dash := NewDashboard(nil) + + // Manually set state + dash.lastState = &DashboardState{ + Timestamp: time.Now().UnixMilli(), + ThisNode: &NodeInfoSnapshot{ + ID: "test-node", + Role: "worker", + Status: "online", + Load: 0.3, + }, + Stats: &SwarmStats{ + TotalNodes: 3, + OnlineNodes: 3, + }, + Connections: &ConnectionStatus{ + NATSConnected: true, + TemporalConnected: false, + }, + } + + output := dash.Render() + assert.Contains(t, output, "PicoClaw Swarm Status") + assert.Contains(t, output, "test-node") + assert.Contains(t, output, "worker") + assert.Contains(t, output, "online") +} + +func TestDashboard_TruncateID(t *testing.T) { + tests := []struct { + name string + id string + want string + }{ + {"short", "abc", "abc"}, + {"exact", "12345678901234567890", "12345678901234567890"}, + {"long", "12345678901234567890123", "12345678901234567..."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := truncateID(tt.id) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestDashboard_RenderCompactWithState(t *testing.T) { + dash := NewDashboard(nil) + + dash.lastState = &DashboardState{ + Timestamp: time.Now().UnixMilli(), + ThisNode: &NodeInfoSnapshot{ + ID: "test-node", + Role: "worker", + Status: "online", + }, + Stats: &SwarmStats{ + TotalNodes: 5, + OnlineNodes: 4, + }, + Connections: &ConnectionStatus{ + NATSConnected: true, + }, + } + + output := dash.RenderCompact() + assert.Contains(t, output, "Swarm[") + assert.Contains(t, output, "worker:online") + assert.Contains(t, output, "4/5") +} diff --git a/pkg/swarm/discovery.go b/pkg/swarm/discovery.go new file mode 100644 index 000000000..10ecaa0ae --- /dev/null +++ b/pkg/swarm/discovery.go @@ -0,0 +1,395 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// Discovery manages node discovery and heartbeats +type Discovery struct { + bridge *NATSBridge + nodeInfo *NodeInfo + cfg *config.SwarmConfig + registry map[string]*NodeInfo + mu sync.RWMutex + heartbeatStop chan struct{} + cleanupStop chan struct{} +} + +// NewDiscovery creates a new discovery service +func NewDiscovery(bridge *NATSBridge, nodeInfo *NodeInfo, cfg *config.SwarmConfig) *Discovery { + return &Discovery{ + bridge: bridge, + nodeInfo: nodeInfo, + cfg: cfg, + registry: make(map[string]*NodeInfo), + heartbeatStop: make(chan struct{}), + cleanupStop: make(chan struct{}), + } +} + +// Start begins heartbeat publishing and node cleanup +func (d *Discovery) Start(ctx context.Context) error { + // Register callbacks + d.bridge.SetOnNodeJoin(d.handleNodeJoin) + d.bridge.SetOnNodeLeave(d.handleNodeLeave) + + // Subscribe to all heartbeats + if _, err := d.bridge.SubscribeAllHeartbeats(d.handleHeartbeat); err != nil { + return fmt.Errorf("failed to subscribe to heartbeats: %w", err) + } + + // Subscribe to shutdown notices + if _, err := d.bridge.SubscribeShutdown(d.handleNodeLeave); err != nil { + return fmt.Errorf("failed to subscribe to shutdown notices: %w", err) + } + + // Start heartbeat goroutine + go d.heartbeatLoop(ctx) + + // Start cleanup goroutine + go d.cleanupLoop(ctx) + + // Query for existing nodes + d.queryNodes() + + logger.InfoC("swarm", "Discovery service started") + return nil +} + +// Stop stops the discovery service +func (d *Discovery) Stop() { + close(d.heartbeatStop) + close(d.cleanupStop) +} + +// GetNodes returns all known nodes (optionally filtered) +func (d *Discovery) GetNodes(role NodeRole, capability string) []*NodeInfo { + d.mu.RLock() + defer d.mu.RUnlock() + + nodes := make([]*NodeInfo, 0) + for _, node := range d.registry { + // Skip offline nodes + if node.Status == StatusOffline { + continue + } + + // Filter by role if specified + if role != "" && node.Role != role { + continue + } + + // Filter by capability if specified + if capability != "" && !containsCapability(node.Capabilities, capability) { + continue + } + + nodes = append(nodes, node) + } + return nodes +} + +// GetNode returns a specific node by ID +func (d *Discovery) GetNode(nodeID string) (*NodeInfo, bool) { + d.mu.RLock() + defer d.mu.RUnlock() + node, ok := d.registry[nodeID] + return node, ok +} + +// GetAllNodes returns all known nodes including offline ones +func (d *Discovery) GetAllNodes() []*NodeInfo { + d.mu.RLock() + defer d.mu.RUnlock() + nodes := make([]*NodeInfo, 0, len(d.registry)) + for _, node := range d.registry { + nodes = append(nodes, node) + } + return nodes +} + +// NodeCount returns the total number of known online nodes +func (d *Discovery) NodeCount() int { + d.mu.RLock() + defer d.mu.RUnlock() + count := 0 + for _, node := range d.registry { + if node.Status != StatusOffline { + count++ + } + } + return count +} + +// SelectWorker selects the best worker for a capability using load balancing +func (d *Discovery) SelectWorker(capability string) *NodeInfo { + return d.SelectWorkerWithPriority(capability, 1) // Default to normal priority +} + +// SelectWorkerWithPriority selects the best worker considering task priority +// Priority levels: 0=low, 1=normal, 2=high, 3=critical +// Higher priority tasks prefer nodes with lower current load and more available capacity +func (d *Discovery) SelectWorkerWithPriority(capability string, priority int) *NodeInfo { + workers := d.GetNodes(RoleWorker, capability) + if len(workers) == 0 { + // Try specialists + workers = d.GetNodes(RoleSpecialist, capability) + } + if len(workers) == 0 { + return nil + } + + // Calculate selection score based on priority + // For high priority tasks, prefer nodes with: + // 1. Lower current load + // 2. More available capacity (maxTasks - tasksRunning) + // 3. Online status + var best *NodeInfo + var bestScore float64 = -1 + + for _, w := range workers { + if w.Status == StatusOffline { + continue + } + if w.TasksRunning >= w.MaxTasks { + continue // Skip full nodes + } + + score := d.calculateNodeScore(w, priority) + if score > bestScore { + best = w + bestScore = score + } + } + + return best +} + +// calculateNodeScore calculates a node's suitability score for a given priority +// Higher score = better candidate +func (d *Discovery) calculateNodeScore(node *NodeInfo, priority int) float64 { + // Base score: inverse of load (0-1 range, where 1 = idle) + loadScore := 1.0 - node.Load + + // Capacity score: ratio of available tasks + capacityScore := float64(node.MaxTasks-node.TasksRunning) / float64(node.MaxTasks) + + // Priority multiplier: + // - Low priority (0): Prefer busy nodes (0.5x), spread load + // - Normal priority (1): Standard selection (1.0x) + // - High priority (2): Prefer idle nodes (1.5x) + // - Critical priority (3): Strongly prefer idle nodes (2.0x) + var priorityMult float64 + switch priority { + case 0: + priorityMult = 0.5 + case 1: + priorityMult = 1.0 + case 2: + priorityMult = 1.5 + case 3: + priorityMult = 2.0 + default: + priorityMult = 1.0 + } + + // Final score combines load and capacity, weighted by priority + // For high priority, idle nodes get much higher scores + score := (loadScore*0.6 + capacityScore*0.4) * priorityMult + + // Bonus for completely idle nodes for high+ priority + if node.Load == 0 && priority >= 2 { + score *= 1.5 + } + + return score +} + +func (d *Discovery) heartbeatLoop(ctx context.Context) { + interval := d.cfg.GetHeartbeatInterval() + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + d.publishHeartbeat() + case <-d.heartbeatStop: + return + case <-ctx.Done(): + return + } + } +} + +func (d *Discovery) publishHeartbeat() { + hid, _ := d.nodeInfo.Metadata["hid"] + sid, _ := d.nodeInfo.Metadata["sid"] + hb := &Heartbeat{ + NodeID: d.nodeInfo.ID, + Role: d.nodeInfo.Role, + Status: d.nodeInfo.Status, + Load: d.nodeInfo.Load, + TasksRunning: d.nodeInfo.TasksRunning, + Timestamp: time.Now().UnixMilli(), + Capabilities: d.nodeInfo.Capabilities, + HID: hid, + SID: sid, + } + if err := d.bridge.PublishHeartbeat(hb); err != nil { + logger.DebugCF("swarm", "Failed to publish heartbeat", map[string]interface{}{ + "error": err.Error(), + }) + } +} + +func (d *Discovery) handleHeartbeat(hb *Heartbeat) { + // Skip our own heartbeats + if hb.NodeID == d.nodeInfo.ID { + return + } + + d.mu.Lock() + defer d.mu.Unlock() + + if node, ok := d.registry[hb.NodeID]; ok { + node.Status = hb.Status + node.Load = hb.Load + node.TasksRunning = hb.TasksRunning + node.LastSeen = hb.Timestamp + } +} + +func (d *Discovery) cleanupLoop(ctx context.Context) { + timeout := d.cfg.GetNodeTimeout() + ticker := time.NewTicker(timeout / 2) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + d.cleanupStaleNodes() + case <-d.cleanupStop: + return + case <-ctx.Done(): + return + } + } +} + +func (d *Discovery) cleanupStaleNodes() { + d.mu.Lock() + defer d.mu.Unlock() + + timeout := d.cfg.GetNodeTimeout() + now := time.Now().UnixMilli() + staleThreshold := now - int64(timeout.Milliseconds()) + // GC threshold: 10x timeout to remove long-dead nodes and prevent memory leak + gcThreshold := now - int64(timeout.Milliseconds()*10) + + for id, node := range d.registry { + if node.LastSeen < staleThreshold && node.Status != StatusOffline { + node.Status = StatusOffline + logger.WarnCF("swarm", "Node marked offline (heartbeat timeout)", map[string]interface{}{ + "node_id": id, + "last_seen": time.UnixMilli(node.LastSeen).Format(time.RFC3339), + }) + } + // GC long-dead nodes to prevent memory leak + if node.LastSeen < gcThreshold { + delete(d.registry, id) + logger.InfoCF("swarm", "Node removed from registry (GC)", map[string]interface{}{ + "node_id": id, + }) + } + } +} + +func (d *Discovery) queryNodes() { + query := &DiscoveryQuery{ + RequesterID: d.nodeInfo.ID, + } + + nodes, err := d.bridge.RequestDiscovery(query, 2*time.Second) + if err != nil { + logger.WarnCF("swarm", "Failed to query existing nodes", map[string]interface{}{ + "error": err.Error(), + }) + return + } + + for _, node := range nodes { + d.handleNodeJoin(node) + } + + logger.InfoCF("swarm", "Discovery query completed", map[string]interface{}{ + "nodes_found": len(nodes), + }) +} + +func (d *Discovery) handleNodeJoin(node *NodeInfo) { + if node == nil { + logger.WarnC("swarm", "Attempted to register nil node") + return + } + if node.ID == "" { + logger.WarnC("swarm", "Attempted to register node with empty ID") + return + } + if node.ID == d.nodeInfo.ID { + return // Skip self + } + + // Validate role + validRoles := map[NodeRole]bool{RoleCoordinator: true, RoleWorker: true, RoleSpecialist: true} + if !validRoles[node.Role] { + logger.WarnCF("swarm", "Invalid node role", map[string]interface{}{ + "node_id": node.ID, + "role": string(node.Role), + }) + return + } + + d.mu.Lock() + defer d.mu.Unlock() + + node.LastSeen = time.Now().UnixMilli() + d.registry[node.ID] = node + + logger.InfoCF("swarm", "Node registered", map[string]interface{}{ + "node_id": node.ID, + "role": string(node.Role), + "capabilities": fmt.Sprintf("%v", node.Capabilities), + }) +} + +func (d *Discovery) handleNodeLeave(nodeID string) { + d.mu.Lock() + defer d.mu.Unlock() + + if node, ok := d.registry[nodeID]; ok { + node.Status = StatusOffline + logger.InfoCF("swarm", "Node left swarm", map[string]interface{}{ + "node_id": nodeID, + }) + } +} + +// MarshalRegistryJSON returns the current registry as JSON (for debugging) +func (d *Discovery) MarshalRegistryJSON() ([]byte, error) { + d.mu.RLock() + defer d.mu.RUnlock() + return json.Marshal(d.registry) +} diff --git a/pkg/swarm/discovery_test.go b/pkg/swarm/discovery_test.go new file mode 100644 index 000000000..83426e411 --- /dev/null +++ b/pkg/swarm/discovery_test.go @@ -0,0 +1,432 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "testing" + "time" +) + +func TestDiscovery_NodeRegistration(t *testing.T) { + tests := []struct { + name string + selfID string + nodes []*NodeInfo + wantCnt int + }{ + { + name: "register single node", + selfID: "self-1", + nodes: []*NodeInfo{ + newTestNodeInfo("other-1", RoleWorker, []string{"code"}, 4), + }, + wantCnt: 1, + }, + { + name: "register multiple nodes", + selfID: "self-2", + nodes: []*NodeInfo{ + newTestNodeInfo("other-a", RoleWorker, []string{"code"}, 4), + newTestNodeInfo("other-b", RoleSpecialist, []string{"ml"}, 2), + newTestNodeInfo("other-c", RoleWorker, []string{"research"}, 4), + }, + wantCnt: 3, + }, + { + name: "skip self registration", + selfID: "self-3", + nodes: []*NodeInfo{ + newTestNodeInfo("self-3", RoleWorker, []string{"code"}, 4), // same as self + newTestNodeInfo("other-d", RoleWorker, []string{"code"}, 4), + }, + wantCnt: 1, // self-3 should be skipped + }, + { + name: "duplicate registration overwrites", + selfID: "self-4", + nodes: []*NodeInfo{ + {ID: "dup-1", Role: RoleWorker, Capabilities: []string{"code"}, Status: StatusOnline, Load: 0.1, MaxTasks: 4}, + {ID: "dup-1", Role: RoleWorker, Capabilities: []string{"code"}, Status: StatusBusy, Load: 0.9, MaxTasks: 4}, + }, + wantCnt: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selfNode := newTestNodeInfo(tt.selfID, RoleCoordinator, nil, 1) + cfg := newTestSwarmConfig(0) + d := NewDiscovery(nil, selfNode, cfg) // bridge not needed for direct handleNodeJoin + + for _, node := range tt.nodes { + d.handleNodeJoin(node) + } + + if got := d.NodeCount(); got != tt.wantCnt { + t.Errorf("NodeCount() = %d, want %d", got, tt.wantCnt) + } + + // For duplicate test, check the final state + if tt.name == "duplicate registration overwrites" { + node, ok := d.GetNode("dup-1") + if !ok { + t.Fatal("GetNode('dup-1') returned false") + } + if node.Status != StatusBusy { + t.Errorf("Status = %q, want %q (overwritten)", node.Status, StatusBusy) + } + } + }) + } +} + +func TestDiscovery_HeartbeatUpdatesNode(t *testing.T) { + tests := []struct { + name string + selfID string + registerNode *NodeInfo + heartbeat Heartbeat + expectUpdate bool + wantStatus NodeStatus + wantLoad float64 + wantTasks int + }{ + { + name: "update load", + selfID: "hb-self-1", + registerNode: &NodeInfo{ID: "hb-node-1", Role: RoleWorker, Status: StatusOnline, Load: 0.1, MaxTasks: 4}, + heartbeat: Heartbeat{NodeID: "hb-node-1", Status: StatusOnline, Load: 0.7, TasksRunning: 3, Timestamp: time.Now().UnixMilli()}, + expectUpdate: true, + wantStatus: StatusOnline, + wantLoad: 0.7, + wantTasks: 3, + }, + { + name: "update status to busy", + selfID: "hb-self-2", + registerNode: &NodeInfo{ID: "hb-node-2", Role: RoleWorker, Status: StatusOnline, Load: 0.5, MaxTasks: 4}, + heartbeat: Heartbeat{NodeID: "hb-node-2", Status: StatusBusy, Load: 1.0, TasksRunning: 4, Timestamp: time.Now().UnixMilli()}, + expectUpdate: true, + wantStatus: StatusBusy, + wantLoad: 1.0, + wantTasks: 4, + }, + { + name: "skip own heartbeat", + selfID: "hb-self-3", + registerNode: nil, // don't register anything + heartbeat: Heartbeat{NodeID: "hb-self-3", Status: StatusOnline, Load: 0.5, Timestamp: time.Now().UnixMilli()}, + expectUpdate: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selfNode := newTestNodeInfo(tt.selfID, RoleCoordinator, nil, 1) + cfg := newTestSwarmConfig(0) + d := NewDiscovery(nil, selfNode, cfg) + + if tt.registerNode != nil { + d.handleNodeJoin(tt.registerNode) + } + + d.handleHeartbeat(&tt.heartbeat) + + if tt.expectUpdate { + node, ok := d.GetNode(tt.heartbeat.NodeID) + if !ok { + t.Fatal("GetNode() returned false after heartbeat") + } + if node.Status != tt.wantStatus { + t.Errorf("Status = %q, want %q", node.Status, tt.wantStatus) + } + if node.Load != tt.wantLoad { + t.Errorf("Load = %f, want %f", node.Load, tt.wantLoad) + } + if node.TasksRunning != tt.wantTasks { + t.Errorf("TasksRunning = %d, want %d", node.TasksRunning, tt.wantTasks) + } + } + }) + } +} + +func TestDiscovery_SelectWorker(t *testing.T) { + tests := []struct { + name string + nodes []*NodeInfo + capability string + wantID string // empty means nil expected + }{ + { + name: "picks lowest load", + nodes: []*NodeInfo{ + {ID: "w-a", Role: RoleWorker, Capabilities: []string{"code"}, Status: StatusOnline, Load: 0.3, TasksRunning: 1, MaxTasks: 4}, + {ID: "w-b", Role: RoleWorker, Capabilities: []string{"code"}, Status: StatusOnline, Load: 0.1, TasksRunning: 0, MaxTasks: 4}, + }, + capability: "code", + wantID: "w-b", + }, + { + name: "skips full worker", + nodes: []*NodeInfo{ + {ID: "w-full", Role: RoleWorker, Capabilities: []string{"code"}, Status: StatusOnline, Load: 1.0, TasksRunning: 2, MaxTasks: 2}, + {ID: "w-avail", Role: RoleWorker, Capabilities: []string{"code"}, Status: StatusOnline, Load: 0.5, TasksRunning: 2, MaxTasks: 4}, + }, + capability: "code", + wantID: "w-avail", + }, + { + name: "no workers with capability", + nodes: []*NodeInfo{ + {ID: "w-research", Role: RoleWorker, Capabilities: []string{"research"}, Status: StatusOnline, Load: 0.1, MaxTasks: 4}, + }, + capability: "code", + wantID: "", + }, + { + name: "falls back to specialist", + nodes: []*NodeInfo{ + {ID: "s-code", Role: RoleSpecialist, Capabilities: []string{"code"}, Status: StatusOnline, Load: 0.2, TasksRunning: 0, MaxTasks: 2}, + }, + capability: "code", + wantID: "s-code", + }, + { + name: "all workers full", + nodes: []*NodeInfo{ + {ID: "w-f1", Role: RoleWorker, Capabilities: []string{"code"}, Status: StatusBusy, Load: 1.0, TasksRunning: 1, MaxTasks: 1}, + {ID: "w-f2", Role: RoleWorker, Capabilities: []string{"code"}, Status: StatusBusy, Load: 1.0, TasksRunning: 1, MaxTasks: 1}, + }, + capability: "code", + wantID: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selfNode := newTestNodeInfo("select-self", RoleCoordinator, nil, 1) + cfg := newTestSwarmConfig(0) + d := NewDiscovery(nil, selfNode, cfg) + + for _, node := range tt.nodes { + d.handleNodeJoin(node) + } + + got := d.SelectWorker(tt.capability) + + if tt.wantID == "" { + if got != nil { + t.Errorf("SelectWorker() = %q, want nil", got.ID) + } + } else { + if got == nil { + t.Fatalf("SelectWorker() = nil, want %q", tt.wantID) + } + if got.ID != tt.wantID { + t.Errorf("SelectWorker().ID = %q, want %q", got.ID, tt.wantID) + } + } + }) + } +} + +func TestDiscovery_StaleNodeCleanup(t *testing.T) { + tests := []struct { + name string + lastSeen int64 // milliseconds ago + initStatus NodeStatus + wantStatus NodeStatus + wantExists bool // whether node should still exist after cleanup + }{ + { + name: "stale node marked offline", + lastSeen: 300, // 300ms ago, > timeout (200ms) but < GC threshold (2s) + initStatus: StatusOnline, + wantStatus: StatusOffline, + wantExists: true, + }, + { + name: "fresh node untouched", + lastSeen: 10, // 10ms ago, well within timeout + initStatus: StatusOnline, + wantStatus: StatusOnline, + wantExists: true, + }, + { + name: "already offline node unchanged", + lastSeen: 300, // 300ms ago, > timeout but < GC threshold + initStatus: StatusOffline, + wantStatus: StatusOffline, + wantExists: true, + }, + { + name: "long dead node GC'd", + lastSeen: 5000, // 5s ago, > GC threshold (2s) + initStatus: StatusOffline, + wantExists: false, // should be removed + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selfNode := newTestNodeInfo("cleanup-self", RoleCoordinator, nil, 1) + cfg := newTestSwarmConfig(0) // NodeTimeout is 200ms + d := NewDiscovery(nil, selfNode, cfg) + + node := &NodeInfo{ + ID: "cleanup-node", + Role: RoleWorker, + Status: tt.initStatus, + MaxTasks: 4, + LastSeen: time.Now().UnixMilli() - tt.lastSeen, + } + // Directly insert into registry (bypass handleNodeJoin which sets LastSeen) + d.mu.Lock() + d.registry["cleanup-node"] = node + d.mu.Unlock() + + d.cleanupStaleNodes() + + got, ok := d.GetNode("cleanup-node") + if tt.wantExists { + if !ok { + t.Fatal("GetNode() returned false, expected node to exist") + } + if got.Status != tt.wantStatus { + t.Errorf("Status = %q, want %q", got.Status, tt.wantStatus) + } + } else { + if ok { + t.Errorf("GetNode() returned true, expected node to be GC'd") + } + } + }) + } +} + +func TestDiscovery_GetNodesFiltering(t *testing.T) { + // Set up a discovery with mixed nodes + selfNode := newTestNodeInfo("filter-self", RoleCoordinator, nil, 1) + cfg := newTestSwarmConfig(0) + d := NewDiscovery(nil, selfNode, cfg) + + nodes := []*NodeInfo{ + {ID: "f-w1", Role: RoleWorker, Capabilities: []string{"code", "research"}, Status: StatusOnline, MaxTasks: 4}, + {ID: "f-w2", Role: RoleWorker, Capabilities: []string{"code"}, Status: StatusOnline, MaxTasks: 4}, + {ID: "f-s1", Role: RoleSpecialist, Capabilities: []string{"ml"}, Status: StatusOnline, MaxTasks: 2}, + {ID: "f-off", Role: RoleWorker, Capabilities: []string{"code"}, Status: StatusOffline, MaxTasks: 4}, + } + for _, n := range nodes { + d.mu.Lock() + d.registry[n.ID] = n + d.mu.Unlock() + } + + tests := []struct { + name string + role NodeRole + capability string + wantCount int + }{ + { + name: "no filter returns all online", + role: "", + capability: "", + wantCount: 3, // f-w1, f-w2, f-s1 (f-off excluded) + }, + { + name: "filter by role worker", + role: RoleWorker, + capability: "", + wantCount: 2, // f-w1, f-w2 (f-off is offline) + }, + { + name: "filter by capability code", + role: "", + capability: "code", + wantCount: 2, // f-w1, f-w2 + }, + { + name: "filter by role and capability", + role: RoleWorker, + capability: "research", + wantCount: 1, // f-w1 + }, + { + name: "offline nodes excluded", + role: RoleWorker, + capability: "code", + wantCount: 2, // f-w1, f-w2 (not f-off) + }, + { + name: "filter by specialist role", + role: RoleSpecialist, + capability: "", + wantCount: 1, // f-s1 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := d.GetNodes(tt.role, tt.capability) + if len(got) != tt.wantCount { + ids := make([]string, len(got)) + for i, n := range got { + ids[i] = n.ID + } + t.Errorf("GetNodes(%q, %q) returned %d nodes %v, want %d", tt.role, tt.capability, len(got), ids, tt.wantCount) + } + }) + } +} + +func TestDiscovery_NodeLeave(t *testing.T) { + tests := []struct { + name string + registerNode bool + leaveID string + wantStatus NodeStatus + }{ + { + name: "known node leaves", + registerNode: true, + leaveID: "leave-node", + wantStatus: StatusOffline, + }, + { + name: "unknown node leave no panic", + registerNode: false, + leaveID: "unknown-node", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selfNode := newTestNodeInfo("leave-self", RoleCoordinator, nil, 1) + cfg := newTestSwarmConfig(0) + d := NewDiscovery(nil, selfNode, cfg) + + if tt.registerNode { + node := newTestNodeInfo("leave-node", RoleWorker, []string{"code"}, 4) + d.handleNodeJoin(node) + } + + // Should not panic + d.handleNodeLeave(tt.leaveID) + + if tt.registerNode { + node, ok := d.GetNode(tt.leaveID) + if !ok { + t.Fatal("GetNode() returned false for left node") + } + if node.Status != tt.wantStatus { + t.Errorf("Status = %q, want %q", node.Status, tt.wantStatus) + } + } + }) + } +} diff --git a/pkg/swarm/edge.go b/pkg/swarm/edge.go new file mode 100644 index 000000000..6d3d82492 --- /dev/null +++ b/pkg/swarm/edge.go @@ -0,0 +1,345 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// EdgeWorkerConfig contains configuration for edge-optimized workers +type EdgeWorkerConfig struct { + // Resource limits + MaxMemoryMB int64 // Maximum memory usage in MB + MaxCPUPercent int // Maximum CPU percentage (1-100) + EnableGCThreshold uint64 // GC trigger threshold in bytes + + // Network optimization + DisableHeartbeat bool // Disable periodic heartbeats + HeartbeatInterval time.Duration // Custom heartbeat interval + CompressionLevel int // Message compression (0-9) + + // Feature flags for minimal footprint + DisableWorkflow bool // Disable Temporal workflow support + DisableDashboard bool // Disable dashboard features + DisableDiscovery bool // Disable full node discovery + MinimalMode bool // Enable absolute minimal mode +} + +// DefaultEdgeWorkerConfig returns default edge worker configuration +func DefaultEdgeWorkerConfig() *EdgeWorkerConfig { + return &EdgeWorkerConfig{ + MaxMemoryMB: 50, // 50MB default + MaxCPUPercent: 50, // 50% CPU max + EnableGCThreshold: 10 * 1024 * 1024, // 10MB + DisableHeartbeat: false, + HeartbeatInterval: 30 * time.Second, // Less frequent + CompressionLevel: 2, // Light compression + DisableWorkflow: true, // Workflows disabled by default on edge + DisableDashboard: true, + DisableDiscovery: false, + MinimalMode: false, + } +} + +// EdgeWorker is a resource-optimized worker for edge devices +type EdgeWorker struct { + *Worker + config *EdgeWorkerConfig + edgeStop atomic.Bool + mu sync.RWMutex + + // Resource tracking + memoryUsed atomic.Int64 + cpuPercent atomic.Int64 + lastGCTime atomic.Value // time.Time + + // Edge-specific optimizations + compressionEnabled bool + batchMode bool + batchSize int + batchTimeout time.Duration +} + +// NewEdgeWorker creates an edge-optimized worker +func NewEdgeWorker( + baseWorker *Worker, + config *EdgeWorkerConfig, +) *EdgeWorker { + if config == nil { + config = DefaultEdgeWorkerConfig() + } + + ew := &EdgeWorker{ + Worker: baseWorker, + config: config, + batchSize: 5, + batchTimeout: 5 * time.Second, + } + + ew.compressionEnabled = config.CompressionLevel > 0 + ew.lastGCTime.Store(time.Now()) + + // Apply resource limits + ew.setupResourceLimits() + + logger.InfoCF("swarm", "Edge worker created", map[string]interface{}{ + "max_memory_mb": config.MaxMemoryMB, + "max_cpu_percent": config.MaxCPUPercent, + "compression_level": config.CompressionLevel, + "minimal_mode": config.MinimalMode, + }) + + return ew +} + +// Start starts the edge worker with optimizations +func (ew *EdgeWorker) Start(ctx context.Context) error { + logger.InfoC("swarm", "Starting edge worker") + + // Set up memory monitoring + go ew.monitorResources(ctx) + + // Start base worker + if err := ew.Worker.Start(ctx); err != nil { + return fmt.Errorf("failed to start base worker: %w", err) + } + + // Disable unnecessary features in minimal mode + if ew.config.MinimalMode { + ew.disableNonEssentialFeatures() + } + + logger.InfoC("swarm", "Edge worker started") + return nil +} + +// Stop gracefully stops the edge worker +func (ew *EdgeWorker) Stop() { + if !ew.edgeStop.CompareAndSwap(false, true) { + return // Already stopped + } + + logger.InfoC("swarm", "Stopping edge worker") + + // Force GC before stopping to free memory + ew.forceGC() + + ew.Worker.Stop() + logger.InfoC("swarm", "Edge worker stopped") +} + +// setupResourceLimits configures resource constraints +func (ew *EdgeWorker) setupResourceLimits() { + // Set GC target based on config + if ew.config.MaxMemoryMB > 0 { + target := uint64(ew.config.MaxMemoryMB * 1024 * 1024 / 2) // Target 50% + ew.memoryUsed.Store(int64(target)) + } +} + +// monitorResources periodically checks and manages resource usage +func (ew *EdgeWorker) monitorResources(ctx context.Context) { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if ew.shouldGC() { + ew.forceGC() + } + + if ew.isOverLimit() { + ew.handleResourceOverrun() + } + } + } +} + +// shouldGC determines if garbage collection should run +func (ew *EdgeWorker) shouldGC() bool { + if ew.config.EnableGCThreshold == 0 { + return false + } + + // Get current memory usage estimate + memUsed := ew.memoryUsed.Load() + return uint64(memUsed) > ew.config.EnableGCThreshold +} + +// forceGC forces garbage collection +func (ew *EdgeWorker) forceGC() { + // In Go, we can't force GC directly, but we can hint + // and clear internal caches + ew.mu.Lock() + + // Clear any internal caches + if ew.nodeInfo != nil { + ew.nodeInfo.Metadata = make(map[string]string) + } + + ew.mu.Unlock() + + ew.lastGCTime.Store(time.Now()) + logger.DebugC("swarm", "Edge worker GC performed") +} + +// isOverLimit checks if resource limits are exceeded +func (ew *EdgeWorker) isOverLimit() bool { + if ew.config.MaxMemoryMB == 0 && ew.config.MaxCPUPercent == 0 { + return false + } + + // Check memory + if ew.config.MaxMemoryMB > 0 { + memMB := ew.memoryUsed.Load() / (1024 * 1024) + if memMB > ew.config.MaxMemoryMB { + logger.WarnCF("swarm", "Memory limit exceeded", map[string]interface{}{ + "used_mb": memMB, + "max_mb": ew.config.MaxMemoryMB, + }) + return true + } + } + + // Check CPU + if ew.config.MaxCPUPercent > 0 { + cpu := ew.cpuPercent.Load() + if cpu > int64(ew.config.MaxCPUPercent) { + logger.WarnCF("swarm", "CPU limit exceeded", map[string]interface{}{ + "used_percent": cpu, + "max_percent": ew.config.MaxCPUPercent, + }) + return true + } + } + + return false +} + +// handleResourceOverrun handles resource limit violations +func (ew *EdgeWorker) handleResourceOverrun() { + // Reduce load by reducing max concurrent tasks + ew.mu.Lock() + if ew.Worker.cfg.MaxConcurrent > 1 { + ew.Worker.cfg.MaxConcurrent-- // Reduce concurrent task limit + } + ew.mu.Unlock() + + // Force GC + ew.forceGC() + + logger.InfoCF("swarm", "Resource overrun handled", map[string]interface{}{ + "new_max_tasks": ew.Worker.cfg.MaxConcurrent, + }) +} + +// disableNonEssentialFeatures disables features in minimal mode +func (ew *EdgeWorker) disableNonEssentialFeatures() { + // Disable internal statistics collection + ew.config.DisableDashboard = true + + // Increase heartbeat interval to reduce network usage + if !ew.config.DisableHeartbeat { + ew.config.HeartbeatInterval = 60 * time.Second + } + + logger.InfoC("swarm", "Non-essential features disabled for minimal mode") +} + +// GetResourceUsage returns current resource usage statistics +func (ew *EdgeWorker) GetResourceUsage() map[string]interface{} { + ew.mu.RLock() + defer ew.mu.RUnlock() + + memMB := ew.memoryUsed.Load() / (1024 * 1024) + cpu := ew.cpuPercent.Load() + + return map[string]interface{}{ + "memory_mb": memMB, + "memory_max_mb": ew.config.MaxMemoryMB, + "cpu_percent": cpu, + "cpu_max_percent": ew.config.MaxCPUPercent, + "last_gc": ew.lastGCTime.Load().(time.Time).Format(time.RFC3339), + "batch_mode": ew.batchMode, + } +} + +// IsHealthy checks if the edge worker is healthy +func (ew *EdgeWorker) IsHealthy() bool { + // Check if we're not critically over resource limits + if ew.config.MaxMemoryMB > 0 { + memMB := ew.memoryUsed.Load() / (1024 * 1024) + if memMB > ew.config.MaxMemoryMB*9/10 { // 90% threshold + return false + } + } + + // Check if base worker is running + return ew.Worker.running.Load() +} + +// GetEdgeConfig returns the edge worker configuration +func (ew *EdgeWorker) GetEdgeConfig() *EdgeWorkerConfig { + return ew.config +} + +// EdgeBuildInfo provides build information for edge deployment +type EdgeBuildInfo struct { + GOOS string + GOARCH string + Version string + Commit string + BuiltAt string +} + +// GetEdgeBuildInfo returns build information for the current binary +func GetEdgeBuildInfo() *EdgeBuildInfo { + return &EdgeBuildInfo{ + GOOS: "linux", // Default target + GOARCH: "arm64", // Default target + Version: "1.0.0", + Commit: "unknown", + BuiltAt: time.Now().Format(time.RFC3339), + } +} + +// SupportedPlatforms returns platforms supported for edge deployment +func SupportedPlatforms() []string { + return []string{ + "linux/arm64", // Raspberry Pi 4+, ARM servers + "linux/arm", // Raspberry Pi Zero, older ARM + "linux/amd64", // Intel/AMD x86_64 + "linux/386", // Intel/AMD 32-bit + "freebsd/arm64", // ARM FreeBSD + "freebsd/amd64", // AMD64 FreeBSD + } +} + +// BuildCommands returns cross-compile commands for edge platforms +func BuildCommands(appName string) map[string]string { + return map[string]string{ + "linux/arm64": fmt.Sprintf("GOOS=linux GOARCH=arm64 go build -o %s-linux-arm64 %s", appName, appName), + "linux/arm": fmt.Sprintf("GOOS=linux GOARCH=arm go build -o %s-linux-arm %s", appName, appName), + "linux/amd64": fmt.Sprintf("GOOS=linux GOARCH=amd64 go build -o %s-linux-amd64 %s", appName, appName), + "linux/386": fmt.Sprintf("GOOS=linux GOARCH=386 go build -o %s-linux-386 %s", appName, appName), + "freebsd/arm64": fmt.Sprintf("GOOS=freebsd GOARCH=arm64 go build -o %s-freebsd-arm64 %s", appName, appName), + "freebsd/amd64": fmt.Sprintf("GOOS=freebsd GOARCH=amd64 go build -o %s-freebsd-amd64 %s", appName, appName), + } +} + +// OptimizeForEdge returns build flags optimized for edge deployment +func OptimizeForEdge() string { + return "-ldflags='-s -w' -trimpath" // Strip debug info, reduce binary size +} diff --git a/pkg/swarm/edge_test.go b/pkg/swarm/edge_test.go new file mode 100644 index 000000000..7fefd5ad5 --- /dev/null +++ b/pkg/swarm/edge_test.go @@ -0,0 +1,276 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultEdgeWorkerConfig(t *testing.T) { + cfg := DefaultEdgeWorkerConfig() + + assert.Equal(t, int64(50), cfg.MaxMemoryMB) + assert.Equal(t, 50, cfg.MaxCPUPercent) + assert.Equal(t, uint64(10*1024*1024), cfg.EnableGCThreshold) + assert.False(t, cfg.DisableHeartbeat) + assert.Equal(t, 30*time.Second, cfg.HeartbeatInterval) + assert.Equal(t, 2, cfg.CompressionLevel) + assert.True(t, cfg.DisableWorkflow) + assert.True(t, cfg.DisableDashboard) + assert.False(t, cfg.MinimalMode) +} + +func TestEdgeWorker_GetResourceUsage(t *testing.T) { + cfg := &EdgeWorkerConfig{ + MaxMemoryMB: 100, + MaxCPUPercent: 75, + EnableGCThreshold: 20 * 1024 * 1024, + } + + ew := &EdgeWorker{ + config: cfg, + } + ew.memoryUsed.Store(50 * 1024 * 1024) // 50MB + ew.cpuPercent.Store(25) + ew.lastGCTime.Store(time.Now()) + + usage := ew.GetResourceUsage() + + assert.Equal(t, int64(50), usage["memory_mb"]) + assert.Equal(t, int64(100), usage["memory_max_mb"]) + assert.Equal(t, int64(25), usage["cpu_percent"]) + assert.Equal(t, 75, usage["cpu_max_percent"]) +} + +func TestEdgeWorker_IsHealthy(t *testing.T) { + cfg := &EdgeWorkerConfig{ + MaxMemoryMB: 100, + } + + ew := &EdgeWorker{ + config: cfg, + Worker: &Worker{ + running: atomic.Bool{}, + }, + } + ew.Worker.running.Store(true) + + // Healthy when under limit + ew.memoryUsed.Store(50 * 1024 * 1024) // 50MB < 100MB + assert.True(t, ew.IsHealthy()) + + // Unhealthy when over 90% limit + ew.memoryUsed.Store(95 * 1024 * 1024) // 95MB > 90MB threshold + assert.False(t, ew.IsHealthy()) +} + +func TestEdgeWorker_shouldGC(t *testing.T) { + tests := []struct { + name string + threshold uint64 + memoryUsed int64 + expectedResult bool + }{ + { + name: "under threshold", + threshold: 10 * 1024 * 1024, + memoryUsed: 5 * 1024 * 1024, + expectedResult: false, + }, + { + name: "over threshold", + threshold: 10 * 1024 * 1024, + memoryUsed: 15 * 1024 * 1024, + expectedResult: true, + }, + { + name: "no threshold set", + threshold: 0, + memoryUsed: 100 * 1024 * 1024, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ew := &EdgeWorker{ + config: &EdgeWorkerConfig{ + EnableGCThreshold: tt.threshold, + }, + } + ew.memoryUsed.Store(tt.memoryUsed) + + result := ew.shouldGC() + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestEdgeWorker_isOverLimit(t *testing.T) { + tests := []struct { + name string + maxMemoryMB int64 + memoryUsed int64 + maxCPUPercent int + cpuPercent int64 + expectedResult bool + }{ + { + name: "under all limits", + maxMemoryMB: 100, + memoryUsed: 50 * 1024 * 1024, + maxCPUPercent: 80, + cpuPercent: 40, + expectedResult: false, + }, + { + name: "over memory limit", + maxMemoryMB: 100, + memoryUsed: 150 * 1024 * 1024, + maxCPUPercent: 80, + cpuPercent: 40, + expectedResult: true, + }, + { + name: "over CPU limit", + maxMemoryMB: 100, + memoryUsed: 50 * 1024 * 1024, + maxCPUPercent: 80, + cpuPercent: 90, + expectedResult: true, + }, + { + name: "no limits set", + maxMemoryMB: 0, + memoryUsed: 1000 * 1024 * 1024, + maxCPUPercent: 0, + cpuPercent: 100, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ew := &EdgeWorker{ + config: &EdgeWorkerConfig{ + MaxMemoryMB: tt.maxMemoryMB, + MaxCPUPercent: tt.maxCPUPercent, + }, + } + ew.memoryUsed.Store(tt.memoryUsed) + ew.cpuPercent.Store(tt.cpuPercent) + + result := ew.isOverLimit() + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestEdgeWorker_GetEdgeConfig(t *testing.T) { + cfg := &EdgeWorkerConfig{ + MaxMemoryMB: 200, + MinimalMode: true, + } + + ew := &EdgeWorker{ + config: cfg, + } + + assert.Equal(t, cfg, ew.GetEdgeConfig()) + assert.Same(t, cfg, ew.GetEdgeConfig()) +} + +func TestGetEdgeBuildInfo(t *testing.T) { + info := GetEdgeBuildInfo() + + assert.Equal(t, "linux", info.GOOS) + assert.Equal(t, "arm64", info.GOARCH) + assert.Equal(t, "1.0.0", info.Version) + assert.NotEmpty(t, info.BuiltAt) +} + +func TestSupportedPlatforms(t *testing.T) { + platforms := SupportedPlatforms() + + assert.Contains(t, platforms, "linux/arm64") + assert.Contains(t, platforms, "linux/arm") + assert.Contains(t, platforms, "linux/amd64") + assert.Contains(t, platforms, "linux/386") + assert.Contains(t, platforms, "freebsd/arm64") + assert.Contains(t, platforms, "freebsd/amd64") +} + +func TestBuildCommands(t *testing.T) { + commands := BuildCommands("picoclaw") + + // Check that key platforms have build commands + assert.Contains(t, commands, "linux/arm64") + assert.Contains(t, commands, "linux/amd64") + assert.Contains(t, commands, "freebsd/arm64") + + // Verify command format + arm64Cmd := commands["linux/arm64"] + assert.Contains(t, arm64Cmd, "GOOS=linux") + assert.Contains(t, arm64Cmd, "GOARCH=arm64") + assert.Contains(t, arm64Cmd, "go build") + assert.Contains(t, arm64Cmd, "picoclaw-linux-arm64") +} + +func TestOptimizeForEdge(t *testing.T) { + flags := OptimizeForEdge() + + assert.Contains(t, flags, "-ldflags") + assert.Contains(t, flags, "-s -w") + assert.Contains(t, flags, "-trimpath") +} + +func TestEdgeWorker_StopIdempotent(t *testing.T) { + // Create a minimal Worker that won't panic on Stop + worker := &Worker{ + taskQueue: make(chan *SwarmTask), + } + worker.running.Store(true) + + ew := &EdgeWorker{ + Worker: worker, + config: DefaultEdgeWorkerConfig(), + } + + // First stop should work + ew.Stop() + + // Second stop should be no-op + ew.Stop() + + assert.True(t, ew.edgeStop.Load()) +} + +func TestDefaultEdgeWorkerConfig_CompressionLevels(t *testing.T) { + cfg := DefaultEdgeWorkerConfig() + + // Verify compression is enabled but not maximum + assert.Greater(t, cfg.CompressionLevel, 0) + assert.Less(t, cfg.CompressionLevel, 9) +} + +func TestEdgeWorkerConfig_DefaultsAreReasonable(t *testing.T) { + cfg := DefaultEdgeWorkerConfig() + + // Memory limit should be reasonable for edge devices (10-100MB) + assert.GreaterOrEqual(t, cfg.MaxMemoryMB, int64(10)) + assert.LessOrEqual(t, cfg.MaxMemoryMB, int64(100)) + + // CPU limit should be reasonable (10-100%) + assert.GreaterOrEqual(t, cfg.MaxCPUPercent, 10) + assert.LessOrEqual(t, cfg.MaxCPUPercent, 100) + + // Heartbeat should be less frequent than default + assert.GreaterOrEqual(t, cfg.HeartbeatInterval, 30*time.Second) +} diff --git a/pkg/swarm/election.go b/pkg/swarm/election.go new file mode 100644 index 000000000..2e5612fb8 --- /dev/null +++ b/pkg/swarm/election.go @@ -0,0 +1,694 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// ElectionManager manages leader election using NATS JetStream KV store +type ElectionManager struct { + // NATS connection + nc *nats.Conn + js nats.JetStreamContext + + // Identity + nodeID string + hid string + sid string + + // Election configuration + electionSubject string // Subject for election coordination + leaseDuration time.Duration + leaderKey string // KV store key for leader lease + + // State + mu sync.RWMutex + isLeader bool + isParticipant bool + currentLeaderID string + leaseExpiry int64 + lastRevision uint64 // Last known revision for optimistic updates + + // Channels + leaderChan chan bool // true when becomes leader, false when loses leadership + stopChan chan struct{} + electionTimer *time.Timer + + // NATS subscription + leaseSub *nats.Subscription + + // Callbacks + onBecameLeader func() + onLostLeadership func() + onNewLeader func(leaderID string) +} + +// ElectionConfig holds configuration for leader election +type ElectionConfig struct { + // ElectionSubject is the NATS subject for election messages + ElectionSubject string + + // LeaseDuration is how long a leader lease is valid + LeaseDuration time.Duration + + // ElectionInterval is how often to check/renew leadership + ElectionInterval time.Duration + + // PreVoteDelay is delay before attempting election (for staggered starts) + PreVoteDelay time.Duration +} + +// DefaultElectionConfig returns default election configuration +func DefaultElectionConfig() *ElectionConfig { + return &ElectionConfig{ + ElectionSubject: "picoclaw.election", + LeaseDuration: 10 * time.Second, + ElectionInterval: 3 * time.Second, + PreVoteDelay: time.Duration(0), + } +} + +// NewElectionManager creates a new election manager +func NewElectionManager(nc *nats.Conn, js nats.JetStreamContext, nodeID, hid, sid string) *ElectionManager { + return &ElectionManager{ + nc: nc, + js: js, + nodeID: nodeID, + hid: hid, + sid: sid, + leaderKey: fmt.Sprintf("picoclaw.election.%s.leader", hid), + leaderChan: make(chan bool, 1), + stopChan: make(chan struct{}), + } +} + +// Start begins participating in leader election +func (em *ElectionManager) Start(ctx context.Context, cfg *ElectionConfig) error { + em.mu.Lock() + defer em.mu.Unlock() + + if em.isParticipant { + return nil + } + + if cfg == nil { + cfg = DefaultElectionConfig() + } + + em.electionSubject = cfg.ElectionSubject + em.leaseDuration = cfg.LeaseDuration + + // Create KV store for leader lease (or get if exists) + bucketName := fmt.Sprintf("PICOCLAW_ELECTION_%s", em.hid) + _, err := em.js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: bucketName, + TTL: cfg.LeaseDuration * 2, + }) + if err != nil { + // Check if it already exists - try to bind to it + _, bindErr := em.js.KeyValue(bucketName) + if bindErr != nil { + return fmt.Errorf("failed to create or bind to election KV store: %w", err) + } + } + + // Store bucket name for later use + em.leaderKey = fmt.Sprintf("leader.%s", em.hid) + + // Subscribe to leadership change notifications + subject := fmt.Sprintf("$KV.%s.>", bucketName) + sub, err := em.nc.Subscribe(subject, func(msg *nats.Msg) { + em.handleLeaseUpdate(msg) + }) + if err != nil { + return fmt.Errorf("failed to subscribe to lease updates: %w", err) + } + em.leaseSub = sub + + em.isParticipant = true + + // Start election goroutine + go em.electionLoop(ctx, cfg) + + logger.InfoCF("swarm", "Election manager started", map[string]interface{}{ + "node_id": em.nodeID, + "hid": em.hid, + "lease_ttl": cfg.LeaseDuration.String(), + }) + + return nil +} + +// Stop stops participating in leader election +func (em *ElectionManager) Stop() { + em.mu.Lock() + if !em.isParticipant { + em.mu.Unlock() + return + } + + wasLeader := em.isLeader + em.isParticipant = false + em.mu.Unlock() + + // Close NATS subscription first to prevent new callbacks + if em.leaseSub != nil { + em.leaseSub.Unsubscribe() + em.leaseSub = nil + } + + // Stop the election loop + close(em.stopChan) + + if em.electionTimer != nil { + em.electionTimer.Stop() + } + + // Step down from leadership without holding the lock + if wasLeader { + em.stepDown() + } + + logger.InfoCF("swarm", "Election manager stopped", map[string]interface{}{ + "node_id": em.nodeID, + }) +} + +// electionLoop runs the election logic +func (em *ElectionManager) electionLoop(ctx context.Context, cfg *ElectionConfig) { + // Initial delay for staggered starts across nodes + if cfg.PreVoteDelay > 0 { + time.Sleep(cfg.PreVoteDelay) + } + + ticker := time.NewTicker(cfg.ElectionInterval) + defer ticker.Stop() + + // Try to become leader immediately + em.attemptLeadership() + + for { + select { + case <-ctx.Done(): + return + case <-em.stopChan: + return + case isLeader := <-em.leaderChan: + em.mu.Lock() + em.isLeader = isLeader + em.mu.Unlock() + + if isLeader { + logger.InfoCF("swarm", "Became leader", map[string]interface{}{ + "node_id": em.nodeID, + }) + if em.onBecameLeader != nil { + em.onBecameLeader() + } + } else { + logger.WarnCF("swarm", "Lost leadership", map[string]interface{}{ + "node_id": em.nodeID, + }) + if em.onLostLeadership != nil { + em.onLostLeadership() + } + // Try to regain leadership + em.attemptLeadership() + } + + case <-ticker.C: + // Renew lease if we're leader + em.mu.RLock() + isLeader := em.isLeader + em.mu.RUnlock() + + if isLeader { + em.renewLease() + } else { + // Periodically attempt to acquire leadership + em.attemptLeadership() + } + } + } +} + +// attemptLeadership tries to acquire leadership lease +func (em *ElectionManager) attemptLeadership() { + kv, err := em.js.KeyValue(fmt.Sprintf("PICOCLAW_ELECTION_%s", em.hid)) + if err != nil { + logger.DebugCF("swarm", "Failed to get election KV", map[string]interface{}{ + "error": err.Error(), + }) + return + } + + // Try to get current entry first + entry, err := kv.Get(em.leaderKey) + + now := time.Now().UnixMilli() + leaseExpiry := now + em.leaseDuration.Milliseconds() + infoBytes := []byte(fmt.Sprintf("%s|%s|%d", em.nodeID, em.sid, leaseExpiry)) + + if err != nil { + // No entry exists, try to create + revision, err := kv.Create(em.leaderKey, infoBytes) + if err != nil { + // Someone else created it first, or other error + return + } + + // Verify we're actually the leader by reading back + entry, err := kv.Get(em.leaderKey) + if err != nil { + return + } + + entryLeader, _, entryExpiry, ok := parseLeaderInfo(string(entry.Value())) + if !ok || entryLeader != em.nodeID { + // Not our entry, someone else won the race + em.mu.Lock() + em.currentLeaderID = entryLeader + em.leaseExpiry = entryExpiry + em.mu.Unlock() + return + } + + // Successfully became leader + em.mu.Lock() + em.isLeader = true + em.currentLeaderID = em.nodeID + em.leaseExpiry = entryExpiry + em.lastRevision = revision + em.mu.Unlock() + + select { + case em.leaderChan <- true: + default: + } + return + } + + // Entry exists, parse it + currentLeader, _, expiry, ok := parseLeaderInfo(string(entry.Value())) + if !ok { + // Corrupted entry, try to take over + em.tryUpdateLeader(kv, infoBytes, entry.Revision()) + return + } + + // Update our view of current leader + em.mu.Lock() + em.currentLeaderID = currentLeader + em.leaseExpiry = expiry + wasLeader := em.isLeader + em.mu.Unlock() + + // Check if we are already the leader + if currentLeader == em.nodeID { + // Try to renew our lease if it's expiring soon + if expiry < now+em.leaseDuration.Milliseconds()/2 { + em.renewLease() + } + return + } + + // Someone else is the leader + if wasLeader { + // We lost leadership + em.mu.Lock() + em.isLeader = false + em.mu.Unlock() + + select { + case em.leaderChan <- false: + default: + } + } + + // Check if current lease is expired + if expiry < now { + // Lease expired, try to take over + em.tryUpdateLeader(kv, infoBytes, entry.Revision()) + } +} + +// tryUpdateLeader attempts to update the leader entry +func (em *ElectionManager) tryUpdateLeader(kv nats.KeyValue, infoBytes []byte, lastRevision uint64) { + revision, err := kv.Update(em.leaderKey, infoBytes, lastRevision) + if err != nil { + // Failed to update (race condition) + return + } + + // Verify by reading back + entry, err := kv.Get(em.leaderKey) + if err != nil { + return + } + + entryLeader, _, entryExpiry, ok := parseLeaderInfo(string(entry.Value())) + if !ok || entryLeader != em.nodeID { + // Not our entry after update + em.mu.Lock() + em.currentLeaderID = entryLeader + em.leaseExpiry = entryExpiry + em.mu.Unlock() + return + } + + // Successfully became leader + em.mu.Lock() + wasNotLeader := !em.isLeader + em.isLeader = true + em.currentLeaderID = em.nodeID + em.leaseExpiry = entryExpiry + em.lastRevision = revision + em.mu.Unlock() + + if em.onNewLeader != nil { + em.onNewLeader(em.nodeID) + } + + if wasNotLeader { + select { + case em.leaderChan <- true: + default: + } + } +} + +// renewLease renews the leadership lease +func (em *ElectionManager) renewLease() { + kv, err := em.js.KeyValue(fmt.Sprintf("PICOCLAW_ELECTION_%s", em.hid)) + if err != nil { + return + } + + em.mu.RLock() + lastRevision := em.lastRevision + em.mu.RUnlock() + + now := time.Now().UnixMilli() + leaseExpiry := now + em.leaseDuration.Milliseconds() + + infoBytes := []byte(fmt.Sprintf("%s|%s|%d", em.nodeID, em.sid, leaseExpiry)) + + // Use the last revision for optimistic update + revision, err := kv.Update(em.leaderKey, infoBytes, lastRevision) + if err != nil { + // Lost leadership - revision mismatch or other error + em.mu.Lock() + em.isLeader = false + em.lastRevision = 0 + em.mu.Unlock() + + select { + case em.leaderChan <- false: + default: + } + + logger.WarnCF("swarm", "Failed to renew leadership lease", map[string]interface{}{ + "node_id": em.nodeID, + "error": err.Error(), + }) + return + } + + em.mu.Lock() + em.leaseExpiry = leaseExpiry + em.lastRevision = revision + em.mu.Unlock() + + logger.DebugCF("swarm", "Renewed leadership lease", map[string]interface{}{ + "node_id": em.nodeID, + "expires": time.UnixMilli(leaseExpiry).Format(time.RFC3339), + }) +} + +// stepDown voluntarily gives up leadership +func (em *ElectionManager) stepDown() { + kv, err := em.js.KeyValue(fmt.Sprintf("PICOCLAW_ELECTION_%s", em.hid)) + if err != nil { + return + } + + kv.Delete(em.leaderKey) + + em.mu.Lock() + em.isLeader = false + em.currentLeaderID = "" + em.lastRevision = 0 + em.mu.Unlock() + + logger.InfoCF("swarm", "Stepped down from leadership", map[string]interface{}{ + "node_id": em.nodeID, + }) +} + +// handleLeaseUpdate handles KV updates for leadership changes +func (em *ElectionManager) handleLeaseUpdate(msg *nats.Msg) { + em.mu.RLock() + isParticipant := em.isParticipant + wasLeader := em.isLeader + em.mu.RUnlock() + + if !isParticipant { + return + } + + // Check if this is a deletion or update + if len(msg.Data) == 0 || msg.Header.Get("Operation") == "DEL" { + em.mu.Lock() + em.currentLeaderID = "" + em.leaseExpiry = 0 + if !wasLeader { + em.mu.Unlock() + // Leader stepped down, try to acquire + go em.attemptLeadership() + return + } + em.mu.Unlock() + return + } + + // Parse the new leader info + newLeader, _, expiry, ok := parseLeaderInfo(string(msg.Data)) + if !ok { + return + } + + // Update our view of current leader + em.mu.Lock() + oldLeaderID := em.currentLeaderID + em.currentLeaderID = newLeader + em.leaseExpiry = expiry + + // Check if we were leader but are no longer + if wasLeader && newLeader != em.nodeID { + em.isLeader = false + em.mu.Unlock() + + select { + case em.leaderChan <- false: + default: + } + return + } + + // Check if we thought we were leader but someone else is + if em.isLeader && newLeader != em.nodeID { + em.isLeader = false + em.mu.Unlock() + + select { + case em.leaderChan <- false: + default: + } + return + } + + em.mu.Unlock() + + // Notify about new leader if it changed + if em.onNewLeader != nil && newLeader != em.nodeID && oldLeaderID != newLeader { + em.onNewLeader(newLeader) + } +} + +// IsLeader returns true if this node is currently the leader +func (em *ElectionManager) IsLeader() bool { + em.mu.RLock() + defer em.mu.RUnlock() + + // If we are leader, return our ID + if em.isLeader && time.Now().UnixMilli() <= em.leaseExpiry { + return true + } + + return false +} + +// GetLeaderID returns the current leader's node ID +func (em *ElectionManager) GetLeaderID() string { + em.mu.RLock() + defer em.mu.RUnlock() + + // If we are leader, return our ID + if em.isLeader && time.Now().UnixMilli() <= em.leaseExpiry { + return em.nodeID + } + + // Return the known leader ID if lease is valid + if em.leaseExpiry > 0 && time.Now().UnixMilli() <= em.leaseExpiry { + return em.currentLeaderID + } + + return "" +} + +// OnBecameLeader sets callback for when this node becomes leader +func (em *ElectionManager) OnBecameLeader(fn func()) { + em.onBecameLeader = fn +} + +// OnLostLeadership sets callback for when this node loses leadership +func (em *ElectionManager) OnLostLeadership(fn func()) { + em.onLostLeadership = fn +} + +// OnNewLeader sets callback for when any node becomes leader +func (em *ElectionManager) OnNewLeader(fn func(leaderID string)) { + em.onNewLeader = fn +} + +// parseLeaderInfo parses the leader info string +func parseLeaderInfo(s string) (nodeID, sid string, expiry int64, ok bool) { + parts := splitN(s, 3, "|") + if len(parts) != 3 { + return "", "", 0, false + } + + var exp int64 + _, err := fmt.Sscanf(parts[2], "%d", &exp) + if err != nil { + return "", "", 0, false + } + + return parts[0], parts[1], exp, true +} + +// splitN splits a string into at most n parts +func splitN(s string, n int, sep string) []string { + if n <= 0 { + return nil + } + if n == 1 { + return []string{s} + } + + parts := make([]string, 0, n) + start := 0 + sepLen := len(sep) + + for i := 0; i < n-1; i++ { + idx := indexOf(s, sep, start) + if idx == -1 { + // Fewer than n parts, return what we have + parts = append(parts, s[start:]) + return parts + } + parts = append(parts, s[start:idx]) + start = idx + sepLen + } + + // Add the rest + parts = append(parts, s[start:]) + return parts +} + +// indexOf finds the index of sep in s starting from start +func indexOf(s, sep string, start int) int { + if start >= len(s) { + return -1 + } + for i := start; i <= len(s)-len(sep); i++ { + if s[i:i+len(sep)] == sep { + return i + } + } + return -1 +} + +// RoleSwitcher handles dynamic role switching based on election results +type RoleSwitcher struct { + electionMgr *ElectionManager + nodeInfo *NodeInfo + manager *Manager + originalRole NodeRole + mu sync.RWMutex +} + +// NewRoleSwitcher creates a new role switcher +func NewRoleSwitcher(em *ElectionManager, nodeInfo *NodeInfo, manager *Manager) *RoleSwitcher { + return &RoleSwitcher{ + electionMgr: em, + nodeInfo: nodeInfo, + manager: manager, + originalRole: nodeInfo.Role, + } +} + +// GetCurrentRole returns the current role +func (rs *RoleSwitcher) GetCurrentRole() NodeRole { + rs.mu.RLock() + defer rs.mu.RUnlock() + return rs.nodeInfo.Role +} + +// Start begins monitoring election results +func (rs *RoleSwitcher) Start() { + rs.electionMgr.OnBecameLeader(rs.onBecameLeader) + rs.electionMgr.OnLostLeadership(rs.onLostLeadership) +} + +func (rs *RoleSwitcher) onBecameLeader() { + rs.mu.Lock() + defer rs.mu.Unlock() + + // Promote to coordinator if not already + if rs.nodeInfo.Role != RoleCoordinator { + logger.InfoCF("swarm", "Promoting to coordinator", map[string]interface{}{ + "node_id": rs.nodeInfo.ID, + "from": string(rs.nodeInfo.Role), + }) + rs.nodeInfo.Metadata = map[string]string{ + "original_role": string(rs.nodeInfo.Role), + } + rs.nodeInfo.Role = RoleCoordinator + } +} + +func (rs *RoleSwitcher) onLostLeadership() { + rs.mu.Lock() + defer rs.mu.Unlock() + + // Demote back to original role + originalRole := rs.nodeInfo.Metadata["original_role"] + if originalRole != "" && originalRole != string(rs.nodeInfo.Role) { + logger.InfoCF("swarm", "Demoting from coordinator", map[string]interface{}{ + "node_id": rs.nodeInfo.ID, + "to": originalRole, + }) + rs.nodeInfo.Role = NodeRole(originalRole) + } +} diff --git a/pkg/swarm/election_test.go b/pkg/swarm/election_test.go new file mode 100644 index 000000000..3371d1f1f --- /dev/null +++ b/pkg/swarm/election_test.go @@ -0,0 +1,283 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestElectionManager_SingleCandidate(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + em := NewElectionManager(tn.NC(), tn.JS(), "node-1", "test-hid", "test-sid") + + becameLeader := make(chan bool, 1) + em.OnBecameLeader(func() { + select { + case becameLeader <- true: + default: + } + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + cfg := &ElectionConfig{ + ElectionSubject: "picoclaw.test.election", + LeaseDuration: 2 * time.Second, + ElectionInterval: 500 * time.Millisecond, + } + + err := em.Start(ctx, cfg) + require.NoError(t, err) + + // Should become leader immediately + select { + case <-becameLeader: + assert.True(t, em.IsLeader()) + assert.Equal(t, "node-1", em.GetLeaderID()) + case <-time.After(1 * time.Second): + t.Fatal("Did not become leader in time") + } + + em.Stop() + }) +} + +func TestElectionManager_MultipleCandidates(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cfg := &ElectionConfig{ + ElectionSubject: "picoclaw.test.election.multi", + LeaseDuration: 3 * time.Second, + ElectionInterval: 1 * time.Second, + } + + em1 := NewElectionManager(tn.NC(), tn.JS(), "node-1", "test-hid-multi", "sid-1") + em2 := NewElectionManager(tn.NC(), tn.JS(), "node-2", "test-hid-multi", "sid-2") + em3 := NewElectionManager(tn.NC(), tn.JS(), "node-3", "test-hid-multi", "sid-3") + + leaderChanges := make(chan string, 10) + + for _, em := range []*ElectionManager{em1, em2, em3} { + em := em + em.OnBecameLeader(func() { + leaderChanges <- em.nodeID + }) + em.OnNewLeader(func(leaderID string) { + leaderChanges <- "new:" + leaderID + }) + } + + // Start all with staggered delays + require.NoError(t, em1.Start(ctx, cfg)) + + time.Sleep(100 * time.Millisecond) + require.NoError(t, em2.Start(ctx, cfg)) + + time.Sleep(100 * time.Millisecond) + require.NoError(t, em3.Start(ctx, cfg)) + + // First node should be leader + assert.Eventually(t, func() bool { + return em1.IsLeader() + }, 2*time.Second, 100*time.Millisecond, "node-1 should become leader") + + // Wait for all nodes to discover the leader + assert.Eventually(t, func() bool { + return em1.GetLeaderID() == "node-1" && + em2.GetLeaderID() == "node-1" && + em3.GetLeaderID() == "node-1" + }, 2*time.Second, 100*time.Millisecond, "all nodes should discover node-1 as leader") + + // Stop leader, second should take over + em1.Stop() + + assert.Eventually(t, func() bool { + return !em1.IsLeader() && (em2.IsLeader() || em3.IsLeader()) + }, 5*time.Second, 200*time.Millisecond, "new leader should be elected") + + // Wait for all nodes to agree on the leader + assert.Eventually(t, func() bool { + l2 := em2.GetLeaderID() + l3 := em3.GetLeaderID() + return l2 != "" && l2 == l3 + }, 2*time.Second, 100*time.Millisecond, "all nodes should agree on leader") + + // Only one should be leader + leaders := 0 + if em2.IsLeader() { + leaders++ + } + if em3.IsLeader() { + leaders++ + } + assert.Equal(t, 1, leaders, "Only one node should be leader") + + // All nodes should agree on the leader ID + assert.Equal(t, em2.GetLeaderID(), em3.GetLeaderID()) + + em2.Stop() + em3.Stop() + }) +} + +func TestElectionManager_LeaseRenewal(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + em := NewElectionManager(tn.NC(), tn.JS(), "node-1", "test-hid-renew", "test-sid") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cfg := &ElectionConfig{ + ElectionSubject: "picoclaw.test.election.renew", + LeaseDuration: 1 * time.Second, + ElectionInterval: 300 * time.Millisecond, + } + + require.NoError(t, em.Start(ctx, cfg)) + + // Become leader + assert.Eventually(t, func() bool { + return em.IsLeader() + }, 2*time.Second, 100*time.Millisecond) + + // Stay leader for multiple lease periods + for i := 0; i < 5; i++ { + time.Sleep(500 * time.Millisecond) + assert.True(t, em.IsLeader(), "Should remain leader during lease period %d", i) + } + + em.Stop() + }) +} + +func TestElectionManager_StepDown(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + em1 := NewElectionManager(tn.NC(), tn.JS(), "node-1", "test-hid-stepdown", "sid-1") + em2 := NewElectionManager(tn.NC(), tn.JS(), "node-2", "test-hid-stepdown", "sid-2") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cfg := &ElectionConfig{ + ElectionSubject: "picoclaw.test.election.stepdown", + LeaseDuration: 2 * time.Second, + ElectionInterval: 500 * time.Millisecond, + } + + require.NoError(t, em1.Start(ctx, cfg)) + require.NoError(t, em2.Start(ctx, cfg)) + + // em1 should become leader + assert.Eventually(t, func() bool { + return em1.IsLeader() && !em2.IsLeader() + }, 2*time.Second, 100*time.Millisecond) + + // em1 steps down + em1.Stop() + + // em2 should take over + assert.Eventually(t, func() bool { + return !em1.IsLeader() && em2.IsLeader() + }, 3*time.Second, 100*time.Millisecond) + + em2.Stop() + }) +} + +func TestRoleSwitcher_PromotionToCoordinator(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cfg := &ElectionConfig{ + ElectionSubject: "picoclaw.test.role.switch", + LeaseDuration: 2 * time.Second, + ElectionInterval: 500 * time.Millisecond, + } + + em := NewElectionManager(tn.NC(), tn.JS(), "node-1", "test-hid-switch", "sid-1") + nodeInfo := CreateTestNodeInfo("node-1", string(RoleWorker), []string{"test"}) + + // Create a minimal manager wrapper + manager := &Manager{ + nodeInfo: nodeInfo, + } + + rs := NewRoleSwitcher(em, nodeInfo, manager) + rs.Start() // Start the role switcher to register callbacks + + require.NoError(t, em.Start(ctx, cfg)) + + // Should become leader and promote to coordinator + assert.Eventually(t, func() bool { + return rs.GetCurrentRole() == RoleCoordinator + }, 2*time.Second, 100*time.Millisecond) + + assert.Equal(t, RoleCoordinator, nodeInfo.Role) + assert.Equal(t, string(RoleWorker), nodeInfo.Metadata["original_role"]) + + em.Stop() + }) +} + +func TestParseLeaderInfo(t *testing.T) { + tests := []struct { + name string + input string + nodeID string + sid string + expiry int64 + ok bool + }{ + { + name: "valid input", + input: "node-1|sid-1|1234567890", + nodeID: "node-1", + sid: "sid-1", + expiry: 1234567890, + ok: true, + }, + { + name: "invalid format", + input: "invalid", + ok: false, + }, + { + name: "missing parts", + input: "node-1|sid-1", + ok: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nodeID, sid, expiry, ok := parseLeaderInfo(tt.input) + assert.Equal(t, tt.ok, ok) + if ok { + assert.Equal(t, tt.nodeID, nodeID) + assert.Equal(t, tt.sid, sid) + assert.Equal(t, tt.expiry, expiry) + } + }) + } +} + +func TestDefaultElectionConfig(t *testing.T) { + cfg := DefaultElectionConfig() + + assert.Equal(t, "picoclaw.election", cfg.ElectionSubject) + assert.Equal(t, 10*time.Second, cfg.LeaseDuration) + assert.Equal(t, 3*time.Second, cfg.ElectionInterval) + assert.Equal(t, time.Duration(0), cfg.PreVoteDelay) +} diff --git a/pkg/swarm/embedded.go b/pkg/swarm/embedded.go new file mode 100644 index 000000000..65b3741ac --- /dev/null +++ b/pkg/swarm/embedded.go @@ -0,0 +1,105 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "fmt" + "time" + + "github.com/nats-io/nats-server/v2/server" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// EmbeddedNATS wraps an embedded NATS server for development mode +type EmbeddedNATS struct { + server *server.Server + cfg *config.NATSConfig +} + +// NewEmbeddedNATS creates a new embedded NATS server +func NewEmbeddedNATS(cfg *config.NATSConfig) *EmbeddedNATS { + return &EmbeddedNATS{cfg: cfg} +} + +// Start starts the embedded NATS server +func (e *EmbeddedNATS) Start() error { + port := e.cfg.EmbeddedPort + if port == 0 { + port = 4222 + } + + host := e.cfg.EmbeddedHost + if host == "" { + host = "0.0.0.0" // Default to listening on all interfaces for external access + } + + opts := &server.Options{ + Host: host, + Port: port, + NoLog: false, + NoSigs: true, + MaxControlLine: 2048, + MaxPayload: 4 * 1024 * 1024, // 4MB + JetStream: true, // Enable JetStream + // Use memory storage for JetStream (no persistence) + StoreDir: "memory://", + } + + ns, err := server.NewServer(opts) + if err != nil { + return fmt.Errorf("failed to create embedded NATS server: %w", err) + } + + go ns.Start() + + // Wait for server to be ready + if !ns.ReadyForConnections(10 * time.Second) { + return fmt.Errorf("embedded NATS server failed to start within timeout") + } + + e.server = ns + logger.InfoCF("swarm", "Embedded NATS server started", map[string]interface{}{ + "host": host, + "port": port, + }) + + return nil +} + +// Stop stops the embedded NATS server +func (e *EmbeddedNATS) Stop() { + if e.server != nil { + e.server.Shutdown() + logger.InfoC("swarm", "Embedded NATS server stopped") + } +} + +// ClientURL returns the URL for local clients to connect +func (e *EmbeddedNATS) ClientURL() string { + port := e.cfg.EmbeddedPort + if port == 0 { + port = 4222 + } + return fmt.Sprintf("nats://localhost:%d", port) +} + +// ExternalURL returns the URL for external clients to connect (uses the actual hostname) +func (e *EmbeddedNATS) ExternalURL(hostname string) string { + port := e.cfg.EmbeddedPort + if port == 0 { + port = 4222 + } + if hostname == "" { + hostname = "localhost" + } + return fmt.Sprintf("nats://%s:%d", hostname, port) +} + +// IsRunning returns true if the embedded server is running +func (e *EmbeddedNATS) IsRunning() bool { + return e.server != nil && e.server.Running() +} diff --git a/pkg/swarm/embedded_test.go b/pkg/swarm/embedded_test.go new file mode 100644 index 000000000..f493d763b --- /dev/null +++ b/pkg/swarm/embedded_test.go @@ -0,0 +1,138 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "fmt" + "strings" + "testing" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestEmbeddedNATS(t *testing.T) { + tests := []struct { + name string + fn func(t *testing.T) + }{ + { + name: "start and stop", + fn: func(t *testing.T) { + port := freePort(t) + cfg := &config.NATSConfig{EmbeddedPort: port} + e := NewEmbeddedNATS(cfg) + + if e.IsRunning() { + t.Error("IsRunning() = true before Start, want false") + } + + if err := e.Start(); err != nil { + t.Fatalf("Start() error: %v", err) + } + + if !e.IsRunning() { + t.Error("IsRunning() = false after Start, want true") + } + + e.Stop() + + if e.IsRunning() { + t.Error("IsRunning() = true after Stop, want false") + } + }, + }, + { + name: "client URL format", + fn: func(t *testing.T) { + port := freePort(t) + cfg := &config.NATSConfig{EmbeddedPort: port} + e := NewEmbeddedNATS(cfg) + if err := e.Start(); err != nil { + t.Fatalf("Start() error: %v", err) + } + defer e.Stop() + + url := e.ClientURL() + want := fmt.Sprintf("nats://127.0.0.1:%d", port) + if url != want { + t.Errorf("ClientURL() = %q, want %q", url, want) + } + }, + }, + { + name: "connect client", + fn: func(t *testing.T) { + port := freePort(t) + cfg := &config.NATSConfig{EmbeddedPort: port} + e := NewEmbeddedNATS(cfg) + if err := e.Start(); err != nil { + t.Fatalf("Start() error: %v", err) + } + defer e.Stop() + + nc, err := nats.Connect(e.ClientURL()) + if err != nil { + t.Fatalf("nats.Connect() error: %v", err) + } + defer nc.Close() + + if !nc.IsConnected() { + t.Error("IsConnected() = false, want true") + } + }, + }, + { + name: "multiple start stop cycles", + fn: func(t *testing.T) { + port := freePort(t) + cfg := &config.NATSConfig{EmbeddedPort: port} + + for i := 0; i < 3; i++ { + e := NewEmbeddedNATS(cfg) + if err := e.Start(); err != nil { + t.Fatalf("cycle %d: Start() error: %v", i, err) + } + if !e.IsRunning() { + t.Errorf("cycle %d: IsRunning() = false after Start", i) + } + e.Stop() + if e.IsRunning() { + t.Errorf("cycle %d: IsRunning() = true after Stop", i) + } + } + }, + }, + { + name: "custom port", + fn: func(t *testing.T) { + port := freePort(t) + cfg := &config.NATSConfig{EmbeddedPort: port} + e := NewEmbeddedNATS(cfg) + if err := e.Start(); err != nil { + t.Fatalf("Start() error: %v", err) + } + defer e.Stop() + + url := e.ClientURL() + if !strings.Contains(url, fmt.Sprintf(":%d", port)) { + t.Errorf("ClientURL() = %q, does not contain port %d", url, port) + } + + // Verify we can actually connect on this port + nc, err := nats.Connect(url) + if err != nil { + t.Fatalf("Connect to custom port %d error: %v", port, err) + } + nc.Close() + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, tt.fn) + } +} diff --git a/pkg/swarm/failover.go b/pkg/swarm/failover.go new file mode 100644 index 000000000..f42b3d398 --- /dev/null +++ b/pkg/swarm/failover.go @@ -0,0 +1,420 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + // DefaultHeartbeatTimeout is how long before a node is considered failed + DefaultHeartbeatTimeout = 60 * time.Second + // DefaultProgressStallTimeout is how long without progress before a task is failed + DefaultProgressStallTimeout = 2 * time.Minute + // FailoverCheckInterval is how often to check for failures + FailoverCheckInterval = 10 * time.Second + // ClaimLockTTL is how long a claim lock is valid + ClaimLockTTL = 30 * time.Second +) + +// FailoverManager manages task failure detection and recovery +type FailoverManager struct { + discovery *Discovery + lifecycle *TaskLifecycleStore + checkpointStore *CheckpointStore + bridge *NATSBridge + nodeInfo *NodeInfo + js nats.JetStreamContext + + // Configuration + heartbeatTimeout time.Duration + progressStallTimeout time.Duration + + // State + running bool + mu sync.RWMutex + claimedTasks map[string]*ClaimInfo +} + +// ClaimInfo tracks claimed task information +type ClaimInfo struct { + TaskID string + ClaimedBy string + ClaimedAt time.Time + ExpiresAt time.Time + Checkpoint *TaskCheckpoint +} + +// NewFailoverManager creates a new failover manager +func NewFailoverManager( + discovery *Discovery, + lifecycle *TaskLifecycleStore, + checkpointStore *CheckpointStore, + bridge *NATSBridge, + nodeInfo *NodeInfo, + js nats.JetStreamContext, +) *FailoverManager { + return &FailoverManager{ + discovery: discovery, + lifecycle: lifecycle, + checkpointStore: checkpointStore, + bridge: bridge, + nodeInfo: nodeInfo, + js: js, + heartbeatTimeout: DefaultHeartbeatTimeout, + progressStallTimeout: DefaultProgressStallTimeout, + claimedTasks: make(map[string]*ClaimInfo), + } +} + +// Start begins failure detection and recovery +func (fm *FailoverManager) Start(ctx context.Context) error { + fm.mu.Lock() + fm.running = true + fm.mu.Unlock() + + logger.InfoC("swarm", "Failover manager starting") + + // Start failure detection loop + go fm.detectFailuresLoop(ctx) + + // Start claim expiration cleanup + go fm.cleanupExpiredClaims(ctx) + + return nil +} + +// Stop stops the failover manager +func (fm *FailoverManager) Stop() { + fm.mu.Lock() + defer fm.mu.Unlock() + fm.running = false +} + +// detectFailuresLoop continuously checks for node and task failures +func (fm *FailoverManager) detectFailuresLoop(ctx context.Context) { + ticker := time.NewTicker(FailoverCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + fm.DetectFailures(ctx) + } + } +} + +// DetectFailures checks for node failures and stalled tasks +func (fm *FailoverManager) DetectFailures(ctx context.Context) { + // Get all discovered nodes + nodes := fm.discovery.GetAllNodes() + + now := time.Now() + + for _, node := range nodes { + // Skip self and offline nodes + if node.ID == fm.nodeInfo.ID || node.Status == StatusOffline { + continue + } + + // Check heartbeat timeout + lastSeen := time.UnixMilli(node.LastSeen) + if now.Sub(lastSeen) > fm.heartbeatTimeout { + logger.WarnCF("swarm", "Node heartbeat timeout detected", map[string]interface{}{ + "node_id": node.ID, + "last_seen": lastSeen.Format(time.RFC3339), + "timeout": fm.heartbeatTimeout, + }) + fm.handleNodeFailure(ctx, node) + } + } + + // Check for stalled tasks + fm.detectStalledTasks(ctx) +} + +// detectStalledTasks looks for tasks that haven't made progress +func (fm *FailoverManager) detectStalledTasks(ctx context.Context) { + // Get active tasks + activeTasks, err := fm.lifecycle.GetActiveTasks(ctx) + if err != nil { + logger.WarnCF("swarm", "Failed to get active tasks for stall detection", map[string]interface{}{ + "error": err.Error(), + }) + return + } + + now := time.Now() + + for _, task := range activeTasks { + // Get task history to check last progress update + history, err := fm.lifecycle.GetTaskHistory(ctx, task.ID) + if err != nil { + continue + } + + // Find the most recent event + if len(history) == 0 { + continue + } + + latestEvent := history[len(history)-1] + eventTime := time.UnixMilli(latestEvent.Timestamp) + + // Check if task is running but stalled + if task.Status == TaskRunning && now.Sub(eventTime) > fm.progressStallTimeout { + logger.WarnCF("swarm", "Stalled task detected", map[string]interface{}{ + "task_id": task.ID, + "assigned_to": task.AssignedTo, + "last_update": eventTime.Format(time.RFC3339), + }) + + // Attempt to claim and recover the task + go fm.attemptTaskRecovery(ctx, task) + } + } +} + +// handleNodeFailure processes a detected node failure +func (fm *FailoverManager) handleNodeFailure(ctx context.Context, failedNode *NodeInfo) { + // Get tasks assigned to this node + tasks, err := fm.lifecycle.GetTasksByNode(ctx, failedNode.ID) + if err != nil { + logger.WarnCF("swarm", "Failed to get tasks for failed node", map[string]interface{}{ + "node_id": failedNode.ID, + "error": err.Error(), + }) + return + } + + for _, task := range tasks { + if task.Status == TaskRunning || task.Status == TaskAssigned { + logger.WarnCF("swarm", "Attempting recovery of task from failed node", map[string]interface{}{ + "task_id": task.ID, + "failed_node": failedNode.ID, + }) + go fm.attemptTaskRecovery(ctx, task) + } + } +} + +// attemptTaskRecovery tries to claim and recover a failed task +func (fm *FailoverManager) attemptTaskRecovery(ctx context.Context, task *SwarmTask) { + // Try to claim the task + claimed, checkpoint, err := fm.ClaimTask(ctx, task.ID) + if err != nil { + logger.WarnCF("swarm", "Failed to claim task for recovery", map[string]interface{}{ + "task_id": task.ID, + "error": err.Error(), + }) + return + } + + if !claimed { + // Another node claimed it + return + } + + logger.InfoCF("swarm", "Successfully claimed task for recovery", map[string]interface{}{ + "task_id": task.ID, + }) + + // Emit recovered event + if err := fm.lifecycle.SaveTaskStatus(task, TaskEventRetry, "Task claimed for failover recovery"); err != nil { + logger.WarnCF("swarm", "Failed to save task status during recovery", map[string]interface{}{ + "task_id": task.ID, + "error": err.Error(), + }) + } + + // Dispatch to worker for recovery + // In a full implementation, this would dispatch to the worker with checkpoint data + logger.InfoCF("swarm", "Task ready for recovery execution", map[string]interface{}{ + "task_id": task.ID, + "has_checkpoint": checkpoint != nil, + }) +} + +// ClaimTask attempts to claim a failed task using distributed locking +func (fm *FailoverManager) ClaimTask(ctx context.Context, taskID string) (bool, *TaskCheckpoint, error) { + // Try to acquire distributed lock using NATS KV + lockKey := fmt.Sprintf("claim_%s", taskID) + + // Create KV bucket for claims if it doesn't exist + bucket, err := fm.js.KeyValue("PICOCLAW_CLAIMS") + if err != nil { + bucket, err = fm.js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: "PICOCLAW_CLAIMS", + }) + if err != nil { + return false, nil, fmt.Errorf("failed to create claims bucket: %w", err) + } + } + + // Try to create the lock entry (atomic create) + claimData := []byte(fmt.Sprintf(`{"claimed_by":"%s","claimed_at":%d,"expires_at":%d}`, + fm.nodeInfo.ID, + time.Now().UnixMilli(), + time.Now().Add(ClaimLockTTL).UnixMilli(), + )) + + // Use Create to ensure we're the first to claim + _, err = bucket.Create(lockKey, claimData) + if err != nil { + if err == nats.ErrKeyExists { + // Already claimed by someone else + return false, nil, nil + } + return false, nil, fmt.Errorf("failed to create claim lock: %w", err) + } + + // Successfully claimed! Now get the checkpoint + checkpoint, err := fm.checkpointStore.LoadCheckpoint(ctx, taskID) + if err != nil { + logger.WarnCF("swarm", "No checkpoint found for task", map[string]interface{}{ + "task_id": taskID, + "error": err.Error(), + }) + // Continue without checkpoint - will restart from beginning + } + + // Track claim locally + fm.mu.Lock() + fm.claimedTasks[taskID] = &ClaimInfo{ + TaskID: taskID, + ClaimedBy: fm.nodeInfo.ID, + ClaimedAt: time.Now(), + ExpiresAt: time.Now().Add(ClaimLockTTL), + Checkpoint: checkpoint, + } + fm.mu.Unlock() + + return true, checkpoint, nil +} + +// ReleaseClaim releases a claim on a task +func (fm *FailoverManager) ReleaseClaim(ctx context.Context, taskID string) error { + lockKey := fmt.Sprintf("claim_%s", taskID) + + bucket, err := fm.js.KeyValue("PICOCLAW_CLAIMS") + if err != nil { + return fmt.Errorf("failed to get claims bucket: %w", err) + } + + err = bucket.Delete(lockKey) + if err != nil && err != nats.ErrKeyNotFound { + return fmt.Errorf("failed to release claim: %w", err) + } + + fm.mu.Lock() + delete(fm.claimedTasks, taskID) + fm.mu.Unlock() + + logger.DebugCF("swarm", "Released task claim", map[string]interface{}{ + "task_id": taskID, + }) + + return nil +} + +// cleanupExpiredClaims removes expired claims and retries them +func (fm *FailoverManager) cleanupExpiredClaims(ctx context.Context) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + fm.cleanupExpiredClaimsOnce(ctx) + } + } +} + +func (fm *FailoverManager) cleanupExpiredClaimsOnce(ctx context.Context) { + fm.mu.Lock() + defer fm.mu.Unlock() + + now := time.Now() + for taskID, claim := range fm.claimedTasks { + if now.After(claim.ExpiresAt) { + logger.WarnCF("swarm", "Claim expired, releasing", map[string]interface{}{ + "task_id": taskID, + }) + delete(fm.claimedTasks, taskID) + + // Also remove from KV + lockKey := fmt.Sprintf("claim_%s", taskID) + if bucket, err := fm.js.KeyValue("PICOCLAW_CLAIMS"); err == nil { + _ = bucket.Delete(lockKey) + } + } + } +} + +// RenewClaim renews a claim lock before it expires +func (fm *FailoverManager) RenewClaim(ctx context.Context, taskID string) error { + lockKey := fmt.Sprintf("claim_%s", taskID) + + bucket, err := fm.js.KeyValue("PICOCLAW_CLAIMS") + if err != nil { + return fmt.Errorf("failed to get claims bucket: %w", err) + } + + // Update the claim with new expiration + claimData := []byte(fmt.Sprintf(`{"claimed_by":"%s","claimed_at":%d,"expires_at":%d}`, + fm.nodeInfo.ID, + time.Now().UnixMilli(), + time.Now().Add(ClaimLockTTL).UnixMilli(), + )) + + _, err = bucket.Put(lockKey, claimData) + if err != nil { + return fmt.Errorf("failed to renew claim: %w", err) + } + + // Update local tracking + fm.mu.Lock() + if claim, exists := fm.claimedTasks[taskID]; exists { + claim.ExpiresAt = time.Now().Add(ClaimLockTTL) + } + fm.mu.Unlock() + + return nil +} + +// GetClaimedTasks returns the list of tasks claimed by this node +func (fm *FailoverManager) GetClaimedTasks() []string { + fm.mu.RLock() + defer fm.mu.RUnlock() + + tasks := make([]string, 0, len(fm.claimedTasks)) + for taskID := range fm.claimedTasks { + tasks = append(tasks, taskID) + } + return tasks +} + +// IsClaimedByThisNode checks if a task is claimed by this node +func (fm *FailoverManager) IsClaimedByThisNode(taskID string) bool { + fm.mu.RLock() + defer fm.mu.RUnlock() + + claim, exists := fm.claimedTasks[taskID] + if !exists { + return false + } + return claim.ClaimedBy == fm.nodeInfo.ID +} diff --git a/pkg/swarm/failover_test.go b/pkg/swarm/failover_test.go new file mode 100644 index 000000000..577dea61c --- /dev/null +++ b/pkg/swarm/failover_test.go @@ -0,0 +1,252 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testCfg provides a default test configuration +var testCfg = &config.Config{ + Swarm: config.SwarmConfig{ + NATS: config.NATSConfig{ + HeartbeatInterval: "10s", + NodeTimeout: "60s", + }, + Temporal: config.TemporalConfig{ + Host: "localhost:7233", + Namespace: "default", + TaskQueue: "picoclaw-test", + }, + MaxConcurrent: 5, + }, +} + +func TestNewFailoverManager(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + + // Create test config + swarmCfg := &testCfg.Swarm + swarmCfg.NATS.URLs = []string{tn.URL()} + + // Create discovery + bridge := NewNATSBridge(swarmCfg, nil, CreateTestNodeInfo("test-node", "coordinator", []string{})) + err := bridge.Connect(ctx) + require.NoError(t, err) + + discovery := NewDiscovery(bridge, CreateTestNodeInfo("test-node", "coordinator", []string{}), swarmCfg) + + // Create lifecycle store + lifecycle := NewTaskLifecycleStore(tn.JS()) + err = lifecycle.Initialize(ctx) + require.NoError(t, err) + + // Create checkpoint store + checkpointStore, err := NewCheckpointStore(tn.JS()) + require.NoError(t, err) + + // Create failover manager + fm := NewFailoverManager(discovery, lifecycle, checkpointStore, bridge, + CreateTestNodeInfo("test-node", "coordinator", []string{}), tn.JS()) + + assert.NotNil(t, fm) + }) +} + +func TestFailoverManager_ClaimTask(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + + // Setup + swarmCfg := &testCfg.Swarm + swarmCfg.NATS.URLs = []string{tn.URL()} + + bridge := NewNATSBridge(swarmCfg, nil, CreateTestNodeInfo("node-1", "coordinator", []string{})) + err := bridge.Connect(ctx) + require.NoError(t, err) + + discovery := NewDiscovery(bridge, CreateTestNodeInfo("node-1", "coordinator", []string{}), swarmCfg) + + lifecycle := NewTaskLifecycleStore(tn.JS()) + err = lifecycle.Initialize(ctx) + require.NoError(t, err) + + checkpointStore, err := NewCheckpointStore(tn.JS()) + require.NoError(t, err) + + // Clean up KV buckets from previous tests + _ = tn.JS().DeleteKeyValue("PICOCLAW_CLAIMS") + _ = tn.JS().DeleteStream("PICOCLAW_TASKS") + + fm := NewFailoverManager(discovery, lifecycle, checkpointStore, bridge, + CreateTestNodeInfo("node-1", "coordinator", []string{}), tn.JS()) + + err = fm.Start(ctx) + require.NoError(t, err) + defer fm.Stop() + + // Create a task + task := CreateTestTask("task-claim", "direct", "Test claim", "test") + + // Claim the task + claimed, checkpoint, err := fm.ClaimTask(ctx, task.ID) + require.NoError(t, err) + assert.True(t, claimed) + // checkpoint may be nil if task hasn't saved one yet + _ = checkpoint // We got a claim, that's what matters + + // Try to claim again - should fail (key already exists) + // Note: This will return an error since Create fails on existing key + _, _, err = fm.ClaimTask(ctx, task.ID) + assert.Error(t, err, "Should fail when trying to claim an already claimed task") + }) +} + +func TestFailoverManager_ReleaseClaim(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + + // Setup + swarmCfg := &testCfg.Swarm + swarmCfg.NATS.URLs = []string{tn.URL()} + + bridge := NewNATSBridge(swarmCfg, nil, CreateTestNodeInfo("node-1", "coordinator", []string{})) + err := bridge.Connect(ctx) + require.NoError(t, err) + + discovery := NewDiscovery(bridge, CreateTestNodeInfo("node-1", "coordinator", []string{}), swarmCfg) + + lifecycle := NewTaskLifecycleStore(tn.JS()) + err = lifecycle.Initialize(ctx) + require.NoError(t, err) + + checkpointStore, err := NewCheckpointStore(tn.JS()) + require.NoError(t, err) + + // Clean up KV buckets from previous tests + _ = tn.JS().DeleteKeyValue("PICOCLAW_CLAIMS") + _ = tn.JS().DeleteStream("PICOCLAW_TASKS") + + fm := NewFailoverManager(discovery, lifecycle, checkpointStore, bridge, + CreateTestNodeInfo("node-1", "coordinator", []string{}), tn.JS()) + + err = fm.Start(ctx) + require.NoError(t, err) + defer fm.Stop() + + task := CreateTestTask("task-release", "direct", "Test release", "test") + + // Claim the task + claimed, _, err := fm.ClaimTask(ctx, task.ID) + require.NoError(t, err) + assert.True(t, claimed) + + // Release the claim + err = fm.ReleaseClaim(ctx, task.ID) + require.NoError(t, err) + + // Now should be claimable again + claimed2, _, err := fm.ClaimTask(ctx, task.ID) + require.NoError(t, err) + assert.True(t, claimed2, "Task should be claimable after release") + }) +} + +func TestFailoverManager_RenewClaim(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + + // Setup + swarmCfg := &testCfg.Swarm + swarmCfg.NATS.URLs = []string{tn.URL()} + + bridge := NewNATSBridge(swarmCfg, nil, CreateTestNodeInfo("node-1", "coordinator", []string{})) + err := bridge.Connect(ctx) + require.NoError(t, err) + + discovery := NewDiscovery(bridge, CreateTestNodeInfo("node-1", "coordinator", []string{}), swarmCfg) + + lifecycle := NewTaskLifecycleStore(tn.JS()) + err = lifecycle.Initialize(ctx) + require.NoError(t, err) + + checkpointStore, err := NewCheckpointStore(tn.JS()) + require.NoError(t, err) + + // Clean up KV buckets from previous tests + _ = tn.JS().DeleteKeyValue("PICOCLAW_CLAIMS") + _ = tn.JS().DeleteStream("PICOCLAW_TASKS") + + fm := NewFailoverManager(discovery, lifecycle, checkpointStore, bridge, + CreateTestNodeInfo("node-1", "coordinator", []string{}), tn.JS()) + + err = fm.Start(ctx) + require.NoError(t, err) + defer fm.Stop() + + task := CreateTestTask("task-renew", "direct", "Test renew", "test") + + // Initial claim + claimed, _, err := fm.ClaimTask(ctx, task.ID) + require.NoError(t, err) + assert.True(t, claimed) + + // Renew claim + err = fm.RenewClaim(ctx, task.ID) + require.NoError(t, err) + + // Verify claim is still held by trying to claim again (should fail) + _, _, err = fm.ClaimTask(ctx, task.ID) + assert.Error(t, err, "Task should still be claimed") + }) +} + +func TestClaimInfo(t *testing.T) { + info := &ClaimInfo{ + TaskID: "task-1", + ClaimedBy: "node-1", + ClaimedAt: time.Now(), + ExpiresAt: time.Now().Add(30 * time.Second), + } + + assert.Equal(t, "task-1", info.TaskID) + assert.Equal(t, "node-1", info.ClaimedBy) + assert.True(t, time.Until(info.ExpiresAt) < 30*time.Second, "Claim should expire within 30 seconds") +} + +func TestDefaultTimeouts(t *testing.T) { + assert.Equal(t, 60*time.Second, DefaultHeartbeatTimeout) + assert.Equal(t, 2*time.Minute, DefaultProgressStallTimeout) + assert.Equal(t, 10*time.Second, FailoverCheckInterval) + assert.Equal(t, 30*time.Second, ClaimLockTTL) +} + +func TestClaimInfo_IsExpired(t *testing.T) { + expiredInfo := &ClaimInfo{ + TaskID: "task-1", + ClaimedBy: "node-1", + ClaimedAt: time.Now().Add(-1 * time.Hour), + ExpiresAt: time.Now().Add(-1 * time.Hour), + } + + assert.True(t, expiredInfo.ExpiresAt.Before(time.Now()), "Old claim should be expired") + + validInfo := &ClaimInfo{ + TaskID: "task-2", + ClaimedBy: "node-1", + ClaimedAt: time.Now(), + ExpiresAt: time.Now().Add(30 * time.Second), + } + + assert.True(t, validInfo.ExpiresAt.After(time.Now()), "Recent claim should not be expired") +} diff --git a/pkg/swarm/heartbeat.go b/pkg/swarm/heartbeat.go new file mode 100644 index 000000000..0a9c74c4b --- /dev/null +++ b/pkg/swarm/heartbeat.go @@ -0,0 +1,293 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + // HeartbeatInterval is how often nodes send heartbeat messages + HeartbeatInterval = 10 * time.Second + // HeartbeatSuspiciousThreshold is how long before a node is marked suspicious + HeartbeatSuspiciousThreshold = 30 * time.Second + // HeartbeatOfflineThreshold is how long before a node is marked offline + HeartbeatOfflineThreshold = 60 * time.Second +) + +// HeartbeatConfig configures heartbeat behavior +type HeartbeatConfig struct { + Interval time.Duration // Heartbeat send interval + SuspiciousTimeout time.Duration // Time before marking suspicious + OfflineTimeout time.Duration // Time before marking offline +} + +// DefaultHeartbeatConfig returns the default heartbeat configuration +func DefaultHeartbeatConfig() *HeartbeatConfig { + return &HeartbeatConfig{ + Interval: HeartbeatInterval, + SuspiciousTimeout: HeartbeatSuspiciousThreshold, + OfflineTimeout: HeartbeatOfflineThreshold, + } +} + +// HeartbeatPublisher sends periodic heartbeat messages for a node +type HeartbeatPublisher struct { + bridge *NATSBridge + nodeInfo *NodeInfo + cfg *HeartbeatConfig + ticker *time.Ticker + stopChan chan struct{} + running bool + mu sync.RWMutex +} + +// NewHeartbeatPublisher creates a new heartbeat publisher +func NewHeartbeatPublisher(bridge *NATSBridge, nodeInfo *NodeInfo, cfg *HeartbeatConfig) *HeartbeatPublisher { + if cfg == nil { + cfg = DefaultHeartbeatConfig() + } + return &HeartbeatPublisher{ + bridge: bridge, + nodeInfo: nodeInfo, + cfg: cfg, + stopChan: make(chan struct{}), + } +} + +// Start begins sending heartbeat messages +func (hp *HeartbeatPublisher) Start(ctx context.Context) error { + hp.mu.Lock() + defer hp.mu.Unlock() + + if hp.running { + return nil + } + + hp.ticker = time.NewTicker(hp.cfg.Interval) + hp.running = true + + go hp.run(ctx) + + logger.InfoCF("swarm", "Heartbeat publisher started", map[string]interface{}{ + "node_id": hp.nodeInfo.ID, + "interval": hp.cfg.Interval.String(), + }) + + return nil +} + +// Stop stops sending heartbeat messages +func (hp *HeartbeatPublisher) Stop() { + hp.mu.Lock() + defer hp.mu.Unlock() + + if !hp.running { + return + } + + close(hp.stopChan) + if hp.ticker != nil { + hp.ticker.Stop() + } + hp.running = false + + logger.InfoC("swarm", "Heartbeat publisher stopped") +} + +func (hp *HeartbeatPublisher) run(ctx context.Context) { + // Send first heartbeat immediately + hp.sendHeartbeat() + + for { + select { + case <-hp.ticker.C: + hp.sendHeartbeat() + case <-hp.stopChan: + return + case <-ctx.Done(): + return + } + } +} + +func (hp *HeartbeatPublisher) sendHeartbeat() { + hb := &Heartbeat{ + NodeID: hp.nodeInfo.ID, + Timestamp: time.Now().UnixMilli(), + Load: hp.nodeInfo.Load, + TasksRunning: hp.nodeInfo.TasksRunning, + Status: hp.nodeInfo.Status, + Capabilities: hp.nodeInfo.Capabilities, + } + + if err := hp.bridge.PublishHeartbeat(hb); err != nil { + logger.DebugCF("swarm", "Failed to publish heartbeat", map[string]interface{}{ + "error": err.Error(), + }) + } +} + +// HeartbeatMonitor tracks heartbeats from other nodes +type HeartbeatMonitor struct { + cfg *HeartbeatConfig + discovery *Discovery + heartbeats map[string]int64 // node_id -> last heartbeat timestamp + mu sync.RWMutex + stopChan chan struct{} + running bool +} + +// NewHeartbeatMonitor creates a new heartbeat monitor +func NewHeartbeatMonitor(discovery *Discovery, cfg *HeartbeatConfig) *HeartbeatMonitor { + if cfg == nil { + cfg = DefaultHeartbeatConfig() + } + return &HeartbeatMonitor{ + cfg: cfg, + discovery: discovery, + heartbeats: make(map[string]int64), + stopChan: make(chan struct{}), + } +} + +// Start begins monitoring heartbeats +func (hm *HeartbeatMonitor) Start(ctx context.Context) error { + hm.mu.Lock() + defer hm.mu.Unlock() + + if hm.running { + return nil + } + + hm.running = true + + // Start checker goroutine + go hm.runChecker(ctx) + + logger.InfoC("swarm", "Heartbeat monitor started") + return nil +} + +// Stop stops monitoring +func (hm *HeartbeatMonitor) Stop() { + hm.mu.Lock() + defer hm.mu.Unlock() + + if !hm.running { + return + } + + close(hm.stopChan) + hm.running = false + + logger.InfoC("swarm", "Heartbeat monitor stopped") +} + +// UpdateHeartbeat records a heartbeat from a node +func (hm *HeartbeatMonitor) UpdateHeartbeat(hb *Heartbeat) { + hm.mu.Lock() + defer hm.mu.Unlock() + + now := time.Now().UnixMilli() + hm.heartbeats[hb.NodeID] = now + + // Update node status in discovery based on heartbeat + node, ok := hm.discovery.GetNode(hb.NodeID) + if ok { + node.Load = hb.Load + node.TasksRunning = hb.TasksRunning + node.LastSeen = now + // If node was offline/suspicious, mark it back online + if node.Status == StatusOffline || node.Status == StatusSuspicious { + node.Status = hb.Status + } + } + + logger.DebugCF("swarm", "Heartbeat received", map[string]interface{}{ + "node_id": hb.NodeID, + "status": string(hb.Status), + "load": hb.Load, + "tasks": hb.TasksRunning, + }) +} + +// runChecker periodically checks for missed heartbeats +func (hm *HeartbeatMonitor) runChecker(ctx context.Context) { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + hm.checkHeartbeats() + case <-hm.stopChan: + return + case <-ctx.Done(): + return + } + } +} + +// checkHeartbeats checks all tracked nodes for missed heartbeats +func (hm *HeartbeatMonitor) checkHeartbeats() { + hm.mu.Lock() + defer hm.mu.Unlock() + + now := time.Now().UnixMilli() + suspiciousThreshold := now - hm.cfg.SuspiciousTimeout.Milliseconds() + offlineThreshold := now - hm.cfg.OfflineTimeout.Milliseconds() + + for nodeID, lastHB := range hm.heartbeats { + node, ok := hm.discovery.GetNode(nodeID) + if !ok { + continue + } + + if lastHB < offlineThreshold { + // Node is offline + if node.Status != StatusOffline { + logger.WarnCF("swarm", "Node marked offline", map[string]interface{}{ + "node_id": nodeID, + "last_heartbeat": time.UnixMilli(lastHB).Format(time.RFC3339), + }) + node.Status = StatusOffline + hm.discovery.handleNodeLeave(nodeID) + } + } else if lastHB < suspiciousThreshold { + // Node is suspicious + if node.Status != StatusSuspicious && node.Status != StatusOffline { + logger.WarnCF("swarm", "Node marked suspicious", map[string]interface{}{ + "node_id": nodeID, + }) + node.Status = StatusSuspicious + } + } + } +} + +// RemoveNode stops monitoring a node +func (hm *HeartbeatMonitor) RemoveNode(nodeID string) { + hm.mu.Lock() + defer hm.mu.Unlock() + delete(hm.heartbeats, nodeID) +} + +// GetLastHeartbeat returns the last heartbeat time for a node +func (hm *HeartbeatMonitor) GetLastHeartbeat(nodeID string) time.Time { + hm.mu.RLock() + defer hm.mu.RUnlock() + + ts, ok := hm.heartbeats[nodeID] + if !ok { + return time.Time{} + } + return time.UnixMilli(ts) +} diff --git a/pkg/swarm/heartbeat_test.go b/pkg/swarm/heartbeat_test.go new file mode 100644 index 000000000..92a8a1e9a --- /dev/null +++ b/pkg/swarm/heartbeat_test.go @@ -0,0 +1,245 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultHeartbeatConfig(t *testing.T) { + cfg := DefaultHeartbeatConfig() + + assert.Equal(t, HeartbeatInterval, cfg.Interval) + assert.Equal(t, HeartbeatSuspiciousThreshold, cfg.SuspiciousTimeout) + assert.Equal(t, HeartbeatOfflineThreshold, cfg.OfflineTimeout) +} + +func TestHeartbeatPublisher_StartStop(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + nodeInfo := CreateTestNodeInfo("hb-pub-test", string(RoleWorker), []string{"test"}) + + // Use connectTestBridge which properly configures the URL + bridge := connectTestBridge(t, tn.url, nodeInfo) + defer bridge.Stop() + + pub := NewHeartbeatPublisher(bridge, nodeInfo, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + require.NoError(t, pub.Start(ctx)) + assert.True(t, pub.IsRunning()) + + pub.Stop() + // Give it a moment to stop + time.Sleep(10 * time.Millisecond) + }) +} + +func TestHeartbeatPublisher_SendHeartbeat(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + nodeInfo := CreateTestNodeInfo("hb-send-test", string(RoleWorker), []string{"test"}) + + // Use connectTestBridge which properly configures the URL + bridge := connectTestBridge(t, tn.url, nodeInfo) + defer bridge.Stop() + + pub := NewHeartbeatPublisher(bridge, nodeInfo, nil) + + // Subscribe to heartbeats + received := make(chan *Heartbeat, 1) + sub, err := bridge.SubscribeAllHeartbeats(func(hb *Heartbeat) { + select { + case received <- hb: + default: + } + }) + require.NoError(t, err) + defer sub.Unsubscribe() + + // Start publisher + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + require.NoError(t, pub.Start(ctx)) + defer pub.Stop() + + // Wait for heartbeat + select { + case hb := <-received: + assert.Equal(t, "hb-send-test", hb.NodeID) + assert.NotZero(t, hb.Timestamp) + case <-time.After(100 * time.Millisecond): + t.Fatal("Did not receive heartbeat in time") + case <-ctx.Done(): + t.Fatal("Context canceled while waiting for heartbeat") + } + }) +} + +func TestHeartbeatMonitor_TrackHeartbeats(t *testing.T) { + // Create a simple discovery for the monitor + nodeInfo := CreateTestNodeInfo("monitor-test", string(RoleCoordinator), nil) + swarmCfg := &config.SwarmConfig{ + Enabled: true, + MaxConcurrent: 2, + NATS: config.NATSConfig{HeartbeatInterval: "50ms", NodeTimeout: "200ms"}, + } + + // Create a mock discovery - we need a real one but can use nil bridge + bridge := NewNATSBridge(swarmCfg, nil, nodeInfo) + discovery := NewDiscovery(bridge, nodeInfo, swarmCfg) + + // Add the test node to discovery + testNode := CreateTestNodeInfo("test-node", string(RoleWorker), []string{"test"}) + discovery.handleNodeJoin(testNode) + + monitor := NewHeartbeatMonitor(discovery, nil) + + hb := &Heartbeat{ + NodeID: "test-node", + Timestamp: time.Now().UnixMilli(), + Status: StatusOnline, + } + + // Before any heartbeat + assert.Zero(t, monitor.GetLastHeartbeat("test-node")) + + // Record heartbeat + monitor.UpdateHeartbeat(hb) + + // After heartbeat + lastHB := monitor.GetLastHeartbeat("test-node") + assert.False(t, lastHB.IsZero()) + assert.True(t, time.Since(lastHB) < time.Second) +} + +func TestHeartbeatMonitor_OfflineDetection(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + nodeInfo := CreateTestNodeInfo("coord-main", string(RoleCoordinator), nil) + + // Use connectTestBridge which properly configures the URL + bridge := connectTestBridge(t, tn.url, nodeInfo) + defer bridge.Stop() + + // Create a minimal swarm config for discovery + swarmCfg := &config.SwarmConfig{ + Enabled: true, + MaxConcurrent: 2, + NATS: config.NATSConfig{HeartbeatInterval: "50ms", NodeTimeout: "200ms"}, + } + discovery := NewDiscovery(bridge, nodeInfo, swarmCfg) + + // Short timeout for testing + cfg := &HeartbeatConfig{ + Interval: 10 * time.Millisecond, + SuspiciousTimeout: 25 * time.Millisecond, + OfflineTimeout: 50 * time.Millisecond, + } + + monitor := NewHeartbeatMonitor(discovery, cfg) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + require.NoError(t, monitor.Start(ctx)) + defer monitor.Stop() + + // Add a node to discovery + testNode := CreateTestNodeInfo("test-offline-node", string(RoleWorker), []string{"test"}) + discovery.handleNodeJoin(testNode) + + // Send initial heartbeat + hb := &Heartbeat{ + NodeID: "test-offline-node", + Timestamp: time.Now().UnixMilli(), + Status: StatusOnline, + } + monitor.UpdateHeartbeat(hb) + + // Node should be online + node, ok := discovery.GetNode("test-offline-node") + require.True(t, ok) + assert.Equal(t, StatusOnline, node.Status) + + // Wait for suspicious threshold + time.Sleep(30 * time.Millisecond) + + // Check heartbeats manually (since checker runs every 5s in tests, we'll trigger it) + monitor.checkHeartbeats() + + node, ok = discovery.GetNode("test-offline-node") + require.True(t, ok) + assert.Equal(t, StatusSuspicious, node.Status, "Node should be marked suspicious") + + // Wait for offline threshold + time.Sleep(30 * time.Millisecond) + monitor.checkHeartbeats() + + node, ok = discovery.GetNode("test-offline-node") + require.True(t, ok) + assert.Equal(t, StatusOffline, node.Status, "Node should be marked offline") + }) +} + +func TestHeartbeat_MessageFields(t *testing.T) { + hb := &Heartbeat{ + NodeID: "test-node", + Timestamp: 1234567890, + Load: 0.75, + TasksRunning: 3, + Status: StatusBusy, + Capabilities: []string{"code", "write"}, + } + + assert.Equal(t, "test-node", hb.NodeID) + assert.Equal(t, int64(1234567890), hb.Timestamp) + assert.Equal(t, 0.75, hb.Load) + assert.Equal(t, 3, hb.TasksRunning) + assert.Equal(t, StatusBusy, hb.Status) + assert.Equal(t, []string{"code", "write"}, hb.Capabilities) +} + +func TestHeartbeatPublisher_IsRunning(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + nodeInfo := CreateTestNodeInfo("hb-running-test", string(RoleWorker), []string{"test"}) + + // Use connectTestBridge which properly configures the URL + bridge := connectTestBridge(t, tn.url, nodeInfo) + defer bridge.Stop() + + pub := NewHeartbeatPublisher(bridge, nodeInfo, nil) + + // Not running initially + assert.False(t, pub.isRunning()) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + pub.Start(ctx) + assert.True(t, pub.isRunning()) + + pub.Stop() + assert.False(t, pub.isRunning()) + }) +} + +// Helper method for testing +func (hp *HeartbeatPublisher) IsRunning() bool { + hp.mu.RLock() + defer hp.mu.RUnlock() + return hp.running +} + +func (hp *HeartbeatPublisher) isRunning() bool { + return hp.IsRunning() +} diff --git a/pkg/swarm/helpers_test.go b/pkg/swarm/helpers_test.go new file mode 100644 index 000000000..8f87b2f09 --- /dev/null +++ b/pkg/swarm/helpers_test.go @@ -0,0 +1,189 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// mockLLMProvider implements providers.LLMProvider for testing. +type mockLLMProvider struct { + chatFn func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) + model string +} + +func (m *mockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + if m.chatFn != nil { + return m.chatFn(ctx, messages, tools, model, options) + } + return &providers.LLMResponse{Content: "mock response", FinishReason: "stop"}, nil +} + +func (m *mockLLMProvider) GetDefaultModel() string { + if m.model != "" { + return m.model + } + return "test-model" +} + +// mockManager is a minimal implementation of the manager interface used for testing +type mockManager struct{} + +func (m *mockManager) GetNodeInfo() *NodeInfo { + return &NodeInfo{ + ID: "mock-manager", + Role: RoleCoordinator, + Status: StatusOnline, + } +} + +func (m *mockManager) PromoteToCoordinator() error { + return nil +} + +func (m *mockManager) DemoteToWorker() error { + return nil +} + +// startTestNATS starts an embedded NATS server on a random available port. +// It returns the embedded server, the client URL, and a cleanup function. +func startTestNATS(t *testing.T) (*EmbeddedNATS, string, func()) { + t.Helper() + port := freePort(t) + cfg := &config.NATSConfig{EmbeddedPort: port} + e := NewEmbeddedNATS(cfg) + if err := e.Start(); err != nil { + t.Fatalf("startTestNATS: failed to start: %v", err) + } + url := e.ClientURL() + return e, url, func() { e.Stop() } +} + +// newTestSwarmConfig returns a SwarmConfig with short timeouts suitable for fast tests. +func newTestSwarmConfig(port int) *config.SwarmConfig { + return &config.SwarmConfig{ + Enabled: true, + Role: "worker", + Capabilities: []string{"general"}, + MaxConcurrent: 2, + NATS: config.NATSConfig{ + URLs: []string{fmt.Sprintf("nats://127.0.0.1:%d", port)}, + HeartbeatInterval: "50ms", + NodeTimeout: "200ms", + Embedded: false, + EmbeddedPort: port, + }, + Temporal: config.TemporalConfig{ + TaskQueue: "test-queue", + }, + } +} + +// newTestConfig returns a full config.Config wrapping a test SwarmConfig. +func newTestConfig(t *testing.T, embeddedPort int) *config.Config { + t.Helper() + workspace := t.TempDir() + return &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: workspace, + RestrictToWorkspace: true, + Model: "test-model", + MaxTokens: 1024, + Temperature: 0.0, + MaxToolIterations: 5, + }, + }, + Swarm: *newTestSwarmConfig(embeddedPort), + } +} + +// newTestNodeInfo creates a NodeInfo with configurable fields. +func newTestNodeInfo(id string, role NodeRole, capabilities []string, maxTasks int) *NodeInfo { + return &NodeInfo{ + ID: id, + Role: role, + Capabilities: capabilities, + Model: "test-model", + Status: StatusOnline, + MaxTasks: maxTasks, + StartedAt: time.Now().UnixMilli(), + Metadata: make(map[string]string), + } +} + +// connectTestBridge creates a NATSBridge, connects it to the given URL, and returns it. +// The bridge is configured with a SwarmConfig pointing at the given URL. +func connectTestBridge(t *testing.T, url string, nodeInfo *NodeInfo) *NATSBridge { + t.Helper() + cfg := &config.SwarmConfig{ + Enabled: true, + MaxConcurrent: 2, + NATS: config.NATSConfig{ + URLs: []string{url}, + }, + } + msgBus := bus.NewMessageBus() + bridge := NewNATSBridge(cfg, msgBus, nodeInfo) + if err := bridge.Connect(context.Background()); err != nil { + t.Fatalf("connectTestBridge: failed to connect: %v", err) + } + return bridge +} + +// connectTestNATS connects a raw nats.Conn to the given URL. Returns the connection. +func connectTestNATS(t *testing.T, url string) *nats.Conn { + t.Helper() + nc, err := nats.Connect(url) + if err != nil { + t.Fatalf("connectTestNATS: %v", err) + } + return nc +} + +// waitFor polls a condition function at 10ms intervals up to the given timeout. +// Returns true if the condition was met before timeout, false otherwise. +func waitFor(t *testing.T, timeout time.Duration, condition func() bool) bool { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if condition() { + return true + } + time.Sleep(10 * time.Millisecond) + } + return false +} + +// newTestAgentLoop creates an AgentLoop backed by a mock LLM provider. +// The provider returns chatResponse on success, or chatErr if non-nil. +func newTestAgentLoop(t *testing.T, chatResponse string, chatErr error) *agent.AgentLoop { + t.Helper() + cfg := newTestConfig(t, 0) + msgBus := bus.NewMessageBus() + provider := &mockLLMProvider{ + model: "test-model", + chatFn: func(ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + if chatErr != nil { + return nil, chatErr + } + return &providers.LLMResponse{ + Content: chatResponse, + FinishReason: "stop", + }, nil + }, + } + return agent.NewAgentLoop(cfg, msgBus, provider) +} diff --git a/pkg/swarm/integration_test.go b/pkg/swarm/integration_test.go new file mode 100644 index 000000000..3e086741a --- /dev/null +++ b/pkg/swarm/integration_test.go @@ -0,0 +1,325 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestIntegration_CoordinatorWorkerRoundTrip(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + chatResponse string + chatErr error + wantStatus string + wantContains string + }{ + { + name: "full round-trip success", + chatResponse: "analysis complete", + chatErr: nil, + wantStatus: string(TaskDone), + wantContains: "analysis complete", + }, + { + name: "error round-trip", + chatResponse: "", + chatErr: fmt.Errorf("agent crashed"), + wantStatus: string(TaskFailed), + wantContains: "agent crashed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // --- Worker side --- + workerNode := newTestNodeInfo("integ-worker", RoleWorker, []string{"code"}, 4) + workerBridge := connectTestBridge(t, url, workerNode) + defer workerBridge.Stop() + + if err := workerBridge.Start(context.Background()); err != nil { + t.Fatalf("worker bridge Start() error: %v", err) + } + + workerCfg := newTestSwarmConfig(0) + workerCfg.MaxConcurrent = 2 + workerTemporal := NewTemporalClient(&config.TemporalConfig{TaskQueue: "test"}) + workerAgentLoop := newTestAgentLoop(t, tt.chatResponse, tt.chatErr) + + worker := NewWorker(workerCfg, workerBridge, workerTemporal, workerAgentLoop, &mockLLMProvider{}, workerNode) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := worker.Start(ctx); err != nil { + t.Fatalf("worker Start() error: %v", err) + } + + // --- Coordinator side --- + coordNode := newTestNodeInfo("integ-coord", RoleCoordinator, nil, 1) + coordBridge := connectTestBridge(t, url, coordNode) + defer coordBridge.Stop() + + if err := coordBridge.Start(context.Background()); err != nil { + t.Fatalf("coord bridge Start() error: %v", err) + } + + coordCfg := newTestSwarmConfig(0) + discovery := NewDiscovery(coordBridge, coordNode, coordCfg) + // Register the worker in coordinator's discovery + discovery.handleNodeJoin(workerNode) + + coordTemporal := NewTemporalClient(&config.TemporalConfig{TaskQueue: "test"}) + coordAgentLoop := newTestAgentLoop(t, "local fallback", nil) + localBus := bus.NewMessageBus() + + coordinator := NewCoordinator(coordCfg, coordBridge, coordTemporal, discovery, coordAgentLoop, &mockLLMProvider{}, localBus) + + // Give all subscriptions time to propagate + time.Sleep(100 * time.Millisecond) + + // --- Dispatch --- + task := &SwarmTask{ + ID: fmt.Sprintf("integ-task-%d", time.Now().UnixNano()%100000), + Type: TaskTypeDirect, + Capability: "code", + Prompt: "integration test prompt", + Status: TaskPending, + Timeout: 5000, + } + + result, err := coordinator.DispatchTask(ctx, task) + if err != nil { + t.Fatalf("DispatchTask() error: %v", err) + } + if result == nil { + t.Fatal("DispatchTask() returned nil result") + } + if result.Status != tt.wantStatus { + t.Errorf("Status = %q, want %q", result.Status, tt.wantStatus) + } + combined := result.Result + result.Error + if !strings.Contains(combined, tt.wantContains) { + t.Errorf("Result+Error = %q, want it to contain %q", combined, tt.wantContains) + } + }) + } +} + +func TestIntegration_MultiNodeDiscovery(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + type nodeSetup struct { + id string + role NodeRole + caps []string + bridge *NATSBridge + discovery *Discovery + } + + nodes := []struct { + id string + role NodeRole + caps []string + }{ + {"disc-node-a", RoleWorker, []string{"code"}}, + {"disc-node-b", RoleWorker, []string{"research"}}, + {"disc-node-c", RoleSpecialist, []string{"ml"}}, + } + + setups := make([]*nodeSetup, len(nodes)) + + // Create all nodes + for i, n := range nodes { + nodeInfo := newTestNodeInfo(n.id, n.role, n.caps, 4) + bridge := connectTestBridge(t, url, nodeInfo) + defer bridge.Stop() + + cfg := newTestSwarmConfig(0) + cfg.NATS.HeartbeatInterval = "50ms" + cfg.NATS.NodeTimeout = "5s" + + disc := NewDiscovery(bridge, nodeInfo, cfg) + + setups[i] = &nodeSetup{ + id: n.id, + role: n.role, + caps: n.caps, + bridge: bridge, + discovery: disc, + } + } + + // Start all bridges and discoveries + ctx := context.Background() + for _, s := range setups { + if err := s.bridge.Start(ctx); err != nil { + t.Fatalf("bridge Start(%s) error: %v", s.id, err) + } + } + + // Stagger discovery starts slightly to let announce messages propagate + for _, s := range setups { + if err := s.discovery.Start(ctx); err != nil { + t.Fatalf("discovery Start(%s) error: %v", s.id, err) + } + time.Sleep(50 * time.Millisecond) + } + defer func() { + for _, s := range setups { + s.discovery.Stop() + } + }() + + // Wait for all nodes to discover each other + ok := waitFor(t, 5*time.Second, func() bool { + for _, s := range setups { + if s.discovery.NodeCount() < 2 { + return false + } + } + return true + }) + + if !ok { + for _, s := range setups { + t.Errorf("node %s discovered %d other nodes, want 2", s.id, s.discovery.NodeCount()) + } + t.Fatal("timed out waiting for multi-node discovery") + } + + // Verify each node knows about the other two + for _, s := range setups { + count := s.discovery.NodeCount() + if count != 2 { + t.Errorf("node %s: NodeCount() = %d, want 2", s.id, count) + } + allNodes := s.discovery.GetNodes("", "") + for _, other := range setups { + if other.id == s.id { + continue + } + found := false + for _, n := range allNodes { + if n.ID == other.id { + found = true + break + } + } + if !found { + t.Errorf("node %s: missing discovery of node %s", s.id, other.id) + } + } + } +} + +func TestIntegration_CapabilityRouting(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + taskCapability string + expectWorkerA bool // Worker A has "code" + expectWorkerB bool // Worker B has "research" + }{ + { + name: "code task to code worker", + taskCapability: "code", + expectWorkerA: true, + expectWorkerB: false, + }, + { + name: "research task to research worker", + taskCapability: "research", + expectWorkerA: false, + expectWorkerB: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Worker A: capability "code" + nodeA := newTestNodeInfo("cap-worker-a", RoleWorker, []string{"code"}, 4) + bridgeA := connectTestBridge(t, url, nodeA) + defer bridgeA.Stop() + + var receivedA atomic.Int32 + var mu sync.Mutex + var resultsA []*SwarmTask + bridgeA.SetOnTaskReceived(func(task *SwarmTask) { + receivedA.Add(1) + mu.Lock() + resultsA = append(resultsA, task) + mu.Unlock() + }) + + if err := bridgeA.Start(context.Background()); err != nil { + t.Fatalf("bridgeA Start() error: %v", err) + } + + // Worker B: capability "research" + nodeB := newTestNodeInfo("cap-worker-b", RoleWorker, []string{"research"}, 4) + bridgeB := connectTestBridge(t, url, nodeB) + defer bridgeB.Stop() + + var receivedB atomic.Int32 + bridgeB.SetOnTaskReceived(func(task *SwarmTask) { + receivedB.Add(1) + }) + + if err := bridgeB.Start(context.Background()); err != nil { + t.Fatalf("bridgeB Start() error: %v", err) + } + + // Give subscriptions time to propagate + time.Sleep(100 * time.Millisecond) + + // Coordinator publishes broadcast task + coordNode := newTestNodeInfo("cap-coord", RoleCoordinator, nil, 1) + coordBridge := connectTestBridge(t, url, coordNode) + defer coordBridge.Stop() + + task := &SwarmTask{ + ID: fmt.Sprintf("cap-task-%s", tt.taskCapability), + Type: TaskTypeBroadcast, + Capability: tt.taskCapability, + Prompt: "capability routing test", + Timeout: 5000, + } + // Broadcast: no AssignedTo, publish to capability subject + if err := coordBridge.PublishTask(task); err != nil { + t.Fatalf("PublishTask() error: %v", err) + } + + // Wait for delivery + time.Sleep(500 * time.Millisecond) + + gotA := receivedA.Load() > 0 + gotB := receivedB.Load() > 0 + + if gotA != tt.expectWorkerA { + t.Errorf("Worker A (code) received = %v, want %v", gotA, tt.expectWorkerA) + } + if gotB != tt.expectWorkerB { + t.Errorf("Worker B (research) received = %v, want %v", gotB, tt.expectWorkerB) + } + }) + } +} diff --git a/pkg/swarm/lifecycle.go b/pkg/swarm/lifecycle.go new file mode 100644 index 000000000..ef931ec0c --- /dev/null +++ b/pkg/swarm/lifecycle.go @@ -0,0 +1,383 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + // TaskStreamName is the JetStream stream for task events + TaskStreamName = "PICOCLAW_TASKS" + // TaskStatusSubject is the subject pattern for task status updates + // Note: NATS doesn't allow '.' in durable stream names, so we use underscore + TaskStatusSubject = "picoclaw_tasks_status" +) + +// TaskLifecycleStore manages task state persistence using JetStream +type TaskLifecycleStore struct { + js nats.JetStreamContext + mu sync.RWMutex + cfg *lifecycleConfig +} + +// lifecycleConfig holds configuration for the lifecycle store +type lifecycleConfig struct { + streamMaxAge time.Duration + streamMaxBytes int64 +} + +// NewTaskLifecycleStore creates a new task lifecycle store +func NewTaskLifecycleStore(js nats.JetStreamContext) *TaskLifecycleStore { + return &TaskLifecycleStore{ + js: js, + cfg: &lifecycleConfig{ + streamMaxAge: 24 * time.Hour * 7, // Keep task history for 7 days + streamMaxBytes: 1024 * 1024 * 100, // 100MB per stream + }, + } +} + +// Initialize creates the JetStream stream if it doesn't exist +func (s *TaskLifecycleStore) Initialize(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Create stream for task events + stream, err := s.js.StreamInfo(TaskStreamName) + if err != nil { + // Stream doesn't exist, create it + _, err = s.js.AddStream(&nats.StreamConfig{ + Name: TaskStreamName, + Subjects: []string{TaskStatusSubject + ".>"}, + MaxAge: s.cfg.streamMaxAge, + MaxBytes: s.cfg.streamMaxBytes, + Storage: nats.FileStorage, + Discard: nats.DiscardOld, + Replicas: 1, + }) + if err != nil { + return fmt.Errorf("failed to create task stream: %w", err) + } + logger.InfoC("swarm", fmt.Sprintf("Created task stream: %s", TaskStreamName)) + } else { + logger.DebugC("swarm", fmt.Sprintf("Task stream exists: %s", stream.Config.Name)) + } + + return nil +} + +// SaveTaskStatus persists a task status event to JetStream +func (s *TaskLifecycleStore) SaveTaskStatus(task *SwarmTask, eventType TaskEventType, message string) error { + return s.SaveTaskStatusWithMetadata(task, eventType, message, nil) +} + +// SaveTaskStatusWithMetadata persists a task status event with additional metadata +func (s *TaskLifecycleStore) SaveTaskStatusWithMetadata(task *SwarmTask, eventType TaskEventType, message string, metadata map[string]interface{}) error { + event := &TaskEvent{ + EventID: fmt.Sprintf("evt-%d-%s", time.Now().UnixNano(), task.ID), + TaskID: task.ID, + EventType: eventType, + Timestamp: time.Now().UnixMilli(), + Status: task.Status, + Message: message, + Metadata: metadata, + } + + if task.AssignedTo != "" { + event.NodeID = task.AssignedTo + } + + // Marshal to JSON + data, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("failed to marshal task event: %w", err) + } + + // Publish to JetStream + subject := fmt.Sprintf("%s.%s", TaskStatusSubject, task.ID) + _, err = s.js.Publish(subject, data) + if err != nil { + return fmt.Errorf("failed to publish task event: %w", err) + } + + logger.DebugCF("swarm", "Saved task status event", map[string]interface{}{ + "task_id": task.ID, + "event_type": string(eventType), + "status": string(task.Status), + }) + + return nil +} + +// GetTaskHistory retrieves the complete event history for a task +func (s *TaskLifecycleStore) GetTaskHistory(ctx context.Context, taskID string) ([]TaskEvent, error) { + subject := fmt.Sprintf("%s.%s", TaskStatusSubject, taskID) + + // Create ephemeral subscription without durable consumer + queueName := fmt.Sprintf("task-history-%s", taskID) + sub, err := s.js.PullSubscribe(subject, queueName, nats.AckExplicit()) + if err != nil { + return nil, fmt.Errorf("failed to subscribe: %w", err) + } + defer sub.Unsubscribe() + + // Fetch messages with timeout + events := []TaskEvent{} + fetchDeadline := time.Now().Add(5 * time.Second) + + for time.Now().Before(fetchDeadline) { + msgs, err := sub.Fetch(100) + if err != nil { + if err == nats.ErrTimeout { + break + } + continue + } + + for _, msg := range msgs { + var event TaskEvent + if err := json.Unmarshal(msg.Data, &event); err != nil { + msg.Ack() + continue + } + events = append(events, event) + msg.Ack() + } + + if len(msgs) < 100 { + // No more messages + break + } + } + + return events, nil +} + +// GetActiveTasks retrieves all currently active (running/pending/assigned) tasks +// from the recent event history +func (s *TaskLifecycleStore) GetActiveTasks(ctx context.Context) ([]*SwarmTask, error) { + // Get stream info to check if stream exists + _, err := s.js.StreamInfo(TaskStreamName) + if err != nil { + return nil, fmt.Errorf("failed to get stream info: %w", err) + } + + // Use stream's state to understand active tasks + // Since we're using subject-based filtering, we need to scan recent messages + activeTasks := make(map[string]*SwarmTask) + + // Create a durable consumer for scanning all task events + consumerName := "active-tasks-scan" + _, err = s.js.ConsumerInfo(TaskStreamName, consumerName) + if err != nil { + // Create consumer for scanning all task status events + // FilterSubject must match the stream's subject pattern + _, err = s.js.AddConsumer(TaskStreamName, &nats.ConsumerConfig{ + Durable: consumerName, + DeliverPolicy: nats.DeliverAllPolicy, + AckPolicy: nats.AckExplicitPolicy, + FilterSubject: TaskStatusSubject + ".>", + }) + if err != nil { + return nil, fmt.Errorf("failed to create scan consumer: %w", err) + } + } + + // When binding to an existing consumer, subject must match the filter subject + sub, err := s.js.PullSubscribe(TaskStatusSubject+".>", "", nats.AckExplicit(), nats.Bind(TaskStreamName, consumerName)) + if err != nil { + return nil, fmt.Errorf("failed to subscribe for active tasks: %w", err) + } + defer sub.Unsubscribe() + + // Fetch recent messages + msgs, err := sub.Fetch(1000, nats.MaxWait(2*time.Second)) + if err != nil && err != nats.ErrTimeout { + return nil, fmt.Errorf("failed to fetch messages: %w", err) + } + + for _, msg := range msgs { + var event TaskEvent + if err := json.Unmarshal(msg.Data, &event); err != nil { + msg.Ack() + continue + } + + // Check if task is still active + if event.Status == TaskPending || event.Status == TaskAssigned || event.Status == TaskRunning { + // Create a minimal SwarmTask from the event + task := &SwarmTask{ + ID: event.TaskID, + Status: event.Status, + AssignedTo: event.NodeID, + } + activeTasks[event.TaskID] = task + } + msg.Ack() + } + + // Convert map to slice + result := make([]*SwarmTask, 0, len(activeTasks)) + for _, task := range activeTasks { + result = append(result, task) + } + + return result, nil +} + +// GetTasksByNode retrieves all tasks assigned to a specific node +func (s *TaskLifecycleStore) GetTasksByNode(ctx context.Context, nodeID string) ([]*SwarmTask, error) { + // Get stream info to check if stream exists + _, err := s.js.StreamInfo(TaskStreamName) + if err != nil { + return nil, fmt.Errorf("failed to get stream info: %w", err) + } + + tasks := make(map[string]*SwarmTask) + + // Create a consumer for scanning all tasks + consumerName := "node-tasks-scan" + _, err = s.js.ConsumerInfo(TaskStreamName, consumerName) + if err != nil { + // Create consumer for scanning + _, err = s.js.AddConsumer(TaskStreamName, &nats.ConsumerConfig{ + Durable: consumerName, + DeliverPolicy: nats.DeliverAllPolicy, + AckPolicy: nats.AckExplicitPolicy, + FilterSubject: TaskStatusSubject + ".>", + }) + if err != nil { + return nil, fmt.Errorf("failed to create scan consumer: %w", err) + } + } + + // When binding to an existing consumer, subject must match the filter subject + sub, err := s.js.PullSubscribe(TaskStatusSubject+".>", "", nats.AckExplicit(), nats.Bind(TaskStreamName, consumerName)) + if err != nil { + return nil, fmt.Errorf("failed to subscribe for node tasks: %w", err) + } + defer sub.Unsubscribe() + + // Fetch recent messages + msgs, err := sub.Fetch(1000, nats.MaxWait(2*time.Second)) + if err != nil && err != nats.ErrTimeout { + return nil, fmt.Errorf("failed to fetch messages: %w", err) + } + + for _, msg := range msgs { + var event TaskEvent + if err := json.Unmarshal(msg.Data, &event); err != nil { + msg.Ack() + continue + } + + // Only process events for this node + if event.NodeID == nodeID { + if task, exists := tasks[event.TaskID]; exists { + // Update existing task with latest status + task.Status = event.Status + task.AssignedTo = event.NodeID + } else { + // Create new task entry + task := &SwarmTask{ + ID: event.TaskID, + Status: event.Status, + AssignedTo: event.NodeID, + } + tasks[event.TaskID] = task + } + } + msg.Ack() + } + + result := make([]*SwarmTask, 0, len(tasks)) + for _, task := range tasks { + result = append(result, task) + } + + return result, nil +} + +// DeleteTaskHistory removes event history for a specific task +func (s *TaskLifecycleStore) DeleteTaskHistory(ctx context.Context, taskID string) error { + subject := fmt.Sprintf("%s.%s", TaskStatusSubject, taskID) + + // Create an ephemeral pull consumer to get messages for deletion + sub, err := s.js.PullSubscribe(subject, "", nats.AckExplicit()) + if err != nil { + return fmt.Errorf("failed to create delete consumer: %w", err) + } + defer sub.Unsubscribe() + + // Fetch and delete messages for this task + deletedCount := 0 + for { + msgs, err := sub.Fetch(100, nats.MaxWait(1*time.Second)) + if err == nats.ErrTimeout { + break + } + if err != nil || len(msgs) == 0 { + break + } + + for _, msg := range msgs { + // Get message metadata to find sequence number + meta, err := msg.Metadata() + if err != nil { + msg.Ack() + continue + } + + // Delete the message using the JetStream API + // The API uses stream name and sequence number + err = s.js.DeleteMsg(TaskStreamName, meta.Sequence.Stream) + if err != nil { + // Message might have been deleted already + msg.Ack() + continue + } + deletedCount++ + } + } + + logger.DebugCF("swarm", "Deleted task history", map[string]interface{}{ + "task_id": taskID, + "deleted_msgs": deletedCount, + }) + + return nil +} + +// GetLatestTaskState retrieves the latest state of a task from its history +func (s *TaskLifecycleStore) GetLatestTaskState(ctx context.Context, taskID string) (*SwarmTask, error) { + events, err := s.GetTaskHistory(ctx, taskID) + if err != nil { + return nil, err + } + + if len(events) == 0 { + return nil, fmt.Errorf("no history found for task %s", taskID) + } + + // Get the most recent event + latestEvent := events[len(events)-1] + + task := &SwarmTask{ + ID: latestEvent.TaskID, + Status: latestEvent.Status, + AssignedTo: latestEvent.NodeID, + } + + return task, nil +} diff --git a/pkg/swarm/lifecycle_test.go b/pkg/swarm/lifecycle_test.go new file mode 100644 index 000000000..4ce5bb6b6 --- /dev/null +++ b/pkg/swarm/lifecycle_test.go @@ -0,0 +1,285 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewTaskLifecycleStore(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store := NewTaskLifecycleStore(tn.JS()) + + err := store.Initialize(ctx) + require.NoError(t, err) + + // Verify stream was created + stream := GetStreamInfo(t, tn, TaskStreamName) + assert.NotNil(t, stream) + assert.Equal(t, TaskStreamName, stream.Config.Name) + }) +} + +func TestTaskLifecycleStore_SaveAndGetTaskHistory(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store := NewTaskLifecycleStore(tn.JS()) + + err := store.Initialize(ctx) + require.NoError(t, err) + + // Create a test task + task := &SwarmTask{ + ID: "test-task-1", + Type: TaskTypeDirect, + Prompt: "Test prompt", + Capability: "test", + Status: TaskPending, + } + + // Save task status + err = store.SaveTaskStatus(task, TaskEventCreated, "Task created") + require.NoError(t, err) + + // Get task history + history, err := store.GetTaskHistory(ctx, task.ID) + require.NoError(t, err) + assert.NotEmpty(t, history) + + // Verify the event + event := history[0] + assert.Equal(t, task.ID, event.TaskID) + assert.Equal(t, TaskEventCreated, event.EventType) + assert.Equal(t, "Task created", event.Message) + }) +} + +func TestTaskLifecycleStore_TaskTransitions(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store := NewTaskLifecycleStore(tn.JS()) + + err := store.Initialize(ctx) + require.NoError(t, err) + + // Purge the stream to ensure clean state + PurgeTestStream(t, tn, TaskStreamName) + + task := CreateTestTask("task-transitions", "direct", "Test task", "test") + + // Simulate task lifecycle + transitions := []struct { + event TaskEventType + status SwarmTaskStatus + message string + }{ + {TaskEventCreated, TaskPending, "Task created"}, + {TaskEventAssigned, TaskAssigned, "Assigned to node-1"}, + {TaskEventStarted, TaskRunning, "Task started"}, + {TaskEventProgress, TaskRunning, "50% complete"}, + {TaskEventCompleted, TaskDone, "Task completed"}, + } + + for _, tt := range transitions { + task.Status = tt.status + err = store.SaveTaskStatus(task, tt.event, tt.message) + require.NoError(t, err) + } + + // Get full history + history, err := store.GetTaskHistory(ctx, task.ID) + require.NoError(t, err) + assert.Len(t, history, 5) + + // Verify order + assert.Equal(t, TaskEventCreated, history[0].EventType) + assert.Equal(t, TaskEventCompleted, history[4].EventType) + }) +} + +func TestTaskLifecycleStore_GetActiveTasks(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store := NewTaskLifecycleStore(tn.JS()) + + err := store.Initialize(ctx) + require.NoError(t, err) + + // Create tasks with different statuses + tasks := []*SwarmTask{ + CreateTestTask("task-active-1", "direct", "Active task 1", "test"), + CreateTestTask("task-active-2", "direct", "Active task 2", "test"), + CreateTestTask("task-completed", "direct", "Completed task", "test"), + } + + // Save statuses + tasks[0].Status = TaskRunning + err = store.SaveTaskStatus(tasks[0], TaskEventStarted, "Started") + require.NoError(t, err) + + tasks[1].Status = TaskAssigned + err = store.SaveTaskStatus(tasks[1], TaskEventAssigned, "Assigned") + require.NoError(t, err) + + tasks[2].Status = TaskDone + err = store.SaveTaskStatus(tasks[2], TaskEventCompleted, "Completed") + require.NoError(t, err) + + // Get active tasks + active, err := store.GetActiveTasks(ctx) + require.NoError(t, err) + + // Should have 2 active tasks + assert.GreaterOrEqual(t, len(active), 2) + }) +} + +func TestTaskLifecycleStore_GetTasksByNode(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store := NewTaskLifecycleStore(tn.JS()) + + err := store.Initialize(ctx) + require.NoError(t, err) + + // Create tasks for different nodes + node1Tasks := []string{"task-node1-1", "task-node1-2"} + + for _, taskID := range node1Tasks { + task := CreateTestTask(taskID, "direct", "Task for node-1", "test") + task.Status = TaskRunning + task.AssignedTo = "node-1" + err = store.SaveTaskStatus(task, TaskEventStarted, "Started on node-1") + require.NoError(t, err) + } + + // Create task for node-2 + task := CreateTestTask("task-node2-1", "direct", "Task for node-2", "test") + task.Status = TaskRunning + task.AssignedTo = "node-2" + err = store.SaveTaskStatus(task, TaskEventStarted, "Started on node-2") + require.NoError(t, err) + + // Get tasks for node-1 + tasks, err := store.GetTasksByNode(ctx, "node-1") + require.NoError(t, err) + assert.GreaterOrEqual(t, len(tasks), 2) + }) +} + +func TestTaskLifecycleStore_DeleteTaskHistory(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store := NewTaskLifecycleStore(tn.JS()) + + err := store.Initialize(ctx) + require.NoError(t, err) + + task := CreateTestTask("task-delete-history", "direct", "Task to delete", "test") + + // Save some history + err = store.SaveTaskStatus(task, TaskEventCreated, "Created") + require.NoError(t, err) + + task.Status = TaskRunning + err = store.SaveTaskStatus(task, TaskEventStarted, "Started") + require.NoError(t, err) + + // Verify history exists + history, err := store.GetTaskHistory(ctx, task.ID) + require.NoError(t, err) + assert.NotEmpty(t, history) + + // Delete history + err = store.DeleteTaskHistory(ctx, task.ID) + require.NoError(t, err) + + // History should be cleared or stream should be empty for this task + history, err = store.GetTaskHistory(ctx, task.ID) + require.NoError(t, err) + // After delete, history should be empty or not exist + assert.Empty(t, history) + }) +} + +func TestTaskLifecycleStore_GetLatestTaskState(t *testing.T) { + RunTestWithNATS(t, func(tn *TestNATS) { + ctx := context.Background() + store := NewTaskLifecycleStore(tn.JS()) + + err := store.Initialize(ctx) + require.NoError(t, err) + + task := CreateTestTask("task-latest", "direct", "Get latest state", "test") + + // Save multiple status updates + transitions := []struct { + event TaskEventType + status SwarmTaskStatus + }{ + {TaskEventCreated, TaskPending}, + {TaskEventAssigned, TaskAssigned}, + {TaskEventStarted, TaskRunning}, + {TaskEventCompleted, TaskDone}, + } + for _, tt := range transitions { + task.Status = tt.status + err = store.SaveTaskStatus(task, tt.event, fmt.Sprintf("State %s", tt.status)) + require.NoError(t, err) + } + + // Get latest state + latest, err := store.GetLatestTaskState(ctx, task.ID) + require.NoError(t, err) + assert.NotNil(t, latest) + assert.Equal(t, TaskDone, latest.Status) + }) +} + +func TestTaskEventTypes(t *testing.T) { + events := []TaskEventType{ + TaskEventCreated, + TaskEventAssigned, + TaskEventStarted, + TaskEventProgress, + TaskEventCompleted, + TaskEventFailed, + TaskEventRetry, + TaskEventCheckpoint, + } + + for _, event := range events { + assert.NotEmpty(t, string(event), "Event type should not be empty") + } +} + +func TestTaskEvent(t *testing.T) { + event := &TaskEvent{ + EventID: "test-event-1", + TaskID: "test-task-1", + EventType: TaskEventCreated, + Timestamp: time.Now().UnixMilli(), + NodeID: "node-1", + Status: TaskPending, + Message: "Test message", + Progress: 0.0, + } + + assert.Equal(t, "test-event-1", event.EventID) + assert.Equal(t, "test-task-1", event.TaskID) + assert.Equal(t, TaskEventCreated, event.EventType) + assert.Equal(t, "node-1", event.NodeID) + assert.Equal(t, TaskPending, event.Status) + assert.Equal(t, "Test message", event.Message) + assert.Equal(t, 0.0, event.Progress) +} diff --git a/pkg/swarm/manager.go b/pkg/swarm/manager.go new file mode 100644 index 000000000..4b1ebc015 --- /dev/null +++ b/pkg/swarm/manager.go @@ -0,0 +1,450 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// Manager orchestrates all swarm components +type Manager struct { + cfg *config.Config + provider providers.LLMProvider + embeddedNATS *EmbeddedNATS + bridge *NATSBridge + temporal *TemporalClient + discovery *Discovery + coordinator *Coordinator + worker *Worker + specialist *SpecialistNode + activities *Activities + lifecycle *TaskLifecycleStore + checkpointStore *CheckpointStore + failoverManager *FailoverManager + contextPool *ContextPool + electionMgr *ElectionManager + roleSwitcher *RoleSwitcher + identity *identity.LoadedIdentity + nodeInfo *NodeInfo + agentLoop *agent.AgentLoop + localBus *bus.MessageBus + enableElection bool // Enable leader election for dynamic role switching +} + +// NewManager creates a new swarm manager +func NewManager(cfg *config.Config, agentLoop *agent.AgentLoop, provider providers.LLMProvider, localBus *bus.MessageBus) *Manager { + swarmCfg := &cfg.Swarm + + // Validate configuration + if err := swarmCfg.Validate(); err != nil { + logger.ErrorCF("swarm", "Invalid configuration", map[string]interface{}{"error": err.Error()}) + return nil + } + + // Load or generate identity + identityLoader := identity.NewLoader() + identityLoader.SetConfig(swarmCfg.HID, swarmCfg.SID) + loadedIdentity := identityLoader.LoadOrGenerate() + hid := loadedIdentity.HID + sid := loadedIdentity.SID + + // Set identity on agent loop for cross-instance communication + agentLoop.SetIdentity(hid, sid) + + // Generate node ID if not set + nodeID := swarmCfg.NodeID + if nodeID == "" { + nodeID = fmt.Sprintf("claw-%s", uuid.New().String()[:8]) + } + + // Create node info + nodeInfo := &NodeInfo{ + ID: nodeID, + Role: NodeRole(swarmCfg.Role), + Capabilities: swarmCfg.Capabilities, + Model: cfg.Agents.Defaults.Model, + Status: StatusOnline, + MaxTasks: swarmCfg.MaxConcurrent, + StartedAt: time.Now().UnixMilli(), + Metadata: make(map[string]string), + } + // Store identity in node metadata for discovery + nodeInfo.Metadata["hid"] = hid + nodeInfo.Metadata["sid"] = sid + + m := &Manager{ + cfg: cfg, + provider: provider, + identity: loadedIdentity, + nodeInfo: nodeInfo, + agentLoop: agentLoop, + localBus: localBus, + } + + // Create components + m.bridge = NewNATSBridge(swarmCfg, localBus, nodeInfo) + m.temporal = NewTemporalClient(&swarmCfg.Temporal) + m.discovery = NewDiscovery(m.bridge, nodeInfo, swarmCfg) + + // Create role-specific components + if nodeInfo.Role == RoleCoordinator { + m.coordinator = NewCoordinator(swarmCfg, m.bridge, m.temporal, m.discovery, agentLoop, provider, localBus) + } + if nodeInfo.Role == RoleWorker || nodeInfo.Role == RoleSpecialist { + m.worker = NewWorker(swarmCfg, m.bridge, m.temporal, agentLoop, provider, nodeInfo) + } + if nodeInfo.Role == RoleSpecialist { + // Create specialist node for capability-based routing + m.specialist = NewSpecialistNode(swarmCfg, m.bridge, m.temporal, agentLoop, provider, nodeInfo, m.bridge.js, m.bridge.nc, "") + } + + logger.InfoCF("swarm", "Swarm manager initialized with identity", map[string]interface{}{ + "hid": hid, + "sid": sid, + "source": loadedIdentity.Source.String(), + "node_id": nodeID, + "role": string(nodeInfo.Role), + }) + + return m +} + +// Start initializes and starts all swarm components +func (m *Manager) Start(ctx context.Context) error { + swarmCfg := &m.cfg.Swarm + + // Start embedded NATS if configured + if swarmCfg.NATS.Embedded { + m.embeddedNATS = NewEmbeddedNATS(&swarmCfg.NATS) + if err := m.embeddedNATS.Start(); err != nil { + return fmt.Errorf("failed to start embedded NATS: %w", err) + } + // Override URLs to connect to embedded server + swarmCfg.NATS.URLs = []string{m.embeddedNATS.ClientURL()} + } + + // Connect to NATS + if err := m.bridge.Connect(ctx); err != nil { + return fmt.Errorf("failed to connect to NATS: %w", err) + } + + // Start NATS bridge + if err := m.bridge.Start(ctx); err != nil { + return fmt.Errorf("failed to start NATS bridge: %w", err) + } + + // Connect to Temporal (non-fatal if unavailable) + if err := m.temporal.Connect(ctx); err != nil { + logger.WarnCF("swarm", "Temporal connection failed (workflows disabled)", map[string]interface{}{ + "error": err.Error(), + }) + } + + // Create activities instance for LLM-driven task operations + m.activities = NewActivities(m.provider, m.agentLoop, &m.cfg.Swarm, m.nodeInfo) + + // Start Temporal worker with workflow registrations if connected + if m.temporal.IsConnected() { + wfs := []interface{}{SwarmWorkflow} + if err := m.temporal.StartWorker(ctx, wfs, m.activities); err != nil { + logger.WarnCF("swarm", "Failed to start Temporal worker", map[string]interface{}{ + "error": err.Error(), + }) + } + } + + // Start discovery + if err := m.discovery.Start(ctx); err != nil { + return fmt.Errorf("failed to start discovery: %w", err) + } + + // Initialize lifecycle store + m.lifecycle = NewTaskLifecycleStore(m.bridge.js) + if err := m.lifecycle.Initialize(ctx); err != nil { + logger.WarnCF("swarm", "Failed to initialize lifecycle store", map[string]interface{}{ + "error": err.Error(), + }) + } + + // Initialize checkpoint store + var err error + m.checkpointStore, err = NewCheckpointStore(m.bridge.js) + if err != nil { + logger.WarnCF("swarm", "Failed to initialize checkpoint store", map[string]interface{}{ + "error": err.Error(), + }) + } + + // Initialize failover manager + if m.lifecycle != nil && m.checkpointStore != nil { + m.failoverManager = NewFailoverManager(m.discovery, m.lifecycle, m.checkpointStore, m.bridge, m.nodeInfo, m.bridge.js) + if err := m.failoverManager.Start(ctx); err != nil { + logger.WarnCF("swarm", "Failed to start failover manager", map[string]interface{}{ + "error": err.Error(), + }) + } + } + + // Initialize shared context pool + m.contextPool = NewContextPool(m.bridge.js, m.nodeInfo.ID, m.identity.HID, m.identity.SID) + if err := m.contextPool.Start(ctx); err != nil { + logger.WarnCF("swarm", "Failed to start context pool", map[string]interface{}{ + "error": err.Error(), + }) + } else { + logger.InfoCF("swarm", "Shared context pool started", map[string]interface{}{ + "hid": m.identity.HID, + "sid": m.identity.SID, + "node_id": m.nodeInfo.ID, + }) + } + + // Initialize leader election if enabled + if m.enableElection { + m.electionMgr = NewElectionManager(m.bridge.nc, m.bridge.js, m.nodeInfo.ID, m.identity.HID, m.identity.SID) + m.roleSwitcher = NewRoleSwitcher(m.electionMgr, m.nodeInfo, m) + + electionCfg := &ElectionConfig{ + ElectionSubject: fmt.Sprintf("picoclaw.election.%s", m.identity.HID), + LeaseDuration: 10 * time.Second, + ElectionInterval: 3 * time.Second, + PreVoteDelay: time.Duration(0), + } + + if err := m.electionMgr.Start(ctx, electionCfg); err != nil { + logger.WarnCF("swarm", "Failed to start election manager", map[string]interface{}{ + "error": err.Error(), + }) + } else { + logger.InfoCF("swarm", "Leader election enabled", map[string]interface{}{ + "node_id": m.nodeInfo.ID, + "hid": m.identity.HID, + }) + } + } + + // Start role-specific components + if m.coordinator != nil { + if err := m.coordinator.Start(ctx); err != nil { + return fmt.Errorf("failed to start coordinator: %w", err) + } + } + + if m.worker != nil { + if err := m.worker.Start(ctx); err != nil { + return fmt.Errorf("failed to start worker: %w", err) + } + } + + if m.specialist != nil { + if err := m.specialist.Start(ctx); err != nil { + return fmt.Errorf("failed to start specialist: %w", err) + } + } + + logger.InfoCF("swarm", "Swarm manager started", map[string]interface{}{ + "node_id": m.nodeInfo.ID, + "role": string(m.nodeInfo.Role), + "capabilities": fmt.Sprintf("%v", m.nodeInfo.Capabilities), + "nats": m.bridge.IsConnected(), + "temporal": m.temporal.IsConnected(), + }) + + return nil +} + +// Stop gracefully stops all swarm components +func (m *Manager) Stop() { + logger.InfoC("swarm", "Stopping swarm manager") + + if m.specialist != nil { + m.specialist.Stop() + } + + if m.worker != nil { + m.worker.Stop() + } + + if m.coordinator != nil { + m.coordinator.Stop() + } + + if m.electionMgr != nil { + m.electionMgr.Stop() + } + + if m.failoverManager != nil { + m.failoverManager.Stop() + } + + if m.contextPool != nil { + m.contextPool.Stop() + } + + m.discovery.Stop() + m.temporal.Stop() + if err := m.bridge.Stop(); err != nil { + logger.WarnCF("swarm", "Error stopping NATS bridge", map[string]interface{}{ + "error": err.Error(), + }) + } + + if m.embeddedNATS != nil { + m.embeddedNATS.Stop() + } + + logger.InfoC("swarm", "Swarm manager stopped") +} + +// GetNodeInfo returns this node's information +func (m *Manager) GetNodeInfo() *NodeInfo { + return m.nodeInfo +} + +// GetDiscoveredNodes returns all discovered nodes +func (m *Manager) GetDiscoveredNodes() []*NodeInfo { + return m.discovery.GetNodes("", "") +} + +// IsNATSConnected returns true if connected to NATS +func (m *Manager) IsNATSConnected() bool { + return m.bridge.IsConnected() +} + +// IsTemporalConnected returns true if connected to Temporal +func (m *Manager) IsTemporalConnected() bool { + return m.temporal.IsConnected() +} + +// DiscoveredNodeCount returns the count of discovered nodes +func (m *Manager) DiscoveredNodeCount() int { + return m.discovery.NodeCount() +} + +// GetContextPool returns the shared context pool +func (m *Manager) GetContextPool() *ContextPool { + return m.contextPool +} + +// GetIdentity returns the node's identity +func (m *Manager) GetIdentity() *identity.LoadedIdentity { + return m.identity +} + +// IsLeader returns true if this node is the leader (via election) +func (m *Manager) IsLeader() bool { + if m.electionMgr != nil { + return m.electionMgr.IsLeader() + } + return false +} + +// GetLeaderID returns the current leader's node ID +func (m *Manager) GetLeaderID() string { + if m.electionMgr != nil { + return m.electionMgr.GetLeaderID() + } + return "" +} + +// SetElectionEnabled enables or disables leader election +func (m *Manager) SetElectionEnabled(enabled bool) { + m.enableElection = enabled +} + +// handleRoleChange handles dynamic role changes +func (m *Manager) handleRoleChange(newRole NodeRole) { + logger.InfoCF("swarm", "Handling role change", map[string]interface{}{ + "node_id": m.nodeInfo.ID, + "new_role": string(newRole), + "old_role": string(m.nodeInfo.Role), + }) + + // Stop old role components + switch m.nodeInfo.Role { + case RoleCoordinator: + if m.coordinator != nil { + m.coordinator.Stop() + m.coordinator = nil + } + case RoleWorker: + if m.worker != nil { + m.worker.Stop() + m.worker = nil + } + case RoleSpecialist: + if m.specialist != nil { + m.specialist.Stop() + m.specialist = nil + } + } + + // Update node role + m.nodeInfo.Role = newRole + + // Start new role components + swarmCfg := &m.cfg.Swarm + switch newRole { + case RoleCoordinator: + m.coordinator = NewCoordinator(swarmCfg, m.bridge, m.temporal, m.discovery, m.agentLoop, m.provider, m.localBus) + if m.coordinator != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := m.coordinator.Start(ctx); err != nil { + logger.ErrorCF("swarm", "Failed to start coordinator after role change", map[string]interface{}{ + "error": err.Error(), + }) + } + } + case RoleWorker: + m.worker = NewWorker(swarmCfg, m.bridge, m.temporal, m.agentLoop, m.provider, m.nodeInfo) + if m.worker != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := m.worker.Start(ctx); err != nil { + logger.ErrorCF("swarm", "Failed to start worker after role change", map[string]interface{}{ + "error": err.Error(), + }) + } + } + case RoleSpecialist: + m.worker = NewWorker(swarmCfg, m.bridge, m.temporal, m.agentLoop, m.provider, m.nodeInfo) + m.specialist = NewSpecialistNode(swarmCfg, m.bridge, m.temporal, m.agentLoop, m.provider, m.nodeInfo, m.bridge.js, m.bridge.nc, "") + if m.worker != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := m.worker.Start(ctx); err != nil { + logger.ErrorCF("swarm", "Failed to start worker after role change", map[string]interface{}{ + "error": err.Error(), + }) + } + } + if m.specialist != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := m.specialist.Start(ctx); err != nil { + logger.ErrorCF("swarm", "Failed to start specialist after role change", map[string]interface{}{ + "error": err.Error(), + }) + } + } + } + + logger.InfoCF("swarm", "Role change completed", map[string]interface{}{ + "node_id": m.nodeInfo.ID, + "role": string(newRole), + }) +} diff --git a/pkg/swarm/messaging.go b/pkg/swarm/messaging.go new file mode 100644 index 000000000..ca9a1cca7 --- /dev/null +++ b/pkg/swarm/messaging.go @@ -0,0 +1,412 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// MessagingAPI provides a high-level API for inter-shrimp communication +// It handles both same-H-id (within tenant) and cross-H-id (cross-tenant) messaging +type MessagingAPI struct { + bridge *NATSBridge + localHID string + localSID string + nodeID string + mu sync.RWMutex + + // Message handlers by type + handlers map[string][]func(*InterShrimpMessage) +} + +// InterShrimpMessage represents a message between shrimp instances +type InterShrimpMessage struct { + // From identifies the sender + FromHID string `json:"from_hid"` + FromSID string `json:"from_sid"` + FromNodeID string `json:"from_node_id"` + + // To identifies the recipient (empty for broadcast within same H-id) + ToHID string `json:"to_hid,omitempty"` + ToSID string `json:"to_sid,omitempty"` + ToNodeID string `json:"to_node_id,omitempty"` + + // Type is the message type for routing + Type string `json:"type"` + + // Payload contains the message data + Payload map[string]interface{} `json:"payload"` + + // Timestamp when message was sent + Timestamp int64 `json:"timestamp"` + + // ID uniquely identifies this message + ID string `json:"id"` + + // InResponseTo links this message to a previous message + InResponseTo string `json:"in_response_to,omitempty"` +} + +// MessagingConfig configures the messaging API +type MessagingConfig struct { + // AllowCrossHID enables cross-H-id communication + AllowCrossHID bool + + // RequireAuth requires authorization for cross-H-id messages + RequireAuth bool + + // AllowedHIDs lists H-ids allowed to communicate with us + AllowedHIDs []string +} + +// NewMessagingAPI creates a new messaging API +func NewMessagingAPI(bridge *NATSBridge, hid, sid, nodeID string) *MessagingAPI { + api := &MessagingAPI{ + bridge: bridge, + localHID: hid, + localSID: sid, + nodeID: nodeID, + handlers: make(map[string][]func(*InterShrimpMessage)), + } + + // Subscribe to messages for this node + go api.subscribeToMessages() + + return api +} + +// Subscribe registers a handler for a specific message type +func (m *MessagingAPI) Subscribe(messageType string, handler func(*InterShrimpMessage)) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.handlers[messageType] == nil { + m.handlers[messageType] = make([]func(*InterShrimpMessage), 0) + } + m.handlers[messageType] = append(m.handlers[messageType], handler) + + logger.DebugCF("swarm", "Registered message handler", map[string]interface{}{ + "type": messageType, + }) +} + +// Unsubscribe removes a handler (not implemented - handlers persist for session) +func (m *MessagingAPI) Unsubscribe(messageType string) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.handlers, messageType) +} + +// SendBroadcast sends a message to all nodes in the same H-id +func (m *MessagingAPI) SendBroadcast(ctx context.Context, messageType string, payload map[string]interface{}) error { + msg := &InterShrimpMessage{ + FromHID: m.localHID, + FromSID: m.localSID, + FromNodeID: m.nodeID, + Type: messageType, + Payload: payload, + Timestamp: time.Now().UnixMilli(), + ID: generateMessageID(), + } + + // For same-H-id broadcast, use a special subject + subject := fmt.Sprintf("picoclaw.msg.%s.*.%s", m.localHID, messageType) + + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + if err := m.bridge.conn.Publish(subject, data); err != nil { + return fmt.Errorf("failed to publish broadcast: %w", err) + } + + logger.DebugCF("swarm", "Sent broadcast message", map[string]interface{}{ + "type": messageType, + "msg_id": msg.ID, + }) + + return nil +} + +// SendToNode sends a message to a specific node (same or different H-id) +func (m *MessagingAPI) SendToNode(ctx context.Context, targetHID, targetSID, targetNodeID, messageType string, payload map[string]interface{}) error { + msg := &InterShrimpMessage{ + FromHID: m.localHID, + FromSID: m.localSID, + FromNodeID: m.nodeID, + ToHID: targetHID, + ToSID: targetSID, + ToNodeID: targetNodeID, + Type: messageType, + Payload: payload, + Timestamp: time.Now().UnixMilli(), + ID: generateMessageID(), + } + + var subject string + if targetNodeID != "" { + // Direct to node + subject = fmt.Sprintf("picoclaw.msg.%s.%s.node.%s.%s", targetHID, targetSID, targetNodeID, messageType) + } else if targetSID != "" { + // To any node with this S-id + subject = fmt.Sprintf("picoclaw.msg.%s.%s.*.%s", targetHID, targetSID, messageType) + } else { + // To any node in this H-id + subject = fmt.Sprintf("picoclaw.msg.%s.*.%s", targetHID, messageType) + } + + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + if err := m.bridge.conn.Publish(subject, data); err != nil { + return fmt.Errorf("failed to publish message: %w", err) + } + + logger.DebugCF("swarm", "Sent direct message", map[string]interface{}{ + "target": fmt.Sprintf("%s/%s/%s", targetHID, targetSID, targetNodeID), + "type": messageType, + "msg_id": msg.ID, + }) + + return nil +} + +// SendReply sends a reply to a previous message +func (m *MessagingAPI) SendReply(ctx context.Context, originalMsg *InterShrimpMessage, payload map[string]interface{}) error { + msg := &InterShrimpMessage{ + FromHID: m.localHID, + FromSID: m.localSID, + FromNodeID: m.nodeID, + ToHID: originalMsg.FromHID, + ToSID: originalMsg.FromSID, + ToNodeID: originalMsg.FromNodeID, + Type: originalMsg.Type + ".reply", + Payload: payload, + Timestamp: time.Now().UnixMilli(), + ID: generateMessageID(), + InResponseTo: originalMsg.ID, + } + + subject := fmt.Sprintf("picoclaw.msg.%s.%s.node.%s.%s", + msg.ToHID, msg.ToSID, msg.ToNodeID, msg.Type) + + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + if err := m.bridge.conn.Publish(subject, data); err != nil { + return fmt.Errorf("failed to publish reply: %w", err) + } + + return nil +} + +// Request sends a message and waits for a response +func (m *MessagingAPI) Request(ctx context.Context, targetHID, targetSID, targetNodeID, messageType string, payload map[string]interface{}, timeout time.Duration) (*InterShrimpMessage, error) { + // Create inbox for response + inbox := m.bridge.conn.NewRespInbox() + responseCh := make(chan *InterShrimpMessage, 1) + + // Subscribe to responses + sub, err := m.bridge.conn.Subscribe(inbox, func(msg *nats.Msg) { + var response InterShrimpMessage + if err := json.Unmarshal(msg.Data, &response); err != nil { + return + } + select { + case responseCh <- &response: + default: + } + }) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to response: %w", err) + } + defer sub.Unsubscribe() + + // Send request with reply-to inbox + requestMsg := &InterShrimpMessage{ + FromHID: m.localHID, + FromSID: m.localSID, + FromNodeID: m.nodeID, + ToHID: targetHID, + ToSID: targetSID, + ToNodeID: targetNodeID, + Type: messageType, + Payload: payload, + Timestamp: time.Now().UnixMilli(), + ID: generateMessageID(), + } + + subject := fmt.Sprintf("picoclaw.msg.%s.%s.node.%s.%s", targetHID, targetSID, targetNodeID, messageType) + + data, err := json.Marshal(requestMsg) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + if err := m.bridge.conn.PublishRequest(subject, inbox, data); err != nil { + return nil, fmt.Errorf("failed to publish request: %w", err) + } + + // Wait for response + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + select { + case response := <-responseCh: + return response, nil + case <-ctx.Done(): + return nil, fmt.Errorf("request timeout") + } +} + +// subscribeToMessages subscribes to messages for this node +func (m *MessagingAPI) subscribeToMessages() { + // Subscribe to messages for this specific node + nodeSubject := fmt.Sprintf("picoclaw.msg.%s.%s.node.%s.*", m.localHID, m.localSID, m.nodeID) + m.bridge.conn.Subscribe(nodeSubject, func(msg *nats.Msg) { + var message InterShrimpMessage + if err := json.Unmarshal(msg.Data, &message); err != nil { + logger.WarnCF("swarm", "Failed to unmarshal message", map[string]interface{}{ + "error": err.Error(), + }) + return + } + m.dispatchMessage(&message) + }) + + // Subscribe to broadcast messages for our H-id + broadcastSubject := fmt.Sprintf("picoclaw.msg.%s.*.*", m.localHID) + m.bridge.conn.Subscribe(broadcastSubject, func(msg *nats.Msg) { + var message InterShrimpMessage + if err := json.Unmarshal(msg.Data, &message); err != nil { + return + } + // Skip messages from ourselves + if message.FromNodeID == m.nodeID { + return + } + m.dispatchMessage(&message) + }) + + logger.InfoCF("swarm", "Messaging API subscribed", map[string]interface{}{ + "hid": m.localHID, + "sid": m.localSID, + "node_id": m.nodeID, + }) +} + +// dispatchMessage dispatches a message to registered handlers +func (m *MessagingAPI) dispatchMessage(msg *InterShrimpMessage) { + m.mu.RLock() + handlers := m.handlers[msg.Type] + m.mu.RUnlock() + + logger.DebugCF("swarm", "Received inter-shrimp message", map[string]interface{}{ + "from": fmt.Sprintf("%s/%s", msg.FromHID, msg.FromSID), + "type": msg.Type, + "handlers": len(handlers), + }) + + for _, handler := range handlers { + go func(h func(*InterShrimpMessage)) { + defer func() { + if r := recover(); r != nil { + logger.ErrorCF("swarm", "Message handler panic", map[string]interface{}{ + "error": fmt.Sprintf("%v", r), + }) + } + }() + h(msg) + }(handler) + } +} + +// Standard message types for MessagingAPI +const ( + // MessageTypeTask is for task-related messages + MessageTypeTask = "task" + + // MessageTypeStatus is for status updates + MessageTypeStatus = "status" + + // MessageTypeContext is for shared context updates + MessageTypeContext = "context" + + // MessageTypeSync is for synchronization + MessageTypeSync = "sync" +) + +// TaskMessage is a task-related message payload +type TaskMessage struct { + TaskID string `json:"task_id"` + Status string `json:"status"` + Result string `json:"result,omitempty"` + Error string `json:"error,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// StatusMessage is a status update payload +type StatusMessage struct { + Load float64 `json:"load"` + TasksRunning int `json:"tasks_running"` + Status string `json:"status"` +} + +// ContextMessage is a context update payload +type ContextMessage struct { + TaskID string `json:"task_id"` + Key string `json:"key"` + Value string `json:"value"` + Action string `json:"action"` // set, delete, merge +} + +// BroadcastTaskStatus broadcasts a task status update +func (m *MessagingAPI) BroadcastTaskStatus(ctx context.Context, taskID, status string, result string, resultErr error) error { + payload := map[string]interface{}{ + "task_id": taskID, + "status": status, + } + if result != "" { + payload["result"] = result + } + if resultErr != nil { + payload["error"] = resultErr.Error() + } + return m.SendBroadcast(ctx, MessageTypeTask, payload) +} + +// BroadcastStatus broadcasts this node's status +func (m *MessagingAPI) BroadcastStatus(ctx context.Context, load float64, tasksRunning int, status string) error { + payload := map[string]interface{}{ + "load": load, + "tasks_running": tasksRunning, + "status": status, + } + return m.SendBroadcast(ctx, MessageTypeStatus, payload) +} + +// PublishContextUpdate publishes a context update to the swarm +func (m *MessagingAPI) PublishContextUpdate(ctx context.Context, taskID, key, value, action string) error { + payload := map[string]interface{}{ + "task_id": taskID, + "key": key, + "value": value, + "action": action, + } + return m.SendBroadcast(ctx, MessageTypeContext, payload) +} diff --git a/pkg/swarm/nats.go b/pkg/swarm/nats.go new file mode 100644 index 000000000..a1ae2cc34 --- /dev/null +++ b/pkg/swarm/nats.go @@ -0,0 +1,458 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// NATS subject patterns +const ( + SubjectHeartbeat = "picoclaw.swarm.heartbeat.%s" // {node_id} + SubjectDiscoveryAnnounce = "picoclaw.swarm.discovery.announce" + SubjectDiscoveryQuery = "picoclaw.swarm.discovery.query" + SubjectTaskAssign = "picoclaw.swarm.task.assign.%s" // {node_id} + SubjectTaskBroadcast = "picoclaw.swarm.task.broadcast.%s" // {capability} + SubjectTaskResult = "picoclaw.swarm.task.result.%s" // {task_id} + SubjectTaskProgress = "picoclaw.swarm.task.progress.%s" // {task_id} + SubjectSystemShutdown = "picoclaw.swarm.system.shutdown.%s" // {node_id} +) + +// NATSBridge connects local MessageBus to NATS for swarm communication +type NATSBridge struct { + conn *nats.Conn + js nats.JetStreamContext + nc *nats.Conn + localBus *bus.MessageBus + nodeInfo *NodeInfo + cfg *config.SwarmConfig + subs []*nats.Subscription + mu sync.RWMutex + running bool + + // Callbacks + onTaskReceived func(*SwarmTask) + onNodeJoin func(*NodeInfo) + onNodeLeave func(nodeID string) +} + +// NewNATSBridge creates a new NATS bridge +func NewNATSBridge(cfg *config.SwarmConfig, localBus *bus.MessageBus, nodeInfo *NodeInfo) *NATSBridge { + return &NATSBridge{ + localBus: localBus, + nodeInfo: nodeInfo, + cfg: cfg, + subs: make([]*nats.Subscription, 0), + } +} + +// Connect establishes connection to NATS server(s) +func (nb *NATSBridge) Connect(ctx context.Context) error { + opts := []nats.Option{ + nats.Name(fmt.Sprintf("picoclaw-%s", nb.nodeInfo.ID)), + nats.ReconnectWait(2 * time.Second), + nats.MaxReconnects(-1), // Unlimited reconnects + nats.DisconnectErrHandler(func(nc *nats.Conn, err error) { + logger.WarnCF("swarm", "NATS disconnected", map[string]interface{}{ + "error": fmt.Sprintf("%v", err), + }) + }), + nats.ReconnectHandler(func(nc *nats.Conn) { + logger.InfoCF("swarm", "NATS reconnected", map[string]interface{}{ + "url": nc.ConnectedUrl(), + }) + }), + } + + if nb.cfg.NATS.Credentials != "" { + opts = append(opts, nats.UserCredentials(nb.cfg.NATS.Credentials)) + } + + urls := nats.DefaultURL + if len(nb.cfg.NATS.URLs) > 0 { + urls = strings.Join(nb.cfg.NATS.URLs, ",") + } + + conn, err := nats.Connect(urls, opts...) + if err != nil { + return fmt.Errorf("failed to connect to NATS: %w", err) + } + + nb.conn = conn + nb.nc = conn + + // Create JetStream context + js, err := conn.JetStream() + if err != nil { + return fmt.Errorf("failed to create JetStream context: %w", err) + } + nb.js = js + logger.InfoCF("swarm", "Connected to NATS", map[string]interface{}{ + "url": conn.ConnectedUrl(), + }) + + return nil +} + +// Start begins listening for swarm messages +func (nb *NATSBridge) Start(ctx context.Context) error { + nb.mu.Lock() + nb.running = true + nb.mu.Unlock() + + // Subscribe to task assignments for this node + taskSub, err := nb.conn.Subscribe( + fmt.Sprintf(SubjectTaskAssign, nb.nodeInfo.ID), + nb.handleTaskAssignment, + ) + if err != nil { + return fmt.Errorf("failed to subscribe to task assignments: %w", err) + } + nb.subs = append(nb.subs, taskSub) + + // Subscribe to capability-based broadcast tasks using queue groups for load balancing + for _, cap := range nb.nodeInfo.Capabilities { + broadcastSub, err := nb.conn.QueueSubscribe( + fmt.Sprintf(SubjectTaskBroadcast, cap), + "workers", // Queue group for load balancing + nb.handleTaskBroadcast, + ) + if err != nil { + return fmt.Errorf("failed to subscribe to broadcast %s: %w", cap, err) + } + nb.subs = append(nb.subs, broadcastSub) + } + + // Subscribe to discovery queries + discoverySub, err := nb.conn.Subscribe(SubjectDiscoveryQuery, nb.handleDiscoveryQuery) + if err != nil { + return fmt.Errorf("failed to subscribe to discovery: %w", err) + } + nb.subs = append(nb.subs, discoverySub) + + // Subscribe to discovery announcements + announceSub, err := nb.conn.Subscribe(SubjectDiscoveryAnnounce, nb.handleDiscoveryAnnounce) + if err != nil { + return fmt.Errorf("failed to subscribe to announcements: %w", err) + } + nb.subs = append(nb.subs, announceSub) + + // Announce our presence + if err := nb.AnnouncePresence(); err != nil { + logger.WarnCF("swarm", "Failed to announce presence", map[string]interface{}{ + "error": err.Error(), + }) + } + + logger.InfoCF("swarm", "NATS bridge started", map[string]interface{}{ + "node_id": nb.nodeInfo.ID, + "capabilities": fmt.Sprintf("%v", nb.nodeInfo.Capabilities), + }) + + return nil +} + +// Stop gracefully stops the bridge +func (nb *NATSBridge) Stop() error { + nb.mu.Lock() + nb.running = false + nb.mu.Unlock() + + // Unsubscribe all + for _, sub := range nb.subs { + if err := sub.Unsubscribe(); err != nil { + logger.WarnCF("swarm", "Failed to unsubscribe", map[string]interface{}{ + "error": err.Error(), + }) + } + } + + if nb.conn != nil && !nb.conn.IsClosed() { + // Announce shutdown + shutdownSubject := fmt.Sprintf(SubjectSystemShutdown, nb.nodeInfo.ID) + _ = nb.conn.Publish(shutdownSubject, []byte(nb.nodeInfo.ID)) + + // Drain and close + return nb.conn.Drain() + } + return nil +} + +// IsConnected returns true if connected to NATS +func (nb *NATSBridge) IsConnected() bool { + return nb.conn != nil && nb.conn.IsConnected() +} + +// AnnouncePresence broadcasts this node's presence +func (nb *NATSBridge) AnnouncePresence() error { + announce := DiscoveryAnnounce{ + Node: *nb.nodeInfo, + Timestamp: time.Now().UnixMilli(), + } + data, err := json.Marshal(announce) + if err != nil { + return fmt.Errorf("failed to marshal announcement: %w", err) + } + return nb.conn.Publish(SubjectDiscoveryAnnounce, data) +} + +// PublishTask sends a task to a specific node or broadcasts by capability +func (nb *NATSBridge) PublishTask(task *SwarmTask) error { + data, err := json.Marshal(task) + if err != nil { + return fmt.Errorf("failed to marshal task: %w", err) + } + + var subject string + if task.AssignedTo != "" { + subject = fmt.Sprintf(SubjectTaskAssign, task.AssignedTo) + } else { + subject = fmt.Sprintf(SubjectTaskBroadcast, task.Capability) + } + + return nb.conn.Publish(subject, data) +} + +// PublishTaskResult publishes the result of a completed task +func (nb *NATSBridge) PublishTaskResult(result *TaskResult) error { + data, err := json.Marshal(result) + if err != nil { + return fmt.Errorf("failed to marshal task result: %w", err) + } + subject := fmt.Sprintf(SubjectTaskResult, result.TaskID) + return nb.conn.Publish(subject, data) +} + +// PublishTaskProgress publishes progress update for a task +func (nb *NATSBridge) PublishTaskProgress(progress *TaskProgress) error { + data, err := json.Marshal(progress) + if err != nil { + return fmt.Errorf("failed to marshal progress: %w", err) + } + subject := fmt.Sprintf(SubjectTaskProgress, progress.TaskID) + return nb.conn.Publish(subject, data) +} + +// PublishHeartbeat publishes a heartbeat message +func (nb *NATSBridge) PublishHeartbeat(hb *Heartbeat) error { + data, err := json.Marshal(hb) + if err != nil { + return fmt.Errorf("failed to marshal heartbeat: %w", err) + } + subject := fmt.Sprintf(SubjectHeartbeat, hb.NodeID) + return nb.conn.Publish(subject, data) +} + +// SubscribeTaskResult subscribes to results for a specific task +func (nb *NATSBridge) SubscribeTaskResult(taskID string, handler func(*TaskResult)) (*nats.Subscription, error) { + subject := fmt.Sprintf(SubjectTaskResult, taskID) + return nb.conn.Subscribe(subject, func(msg *nats.Msg) { + var result TaskResult + if err := json.Unmarshal(msg.Data, &result); err == nil { + handler(&result) + } + }) +} + +// SubscribeHeartbeat subscribes to heartbeats from a specific node +func (nb *NATSBridge) SubscribeHeartbeat(nodeID string, handler func(*Heartbeat)) (*nats.Subscription, error) { + subject := fmt.Sprintf(SubjectHeartbeat, nodeID) + return nb.conn.Subscribe(subject, func(msg *nats.Msg) { + var hb Heartbeat + if err := json.Unmarshal(msg.Data, &hb); err == nil { + handler(&hb) + } + }) +} + +// SubscribeAllHeartbeats subscribes to all heartbeat messages +func (nb *NATSBridge) SubscribeAllHeartbeats(handler func(*Heartbeat)) (*nats.Subscription, error) { + return nb.conn.Subscribe("picoclaw.swarm.heartbeat.*", func(msg *nats.Msg) { + var hb Heartbeat + if err := json.Unmarshal(msg.Data, &hb); err == nil { + handler(&hb) + } + }) +} + +// SubscribeShutdown subscribes to shutdown notices from a specific node +func (nb *NATSBridge) SubscribeShutdown(handler func(nodeID string)) (*nats.Subscription, error) { + return nb.conn.Subscribe("picoclaw.swarm.system.shutdown.*", func(msg *nats.Msg) { + handler(string(msg.Data)) + }) +} + +// RequestDiscovery sends a discovery query and collects responses +func (nb *NATSBridge) RequestDiscovery(query *DiscoveryQuery, timeout time.Duration) ([]*NodeInfo, error) { + data, err := json.Marshal(query) + if err != nil { + return nil, fmt.Errorf("failed to marshal discovery query: %w", err) + } + + var nodes []*NodeInfo + var mu sync.Mutex + + inbox := nb.conn.NewRespInbox() + sub, err := nb.conn.Subscribe(inbox, func(msg *nats.Msg) { + var node NodeInfo + if err := json.Unmarshal(msg.Data, &node); err == nil { + mu.Lock() + nodes = append(nodes, &node) + mu.Unlock() + } + }) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to discovery inbox: %w", err) + } + + if err := nb.conn.PublishRequest(SubjectDiscoveryQuery, inbox, data); err != nil { + sub.Unsubscribe() + return nil, fmt.Errorf("failed to publish discovery query: %w", err) + } + + time.Sleep(timeout) + sub.Unsubscribe() + + return nodes, nil +} + +// SetOnTaskReceived sets the callback for when a task is received +func (nb *NATSBridge) SetOnTaskReceived(handler func(*SwarmTask)) { + nb.mu.Lock() + defer nb.mu.Unlock() + nb.onTaskReceived = handler +} + +// SetOnNodeJoin sets the callback for when a node joins +func (nb *NATSBridge) SetOnNodeJoin(handler func(*NodeInfo)) { + nb.mu.Lock() + defer nb.mu.Unlock() + nb.onNodeJoin = handler +} + +// SetOnNodeLeave sets the callback for when a node leaves +func (nb *NATSBridge) SetOnNodeLeave(handler func(nodeID string)) { + nb.mu.Lock() + defer nb.mu.Unlock() + nb.onNodeLeave = handler +} + +// Message handlers + +func (nb *NATSBridge) handleTaskAssignment(msg *nats.Msg) { + var task SwarmTask + if err := json.Unmarshal(msg.Data, &task); err != nil { + logger.ErrorCF("swarm", "Failed to unmarshal task", map[string]interface{}{ + "error": err.Error(), + }) + return + } + + logger.InfoCF("swarm", "Received task assignment", map[string]interface{}{ + "task_id": task.ID, + "capability": task.Capability, + }) + + nb.mu.RLock() + handler := nb.onTaskReceived + nb.mu.RUnlock() + + if handler != nil { + handler(&task) + } +} + +func (nb *NATSBridge) handleTaskBroadcast(msg *nats.Msg) { + var task SwarmTask + if err := json.Unmarshal(msg.Data, &task); err != nil { + return + } + + logger.InfoCF("swarm", "Received broadcast task", map[string]interface{}{ + "task_id": task.ID, + "capability": task.Capability, + }) + + nb.mu.RLock() + handler := nb.onTaskReceived + nb.mu.RUnlock() + + if handler != nil { + handler(&task) + } +} + +func (nb *NATSBridge) handleDiscoveryQuery(msg *nats.Msg) { + var query DiscoveryQuery + if err := json.Unmarshal(msg.Data, &query); err != nil { + return + } + + // Check if we match the query criteria + if query.Role != "" && query.Role != nb.nodeInfo.Role { + return + } + if query.Capability != "" && !containsCapability(nb.nodeInfo.Capabilities, query.Capability) { + return + } + + // Reply with our node info + response, err := json.Marshal(nb.nodeInfo) + if err != nil { + return + } + if err := msg.Respond(response); err != nil { + logger.WarnCF("swarm", "Failed to respond to discovery query", map[string]interface{}{ + "error": err.Error(), + }) + } +} + +func (nb *NATSBridge) handleDiscoveryAnnounce(msg *nats.Msg) { + var announce DiscoveryAnnounce + if err := json.Unmarshal(msg.Data, &announce); err != nil { + return + } + + // Skip our own announcements + if announce.Node.ID == nb.nodeInfo.ID { + return + } + + logger.InfoCF("swarm", "Node joined swarm", map[string]interface{}{ + "node_id": announce.Node.ID, + "role": string(announce.Node.Role), + "capabilities": fmt.Sprintf("%v", announce.Node.Capabilities), + }) + + nb.mu.RLock() + handler := nb.onNodeJoin + nb.mu.RUnlock() + + if handler != nil { + handler(&announce.Node) + } +} + +// containsCapability checks if a capability is in the list +func containsCapability(caps []string, target string) bool { + for _, c := range caps { + if c == target { + return true + } + } + return false +} diff --git a/pkg/swarm/nats_test.go b/pkg/swarm/nats_test.go new file mode 100644 index 000000000..277eab620 --- /dev/null +++ b/pkg/swarm/nats_test.go @@ -0,0 +1,676 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestNATSBridge_Connect(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + fn func(t *testing.T) + }{ + { + name: "successful connect", + fn: func(t *testing.T) { + node := newTestNodeInfo("connect-test", RoleWorker, []string{"general"}, 4) + bridge := connectTestBridge(t, url, node) + defer bridge.Stop() + + if !bridge.IsConnected() { + t.Error("IsConnected() = false, want true") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, tt.fn) + } +} + +func TestNATSBridge_PublishSubscribeTaskAssignment(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + task *SwarmTask + workerID string + workerCaps []string + }{ + { + name: "direct task to specific node", + task: &SwarmTask{ + ID: "task-direct01", + Type: TaskTypeDirect, + Capability: "code", + Prompt: "write a function", + Status: TaskPending, + Priority: 1, + Timeout: 5000, + }, + workerID: "worker-1", + workerCaps: []string{"code"}, + }, + { + name: "task with context data", + task: &SwarmTask{ + ID: "task-ctx00001", + Type: TaskTypeDirect, + Capability: "research", + Prompt: "find information", + Status: TaskPending, + Context: map[string]interface{}{"key": "value", "num": float64(42)}, + Timeout: 5000, + }, + workerID: "worker-2", + workerCaps: []string{"research"}, + }, + { + name: "task with all fields set", + task: &SwarmTask{ + ID: "task-full0001", + WorkflowID: "wf-1", + ParentID: "task-parent", + Type: TaskTypeDirect, + Priority: 3, + Capability: "code", + Prompt: "complex task", + Context: map[string]interface{}{"a": "b"}, + Status: TaskAssigned, + CreatedAt: 1000000, + Timeout: 30000, + }, + workerID: "worker-3", + workerCaps: []string{"code"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create subscriber (worker) + workerNode := newTestNodeInfo(tt.workerID, RoleWorker, tt.workerCaps, 4) + workerBridge := connectTestBridge(t, url, workerNode) + defer workerBridge.Stop() + + var received atomic.Value + workerBridge.SetOnTaskReceived(func(task *SwarmTask) { + received.Store(task) + }) + + if err := workerBridge.Start(context.Background()); err != nil { + t.Fatalf("worker Start() error: %v", err) + } + + // Create publisher (coordinator) + coordNode := newTestNodeInfo("coord-pub", RoleCoordinator, nil, 1) + coordBridge := connectTestBridge(t, url, coordNode) + defer coordBridge.Stop() + + // Set AssignedTo and publish + tt.task.AssignedTo = tt.workerID + if err := coordBridge.PublishTask(tt.task); err != nil { + t.Fatalf("PublishTask() error: %v", err) + } + + // Wait for delivery + ok := waitFor(t, 2*time.Second, func() bool { + return received.Load() != nil + }) + if !ok { + t.Fatal("timed out waiting for task delivery") + } + + got := received.Load().(*SwarmTask) + if got.ID != tt.task.ID { + t.Errorf("received ID = %q, want %q", got.ID, tt.task.ID) + } + if got.Prompt != tt.task.Prompt { + t.Errorf("received Prompt = %q, want %q", got.Prompt, tt.task.Prompt) + } + if got.Capability != tt.task.Capability { + t.Errorf("received Capability = %q, want %q", got.Capability, tt.task.Capability) + } + if got.Priority != tt.task.Priority { + t.Errorf("received Priority = %d, want %d", got.Priority, tt.task.Priority) + } + }) + } +} + +func TestNATSBridge_PublishSubscribeBroadcast(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + capability string + workerACaps []string + workerBCaps []string + expectReceived bool + }{ + { + name: "broadcast by capability", + capability: "code", + workerACaps: []string{"code"}, + workerBCaps: []string{"code"}, + expectReceived: true, + }, + { + name: "broadcast to unmatched capability", + capability: "ml", + workerACaps: []string{"code"}, + workerBCaps: []string{"research"}, + expectReceived: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create two workers + nodeA := newTestNodeInfo("bcast-a", RoleWorker, tt.workerACaps, 4) + bridgeA := connectTestBridge(t, url, nodeA) + defer bridgeA.Stop() + + nodeB := newTestNodeInfo("bcast-b", RoleWorker, tt.workerBCaps, 4) + bridgeB := connectTestBridge(t, url, nodeB) + defer bridgeB.Stop() + + var receivedCount atomic.Int32 + handler := func(task *SwarmTask) { + receivedCount.Add(1) + } + bridgeA.SetOnTaskReceived(handler) + bridgeB.SetOnTaskReceived(handler) + + if err := bridgeA.Start(context.Background()); err != nil { + t.Fatalf("bridgeA Start() error: %v", err) + } + if err := bridgeB.Start(context.Background()); err != nil { + t.Fatalf("bridgeB Start() error: %v", err) + } + + // Give subscriptions time to propagate + time.Sleep(50 * time.Millisecond) + + // Publish broadcast task (no AssignedTo) + coordNode := newTestNodeInfo("bcast-coord", RoleCoordinator, nil, 1) + coordBridge := connectTestBridge(t, url, coordNode) + defer coordBridge.Stop() + + task := &SwarmTask{ + ID: "task-bcast001", + Type: TaskTypeBroadcast, + Capability: tt.capability, + Prompt: "broadcast test", + Timeout: 5000, + } + if err := coordBridge.PublishTask(task); err != nil { + t.Fatalf("PublishTask() error: %v", err) + } + + if tt.expectReceived { + // At least one worker should receive it (queue group delivers to one) + ok := waitFor(t, 2*time.Second, func() bool { + return receivedCount.Load() >= 1 + }) + if !ok { + t.Error("expected at least one worker to receive broadcast, got 0") + } + } else { + // No one should receive + time.Sleep(200 * time.Millisecond) + if receivedCount.Load() > 0 { + t.Errorf("expected no workers to receive broadcast, got %d", receivedCount.Load()) + } + } + }) + } +} + +func TestNATSBridge_DiscoveryRoundTrip(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + responderRole NodeRole + responderCaps []string + queryRole NodeRole + queryCap string + expectFound bool + }{ + { + name: "query all nodes", + responderRole: RoleWorker, + responderCaps: []string{"code"}, + queryRole: "", + queryCap: "", + expectFound: true, + }, + { + name: "query by role worker", + responderRole: RoleWorker, + responderCaps: []string{"code"}, + queryRole: RoleWorker, + queryCap: "", + expectFound: true, + }, + { + name: "query by capability", + responderRole: RoleWorker, + responderCaps: []string{"code", "research"}, + queryRole: "", + queryCap: "research", + expectFound: true, + }, + { + name: "query with no matches - role mismatch", + responderRole: RoleWorker, + responderCaps: []string{"code"}, + queryRole: RoleSpecialist, + queryCap: "", + expectFound: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Responder bridge + respNode := newTestNodeInfo("disc-resp", tt.responderRole, tt.responderCaps, 4) + respBridge := connectTestBridge(t, url, respNode) + defer respBridge.Stop() + + if err := respBridge.Start(context.Background()); err != nil { + t.Fatalf("responder Start() error: %v", err) + } + + // Requester bridge + reqNode := newTestNodeInfo("disc-req", RoleCoordinator, nil, 1) + reqBridge := connectTestBridge(t, url, reqNode) + defer reqBridge.Stop() + + // Give subscriptions time to propagate + time.Sleep(50 * time.Millisecond) + + query := &DiscoveryQuery{ + RequesterID: "disc-req", + Role: tt.queryRole, + Capability: tt.queryCap, + } + + nodes, err := reqBridge.RequestDiscovery(query, 500*time.Millisecond) + if err != nil { + t.Fatalf("RequestDiscovery() error: %v", err) + } + + if tt.expectFound { + if len(nodes) == 0 { + t.Error("expected at least 1 node, got 0") + } else { + if nodes[0].ID != "disc-resp" { + t.Errorf("node ID = %q, want %q", nodes[0].ID, "disc-resp") + } + } + } else { + if len(nodes) != 0 { + t.Errorf("expected 0 nodes, got %d", len(nodes)) + } + } + }) + } +} + +func TestNATSBridge_HeartbeatPublishSubscribe(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + fn func(t *testing.T) + }{ + { + name: "heartbeat received", + fn: func(t *testing.T) { + pubNode := newTestNodeInfo("hb-pub", RoleWorker, []string{"code"}, 4) + pubBridge := connectTestBridge(t, url, pubNode) + defer pubBridge.Stop() + + subNode := newTestNodeInfo("hb-sub", RoleCoordinator, nil, 1) + subBridge := connectTestBridge(t, url, subNode) + defer subBridge.Stop() + + var received atomic.Value + _, err := subBridge.SubscribeHeartbeat("hb-pub", func(hb *Heartbeat) { + received.Store(hb) + }) + if err != nil { + t.Fatalf("SubscribeHeartbeat() error: %v", err) + } + + // Flush to ensure subscription is registered on server before publishing + subBridge.conn.Flush() + time.Sleep(50 * time.Millisecond) + + hb := &Heartbeat{ + NodeID: "hb-pub", + Status: StatusOnline, + Load: 0.5, + TasksRunning: 2, + Timestamp: time.Now().UnixMilli(), + } + if err := pubBridge.PublishHeartbeat(hb); err != nil { + t.Fatalf("PublishHeartbeat() error: %v", err) + } + + ok := waitFor(t, 2*time.Second, func() bool { + return received.Load() != nil + }) + if !ok { + t.Fatal("timed out waiting for heartbeat") + } + + got := received.Load().(*Heartbeat) + if got.NodeID != hb.NodeID { + t.Errorf("NodeID = %q, want %q", got.NodeID, hb.NodeID) + } + if got.Status != hb.Status { + t.Errorf("Status = %q, want %q", got.Status, hb.Status) + } + if got.Load != hb.Load { + t.Errorf("Load = %f, want %f", got.Load, hb.Load) + } + if got.TasksRunning != hb.TasksRunning { + t.Errorf("TasksRunning = %d, want %d", got.TasksRunning, hb.TasksRunning) + } + }, + }, + { + name: "wildcard heartbeat subscription", + fn: func(t *testing.T) { + subNode := newTestNodeInfo("hb-wild-sub", RoleCoordinator, nil, 1) + subBridge := connectTestBridge(t, url, subNode) + defer subBridge.Stop() + + var mu sync.Mutex + received := make(map[string]*Heartbeat) + _, err := subBridge.SubscribeAllHeartbeats(func(hb *Heartbeat) { + mu.Lock() + received[hb.NodeID] = hb + mu.Unlock() + }) + if err != nil { + t.Fatalf("SubscribeAllHeartbeats() error: %v", err) + } + + // Flush to ensure subscription is registered on server + subBridge.conn.Flush() + time.Sleep(50 * time.Millisecond) + + // Publish from two different nodes + for _, nodeID := range []string{"hb-wild-a", "hb-wild-b"} { + node := newTestNodeInfo(nodeID, RoleWorker, []string{"code"}, 4) + bridge := connectTestBridge(t, url, node) + defer bridge.Stop() + + hb := &Heartbeat{ + NodeID: nodeID, + Status: StatusOnline, + Load: 0.1, + Timestamp: time.Now().UnixMilli(), + } + if err := bridge.PublishHeartbeat(hb); err != nil { + t.Fatalf("PublishHeartbeat(%s) error: %v", nodeID, err) + } + } + + ok := waitFor(t, 2*time.Second, func() bool { + mu.Lock() + defer mu.Unlock() + return len(received) >= 2 + }) + if !ok { + mu.Lock() + t.Fatalf("timed out waiting for heartbeats, got %d", len(received)) + mu.Unlock() + } + + mu.Lock() + defer mu.Unlock() + if _, ok := received["hb-wild-a"]; !ok { + t.Error("missing heartbeat from hb-wild-a") + } + if _, ok := received["hb-wild-b"]; !ok { + t.Error("missing heartbeat from hb-wild-b") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, tt.fn) + } +} + +func TestNATSBridge_TaskResultPublishSubscribe(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + result TaskResult + }{ + { + name: "success result", + result: TaskResult{ + TaskID: "task-res00001", + NodeID: "worker-1", + Status: "done", + Result: "completed successfully", + CompletedAt: time.Now().UnixMilli(), + }, + }, + { + name: "failure result", + result: TaskResult{ + TaskID: "task-res00002", + NodeID: "worker-2", + Status: "failed", + Error: "out of memory", + CompletedAt: time.Now().UnixMilli(), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Subscriber + subNode := newTestNodeInfo("result-sub", RoleCoordinator, nil, 1) + subBridge := connectTestBridge(t, url, subNode) + defer subBridge.Stop() + + var received atomic.Value + _, err := subBridge.SubscribeTaskResult(tt.result.TaskID, func(r *TaskResult) { + received.Store(r) + }) + if err != nil { + t.Fatalf("SubscribeTaskResult() error: %v", err) + } + + // Flush to ensure subscription is registered on server + subBridge.conn.Flush() + time.Sleep(50 * time.Millisecond) + + // Publisher + pubNode := newTestNodeInfo("result-pub", RoleWorker, []string{"code"}, 4) + pubBridge := connectTestBridge(t, url, pubNode) + defer pubBridge.Stop() + + if err := pubBridge.PublishTaskResult(&tt.result); err != nil { + t.Fatalf("PublishTaskResult() error: %v", err) + } + + ok := waitFor(t, 2*time.Second, func() bool { + return received.Load() != nil + }) + if !ok { + t.Fatal("timed out waiting for task result") + } + + got := received.Load().(*TaskResult) + if got.TaskID != tt.result.TaskID { + t.Errorf("TaskID = %q, want %q", got.TaskID, tt.result.TaskID) + } + if got.Status != tt.result.Status { + t.Errorf("Status = %q, want %q", got.Status, tt.result.Status) + } + if got.Result != tt.result.Result { + t.Errorf("Result = %q, want %q", got.Result, tt.result.Result) + } + if got.Error != tt.result.Error { + t.Errorf("Error = %q, want %q", got.Error, tt.result.Error) + } + }) + } +} + +func TestNATSBridge_ShutdownAnnouncement(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + // Subscriber + subNode := newTestNodeInfo("shutdown-sub", RoleCoordinator, nil, 1) + subBridge := connectTestBridge(t, url, subNode) + defer subBridge.Stop() + + var receivedID atomic.Value + _, err := subBridge.SubscribeShutdown(func(nodeID string) { + receivedID.Store(nodeID) + }) + if err != nil { + t.Fatalf("SubscribeShutdown() error: %v", err) + } + + // Flush to ensure subscription is registered on server + subBridge.conn.Flush() + time.Sleep(50 * time.Millisecond) + + // Bridge that will shut down + shutdownNode := newTestNodeInfo("shutdown-node", RoleWorker, []string{"code"}, 4) + shutdownBridge := connectTestBridge(t, url, shutdownNode) + + if err := shutdownBridge.Start(context.Background()); err != nil { + t.Fatalf("Start() error: %v", err) + } + + // Stop the bridge - this should publish shutdown + if err := shutdownBridge.Stop(); err != nil { + t.Fatalf("Stop() error: %v", err) + } + + ok := waitFor(t, 2*time.Second, func() bool { + return receivedID.Load() != nil + }) + if !ok { + t.Fatal("timed out waiting for shutdown announcement") + } + + got := receivedID.Load().(string) + if got != "shutdown-node" { + t.Errorf("received nodeID = %q, want %q", got, "shutdown-node") + } +} + +func TestNATSBridge_Stop(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + fn func(t *testing.T) + }{ + { + name: "graceful stop drains connection", + fn: func(t *testing.T) { + node := newTestNodeInfo("stop-test", RoleWorker, []string{"code"}, 4) + bridge := connectTestBridge(t, url, node) + + if err := bridge.Start(context.Background()); err != nil { + t.Fatalf("Start() error: %v", err) + } + + if !bridge.IsConnected() { + t.Error("IsConnected() = false before Stop") + } + + if err := bridge.Stop(); err != nil { + t.Fatalf("Stop() error: %v", err) + } + + // After drain, connection should eventually close + ok := waitFor(t, 2*time.Second, func() bool { + return !bridge.IsConnected() + }) + if !ok { + t.Error("bridge still connected after Stop()") + } + }, + }, + { + name: "stop publishes shutdown", + fn: func(t *testing.T) { + // Listener bridge + listenerNode := newTestNodeInfo("stop-listener", RoleCoordinator, nil, 1) + listenerBridge := connectTestBridge(t, url, listenerNode) + defer listenerBridge.Stop() + + var gotShutdown atomic.Value + _, err := listenerBridge.SubscribeShutdown(func(nodeID string) { + gotShutdown.Store(nodeID) + }) + if err != nil { + t.Fatalf("SubscribeShutdown() error: %v", err) + } + + // Bridge to stop + stopNode := newTestNodeInfo("stop-sender", RoleWorker, []string{"code"}, 4) + stopBridge := connectTestBridge(t, url, stopNode) + + if err := stopBridge.Start(context.Background()); err != nil { + t.Fatalf("Start() error: %v", err) + } + + if err := stopBridge.Stop(); err != nil { + t.Fatalf("Stop() error: %v", err) + } + + ok := waitFor(t, 2*time.Second, func() bool { + return gotShutdown.Load() != nil + }) + if !ok { + t.Fatal("timed out waiting for shutdown from Stop()") + } + + if gotShutdown.Load().(string) != "stop-sender" { + t.Errorf("shutdown nodeID = %q, want %q", gotShutdown.Load().(string), "stop-sender") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, tt.fn) + } +} diff --git a/pkg/swarm/partition.go b/pkg/swarm/partition.go new file mode 100644 index 000000000..d21d69553 --- /dev/null +++ b/pkg/swarm/partition.go @@ -0,0 +1,381 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "strings" + "sync" + + "github.com/sipeed/picoclaw/pkg/identity" +) + +// Partition manages H-id based partitioning for NATS subjects +type Partition struct { + hid string + subjects map[string]string // original subject -> partitioned subject + mu sync.RWMutex +} + +// NewPartition creates a new partition manager +func NewPartition(hid string) *Partition { + return &Partition{ + hid: hid, + subjects: make(map[string]string), + } +} + +// NewPartitionWithIdentity creates a partition manager from an identity +func NewPartitionWithIdentity(id *identity.Identity) *Partition { + if id == nil { + return NewPartition("") + } + return NewPartition(id.HID) +} + +// Partitionize adds H-id partitioning to a subject +func (p *Partition) Partitionize(subject string) string { + if p.hid == "" { + return subject // No partitioning + } + + p.mu.RLock() + if cached, ok := p.subjects[subject]; ok { + p.mu.RUnlock() + return cached + } + p.mu.RUnlock() + + // Parse and add H-id + parsed := ParseSubject(subject) + if parsed == nil { + return subject + } + + // Skip if already partitioned or is cross-domain + if parsed.HID != "" || parsed.IsCrossDomain() { + return subject + } + + // Add H-id to domain part + builder := NewSubjectBuilder().WithHID(p.hid) + partitioned := builder.Build(parsed.Domain, parsed.Parts...) + + p.mu.Lock() + p.subjects[subject] = partitioned + p.mu.Unlock() + + return partitioned +} + +// Departitionize removes H-id partitioning from a subject +func (p *Partition) Departitionize(subject string) string { + parsed := ParseSubject(subject) + if parsed == nil { + return subject + } + + // Not a partitioned subject + if parsed.HID == "" { + return subject + } + + // Remove H-id from parts + parsed.HID = "" + + return parsed.String() +} + +// IsInPartition checks if a subject belongs to this partition's H-id +func (p *Partition) IsInPartition(subject string) bool { + if p.hid == "" { + return true // No partitioning means all subjects match + } + + parsed := ParseSubject(subject) + if parsed == nil { + return false + } + + // Check cross-domain subjects + if parsed.IsCrossDomain() { + return parsed.ToHID == p.hid + } + + // Check regular subjects + return parsed.HID == p.hid || parsed.HID == "" +} + +// GetHID returns the H-id for this partition +func (p *Partition) GetHID() string { + return p.hid +} + +// SetHID sets the H-id for this partition and clears the cache +func (p *Partition) SetHID(hid string) { + p.mu.Lock() + defer p.mu.Unlock() + p.hid = hid + p.subjects = make(map[string]string) +} + +// SubscribeFilter creates a wildcard subscription subject for this partition +func (p *Partition) SubscribeFilter(domain SubjectDomain, suffix string) string { + builder := NewSubjectBuilder().WithHID(p.hid) + return builder.BuildWildcard(domain, suffix) +} + +// GetAllHIDSubjects gets all subject variants for a given base subject +// This is useful for publishing to multiple partitions +func GetAllHIDSubjects(subject string, hids []string) []string { + parsed := ParseSubject(subject) + if parsed == nil { + return []string{subject} + } + + if parsed.HID != "" || parsed.IsCrossDomain() { + return []string{subject} // Already has H-id or is cross-domain + } + + subjects := make([]string, 0, len(hids)) + for _, hid := range hids { + builder := NewSubjectBuilder().WithHID(hid) + partitioned := builder.Build(parsed.Domain, parsed.Parts...) + subjects = append(subjects, partitioned) + } + + return subjects +} + +// IsSubjectForHID checks if a subject is intended for a specific H-id +func IsSubjectForHID(subject, hid string) bool { + parsed := ParseSubject(subject) + if parsed == nil { + return false + } + + if parsed.IsCrossDomain() { + return parsed.ToHID == hid + } + + return parsed.HID == hid +} + +// ExtractHIDFromSubject extracts the H-id from a subject +func ExtractHIDFromSubject(subject string) string { + parsed := ParseSubject(subject) + if parsed == nil { + return "" + } + + if parsed.IsCrossDomain() { + return parsed.FromHID + } + + return parsed.HID +} + +// SubjectWithHID creates a new subject with the given H-id +func SubjectWithHID(subject, hid string) string { + parsed := ParseSubject(subject) + if parsed == nil { + return subject + } + + if parsed.IsCrossDomain() { + // For cross-domain, modify from HID + parsed.FromHID = hid + return parsed.String() + } + + if parsed.HID != "" { + // Already has H-id, replace it + parsed.HID = hid + return parsed.String() + } + + // Add H-id + builder := NewSubjectBuilder().WithHID(hid) + return builder.Build(parsed.Domain, parsed.Parts...) +} + +// SubjectRouter routes subjects to appropriate H-id partitions +type SubjectRouter struct { + partitions map[string]*Partition // hid -> partition + mu sync.RWMutex + localHID string +} + +// NewSubjectRouter creates a new subject router +func NewSubjectRouter(localHID string) *SubjectRouter { + return &SubjectRouter{ + partitions: make(map[string]*Partition), + localHID: localHID, + } +} + +// AddPartition adds a partition for an H-id +func (r *SubjectRouter) AddPartition(hid string) *Partition { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.partitions[hid]; !exists { + r.partitions[hid] = NewPartition(hid) + } + + return r.partitions[hid] +} + +// RemovePartition removes a partition +func (r *SubjectRouter) RemovePartition(hid string) { + r.mu.Lock() + defer r.mu.Unlock() + delete(r.partitions, hid) +} + +// GetPartition gets a partition for an H-id +func (r *SubjectRouter) GetPartition(hid string) *Partition { + r.mu.RLock() + defer r.mu.RUnlock() + return r.partitions[hid] +} + +// GetLocalPartition gets the local partition +func (r *SubjectRouter) GetLocalPartition() *Partition { + r.mu.RLock() + defer r.mu.RUnlock() + return r.partitions[r.localHID] +} + +// Route determines which partition(s) a subject should be routed to +func (r *SubjectRouter) Route(subject string) []string { + parsed := ParseSubject(subject) + if parsed == nil { + return []string{r.localHID} + } + + // Cross-domain subjects route to specific H-id + if parsed.IsCrossDomain() { + return []string{parsed.ToHID} + } + + // Subjects with H-id route to that H-id + if parsed.HID != "" { + return []string{parsed.HID} + } + + // Broadcast subjects route to all partitions + hids := make([]string, 0, len(r.partitions)) + r.mu.RLock() + for hid := range r.partitions { + hids = append(hids, hid) + } + r.mu.RUnlock() + + return hids +} + +// IsLocal checks if a subject is for the local partition +func (r *SubjectRouter) IsLocal(subject string) bool { + hids := r.Route(subject) + for _, hid := range hids { + if hid == r.localHID { + return true + } + } + return false +} + +// TransformForPartition transforms a subject for a specific target partition +func (r *SubjectRouter) TransformForPartition(subject, targetHID string) string { + parsed := ParseSubject(subject) + if parsed == nil { + return subject + } + + // If it's a cross-domain subject, check if it's for us + if parsed.IsCrossDomain() { + if parsed.ToHID == r.localHID { + // This is for us, transform to local + builder := NewSubjectBuilder().WithHID(r.localHID) + return builder.Build(parsed.Domain, parsed.Parts...) + } + return subject // Not for us, leave as-is + } + + // Regular subject - add target H-id if needed + if parsed.HID == "" { + builder := NewSubjectBuilder().WithHID(targetHID) + return builder.Build(parsed.Domain, parsed.Parts...) + } + + return subject // Already has H-id +} + +// CreateCrossSubject creates a cross-domain subject from local to remote +func (r *SubjectRouter) CreateCrossSubject(localSubject, remoteHID string) string { + parsed := ParseSubject(localSubject) + if parsed == nil { + return localSubject + } + + // Build cross-domain subject + builder := NewSubjectBuilder() + parts := append([]string{string(parsed.Domain)}, parsed.Parts...) + return builder.BuildCross(r.localHID, remoteHID, parts...) +} + +// ValidateSubject validates that a subject is properly formatted +func ValidateSubject(subject string) bool { + if subject == "" { + return false + } + + parts := strings.Split(subject, ".") + if len(parts) < 2 { + return false + } + + // Check for valid prefix + if parts[0] != "picoclaw" { + return false + } + + // Check for valid domain + validDomains := map[SubjectDomain]bool{ + DomainSwarm: true, + DomainMemory: true, + DomainTask: true, + DomainDiscovery: true, + DomainSystem: true, + DomainCross: true, + } + + if !validDomains[SubjectDomain(parts[1])] { + return false + } + + return true +} + +// SanitizeSubject sanitizes a subject for safe use +func SanitizeSubject(subject string) string { + // Remove any whitespace + subject = strings.TrimSpace(subject) + + // Ensure it starts with picoclaw. + if !strings.HasPrefix(subject, "picoclaw.") { + return subject + } + + // Convert to lowercase (H-ids are case-insensitive) + parts := strings.Split(subject, ".") + for i, part := range parts { + if i > 1 && !strings.HasPrefix(part, ">") && !strings.HasPrefix(part, "*") { + parts[i] = strings.ToLower(part) + } + } + + return strings.Join(parts, ".") +} diff --git a/pkg/swarm/registry.go b/pkg/swarm/registry.go new file mode 100644 index 000000000..80d80a47a --- /dev/null +++ b/pkg/swarm/registry.go @@ -0,0 +1,389 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + // CapabilitySubject is the base subject for capability discovery + CapabilitySubject = "picoclaw.swarm.capability" + // CapabilityKVBucket is the KV bucket for storing capabilities + CapabilityKVBucket = "PICOCLAW_CAPABILITIES" +) + +// CapabilityRegistry manages dynamic capability registration and discovery +type CapabilityRegistry struct { + caps sync.Map // map[string]*Capability + nodeInfo *NodeInfo + js nats.JetStreamContext + nc *nats.Conn + mu sync.RWMutex +} + +// NewCapabilityRegistry creates a new capability registry +func NewCapabilityRegistry(nodeInfo *NodeInfo, js nats.JetStreamContext, nc *nats.Conn) *CapabilityRegistry { + return &CapabilityRegistry{ + nodeInfo: nodeInfo, + js: js, + nc: nc, + } +} + +// Initialize sets up the capability registry +func (r *CapabilityRegistry) Initialize(ctx context.Context) error { + // Create KV bucket for capabilities + _, err := r.js.KeyValue(CapabilityKVBucket) + if err != nil { + // Bucket doesn't exist, create it + _, err = r.js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: CapabilityKVBucket, + Description: "PicoClaw swarm capability registry", + MaxBytes: 1024 * 1024 * 10, // 10MB + Storage: nats.FileStorage, + Replicas: 1, + }) + if err != nil { + return fmt.Errorf("failed to create capabilities bucket: %w", err) + } + logger.InfoC("swarm", fmt.Sprintf("Created capability bucket: %s", CapabilityKVBucket)) + } + + // Subscribe to capability announcements for discovery + _, err = r.nc.Subscribe(CapabilitySubject+".announce", func(msg *nats.Msg) { + r.handleCapabilityAnnouncement(msg) + }) + if err != nil { + return fmt.Errorf("failed to subscribe to capability announcements: %w", err) + } + + // Subscribe to capability queries + _, err = r.nc.Subscribe(CapabilitySubject+".query", func(msg *nats.Msg) { + r.handleCapabilityQuery(msg) + }) + if err != nil { + return fmt.Errorf("failed to subscribe to capability queries: %w", err) + } + + logger.InfoC("swarm", "Capability registry initialized") + + return nil +} + +// Register registers a capability for this node +func (r *CapabilityRegistry) Register(name, description, version string, metadata map[string]interface{}) error { + r.mu.Lock() + defer r.mu.Unlock() + + cap := &Capability{ + Name: name, + Description: description, + Version: version, + Metadata: metadata, + NodeID: r.nodeInfo.ID, + RegisteredAt: time.Now().UnixMilli(), + } + + // Store locally + r.caps.Store(name, cap) + + // Store in KV for cross-node discovery + kv, err := r.js.KeyValue(CapabilityKVBucket) + if err != nil { + return fmt.Errorf("failed to get capabilities bucket: %w", err) + } + + key := fmt.Sprintf("%s:%s", r.nodeInfo.ID, name) + data, err := json.Marshal(cap) + if err != nil { + return fmt.Errorf("failed to marshal capability: %w", err) + } + + _, err = kv.Put(key, data) + if err != nil { + return fmt.Errorf("failed to store capability: %w", err) + } + + // Announce to swarm + announcement := map[string]interface{}{ + "type": "register", + "capability": cap, + "timestamp": time.Now().UnixMilli(), + } + announcementData, _ := json.Marshal(announcement) + + err = r.nc.Publish(CapabilitySubject+".announce", announcementData) + if err != nil { + logger.WarnCF("swarm", "Failed to announce capability", map[string]interface{}{ + "name": name, + "error": err.Error(), + }) + } + + logger.InfoCF("swarm", "Registered capability", map[string]interface{}{ + "name": name, + "version": version, + }) + + return nil +} + +// Unregister removes a capability +func (r *CapabilityRegistry) Unregister(name string) error { + r.mu.Lock() + defer r.mu.Unlock() + + // Remove from local store + r.caps.Delete(name) + + // Remove from KV + kv, err := r.js.KeyValue(CapabilityKVBucket) + if err != nil { + return fmt.Errorf("failed to get capabilities bucket: %w", err) + } + + key := fmt.Sprintf("%s:%s", r.nodeInfo.ID, name) + err = kv.Delete(key) + if err != nil && err != nats.ErrKeyNotFound { + return fmt.Errorf("failed to delete capability: %w", err) + } + + // Announce removal + announcement := map[string]interface{}{ + "type": "unregister", + "node_id": r.nodeInfo.ID, + "name": name, + "timestamp": time.Now().UnixMilli(), + } + announcementData, err := json.Marshal(announcement) + if err != nil { + return fmt.Errorf("failed to marshal announcement: %w", err) + } + + if err := r.nc.Publish(CapabilitySubject+".announce", announcementData); err != nil { + logger.WarnCF("swarm", "Failed to publish unregister announcement", map[string]interface{}{ + "error": err.Error(), + }) + } + + logger.InfoCF("swarm", "Unregistered capability", map[string]interface{}{ + "name": name, + }) + + return nil +} + +// Get retrieves a local capability by name +func (r *CapabilityRegistry) Get(name string) (*Capability, bool) { + val, ok := r.caps.Load(name) + if !ok { + return nil, false + } + cap, ok := val.(*Capability) + return cap, ok +} + +// List returns all local capabilities +func (r *CapabilityRegistry) List() []Capability { + caps := make([]Capability, 0) + r.caps.Range(func(key, value interface{}) bool { + if cap, ok := value.(*Capability); ok { + caps = append(caps, *cap) + } + return true + }) + return caps +} + +// Discover finds capabilities across the swarm +func (r *CapabilityRegistry) Discover(ctx context.Context, name, version string) ([]Capability, error) { + // Query KV store for capabilities + kv, err := r.js.KeyValue(CapabilityKVBucket) + if err != nil { + return nil, fmt.Errorf("failed to get capabilities bucket: %w", err) + } + + // List all keys + watcher, err := kv.WatchAll(nats.Context(ctx)) + if err != nil { + return nil, fmt.Errorf("failed to watch capabilities: %w", err) + } + defer watcher.Stop() + + capabilities := make([]Capability, 0) + timeout := time.After(5 * time.Second) + +collect: + for { + select { + case entry := <-watcher.Updates(): + if entry == nil { + break collect + } + + var cap Capability + if err := json.Unmarshal(entry.Value(), &cap); err != nil { + continue + } + + // Filter by name if specified + if name != "" && cap.Name != name { + continue + } + + // Filter by version if specified + if version != "" && cap.Version != version { + continue + } + + capabilities = append(capabilities, cap) + + case <-timeout: + break collect + } + } + + return capabilities, nil +} + +// DiscoverByNode finds all capabilities provided by a specific node +func (r *CapabilityRegistry) DiscoverByNode(ctx context.Context, nodeID string) ([]Capability, error) { + return r.Discover(ctx, "", "") +} + +// handleCapabilityAnnouncement processes capability announcements from other nodes +func (r *CapabilityRegistry) handleCapabilityAnnouncement(msg *nats.Msg) { + var announcement struct { + Type string `json:"type"` + Capability *Capability `json:"capability,omitempty"` + NodeID string `json:"node_id,omitempty"` + Name string `json:"name,omitempty"` + } + + if err := json.Unmarshal(msg.Data, &announcement); err != nil { + return + } + + // Ignore our own announcements + switch announcement.Type { + case "register": + if announcement.Capability != nil && announcement.Capability.NodeID != r.nodeInfo.ID { + logger.DebugCF("swarm", "Capability announced by peer", map[string]interface{}{ + "node_id": announcement.Capability.NodeID, + "name": announcement.Capability.Name, + }) + // Store peer capability + r.caps.Store(announcement.Capability.NodeID+":"+announcement.Capability.Name, announcement.Capability) + } + case "unregister": + if announcement.NodeID != r.nodeInfo.ID { + logger.DebugCF("swarm", "Capability unregistered by peer", map[string]interface{}{ + "node_id": announcement.NodeID, + "name": announcement.Name, + }) + r.caps.Delete(announcement.NodeID + ":" + announcement.Name) + } + } +} + +// handleCapabilityQuery responds to capability discovery queries +func (r *CapabilityRegistry) handleCapabilityQuery(msg *nats.Msg) { + var query CapabilityRequest + if err := json.Unmarshal(msg.Data, &query); err != nil { + return + } + + // Don't respond to our own queries + if query.RequesterID == r.nodeInfo.ID { + return + } + + // Get our capabilities + caps := r.List() + + // Filter if requested + if query.Capability != "" { + filtered := make([]Capability, 0) + for _, cap := range caps { + if cap.Name == query.Capability { + if query.Version == "" || cap.Version == query.Version { + filtered = append(filtered, cap) + } + } + } + caps = filtered + } + + // Send response + response := CapabilityResponse{ + Capabilities: caps, + RequestID: query.RequesterID, + Timestamp: time.Now().UnixMilli(), + } + + data, _ := json.Marshal(response) + _ = msg.Respond(data) +} + +// QueryCapabilities sends a query to discover capabilities from other nodes +func (r *CapabilityRegistry) QueryCapabilities(ctx context.Context, capability, version string) ([]Capability, error) { + // Send query + query := CapabilityRequest{ + RequesterID: r.nodeInfo.ID, + Capability: capability, + Version: version, + } + + queryData, _ := json.Marshal(query) + + // Create inbox for responses + inbox := nats.NewInbox() + sub, err := r.nc.SubscribeSync(inbox) + if err != nil { + return nil, fmt.Errorf("failed to create subscription: %w", err) + } + defer sub.Unsubscribe() + + // Publish query + err = r.nc.PublishRequest(CapabilitySubject+".query", inbox, queryData) + if err != nil { + return nil, fmt.Errorf("failed to publish query: %w", err) + } + + // Collect responses with timeout + capabilities := make([]Capability, 0) + timeout := time.After(3 * time.Second) + + for { + select { + case <-timeout: + return capabilities, nil + default: + msg, err := sub.NextMsg(500 * time.Millisecond) + if err != nil { + if err == nats.ErrTimeout { + continue + } + return capabilities, nil + } + + var response CapabilityResponse + if err := json.Unmarshal(msg.Data, &response); err != nil { + continue + } + + capabilities = append(capabilities, response.Capabilities...) + } + } +} diff --git a/pkg/swarm/registry_test.go b/pkg/swarm/registry_test.go new file mode 100644 index 000000000..3bac835b4 --- /dev/null +++ b/pkg/swarm/registry_test.go @@ -0,0 +1,102 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewCapabilityRegistry(t *testing.T) { + nodeInfo := &NodeInfo{ + ID: "test-node-1", + Role: RoleWorker, + } + + registry := NewCapabilityRegistry(nodeInfo, nil, nil) + assert.NotNil(t, registry) + assert.Equal(t, nodeInfo, registry.nodeInfo) +} + +func TestCapabilityRegistry_Register(t *testing.T) { + t.Skip("Requires NATS connection") + + /* + nodeInfo := &NodeInfo{ + ID: "test-node-1", + Role: RoleWorker, + } + + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + js, _ := nc.JetStream() + + registry := NewCapabilityRegistry(nodeInfo, js, nc) + err := registry.Initialize(context.Background()) + require.NoError(t, err) + + // Register a capability + err = registry.Register("test-cap", "A test capability", "1.0.0", nil) + require.NoError(t, err) + + // Check local storage + cap, ok := registry.Get("test-cap") + assert.True(t, ok) + assert.Equal(t, "test-cap", cap.Name) + assert.Equal(t, "A test capability", cap.Description) + */ +} + +func TestCapabilityRegistry_List(t *testing.T) { + nodeInfo := &NodeInfo{ + ID: "test-node-1", + Role: RoleWorker, + } + + registry := NewCapabilityRegistry(nodeInfo, nil, nil) + + // Add capabilities to local store + registry.caps.Store("cap-1", &Capability{Name: "cap-1"}) + registry.caps.Store("cap-2", &Capability{Name: "cap-2"}) + + caps := registry.List() + assert.Len(t, caps, 2) +} + +func TestCapabilityRegistry_Get(t *testing.T) { + nodeInfo := &NodeInfo{ + ID: "test-node-1", + Role: RoleWorker, + } + + registry := NewCapabilityRegistry(nodeInfo, nil, nil) + + expectedCap := &Capability{ + Name: "test-cap", + Description: "Test", + Version: "1.0.0", + NodeID: "test-node-1", + } + + registry.caps.Store("test-cap", expectedCap) + + cap, ok := registry.Get("test-cap") + assert.True(t, ok) + assert.Equal(t, expectedCap, cap) + + _, ok = registry.Get("non-existent") + assert.False(t, ok) +} + +func TestCapabilitySubject(t *testing.T) { + assert.Equal(t, "picoclaw.swarm.capability", CapabilitySubject) +} + +func TestCapabilityKVBucket(t *testing.T) { + assert.Equal(t, "PICOCLAW_CAPABILITIES", CapabilityKVBucket) +} diff --git a/pkg/swarm/specialist.go b/pkg/swarm/specialist.go new file mode 100644 index 000000000..a579f01cb --- /dev/null +++ b/pkg/swarm/specialist.go @@ -0,0 +1,302 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// SpecialistNode is a worker that specializes in specific capabilities +type SpecialistNode struct { + *Worker + registry *CapabilityRegistry + skillsDir string +} + +// NewSpecialistNode creates a new specialist node +func NewSpecialistNode( + cfg *config.SwarmConfig, + bridge *NATSBridge, + temporal *TemporalClient, + agentLoop *agent.AgentLoop, + provider providers.LLMProvider, + nodeInfo *NodeInfo, + js nats.JetStreamContext, + nc *nats.Conn, + skillsDir string, +) *SpecialistNode { + registry := NewCapabilityRegistry(nodeInfo, js, nc) + + // Create base worker + worker := NewWorker(cfg, bridge, temporal, agentLoop, provider, nodeInfo) + + return &SpecialistNode{ + Worker: worker, + registry: registry, + skillsDir: skillsDir, + } +} + +// Start initializes and starts the specialist node +func (s *SpecialistNode) Start(ctx context.Context) error { + // Initialize capability registry + if err := s.registry.Initialize(ctx); err != nil { + return fmt.Errorf("failed to initialize capability registry: %w", err) + } + + // Perform dynamic capability discovery + if err := s.DynamicCapabilityDiscovery(); err != nil { + logger.WarnCF("swarm", "Dynamic capability discovery failed", map[string]interface{}{ + "error": err.Error(), + }) + } + + // Start base worker + if err := s.Worker.Start(ctx); err != nil { + return fmt.Errorf("failed to start worker: %w", err) + } + + logger.InfoCF("swarm", "Specialist node started", map[string]interface{}{ + "node_id": s.nodeInfo.ID, + "capabilities": fmt.Sprintf("%v", s.nodeInfo.Capabilities), + }) + + return nil +} + +// DynamicCapabilityDiscovery scans for skills and registers them as capabilities +func (s *SpecialistNode) DynamicCapabilityDiscovery() error { + if s.skillsDir == "" { + // Use default skills directory + s.skillsDir = filepath.Join(os.Getenv("HOME"), ".picoclaw", "skills") + } + + // Check if directory exists + if _, err := os.Stat(s.skillsDir); os.IsNotExist(err) { + logger.DebugC("swarm", "Skills directory does not exist, skipping discovery") + return nil + } + + logger.InfoCF("swarm", "Starting dynamic capability discovery", map[string]interface{}{ + "directory": s.skillsDir, + }) + + // Walk the skills directory + err := filepath.Walk(s.skillsDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil // Continue on error + } + + // Look for .md files which typically define skills + if !info.IsDir() && strings.HasSuffix(path, ".md") { + if err := s.discoverSkillFromFile(path); err != nil { + logger.DebugCF("swarm", "Failed to discover skill", map[string]interface{}{ + "path": path, + "error": err.Error(), + }) + } + } + + return nil + }) + + if err != nil { + return fmt.Errorf("failed to walk skills directory: %w", err) + } + + return nil +} + +// discoverSkillFromFile parses a skill file and registers it as a capability +func (s *SpecialistNode) discoverSkillFromFile(path string) error { + // Read file content + content, err := os.ReadFile(path) + if err != nil { + return err + } + + // Extract skill name from filename + skillName := strings.TrimSuffix(filepath.Base(path), ".md") + + // Parse metadata from content + // Skills typically have frontmatter or specific patterns + metadata := s.extractSkillMetadata(string(content)) + + // Register as capability + description := metadata["description"] + if description == "" { + description = fmt.Sprintf("Skill: %s", skillName) + } + + version := metadata["version"] + if version == "" { + version = "1.0.0" + } + + // Build metadata map + metaMap := make(map[string]interface{}) + metaMap["file"] = path + for k, v := range metadata { + metaMap[k] = v + } + + if err := s.registry.Register(skillName, description, version, metaMap); err != nil { + return fmt.Errorf("failed to register skill %s: %w", skillName, err) + } + + logger.InfoCF("swarm", "Discovered and registered skill", map[string]interface{}{ + "skill": skillName, + "description": description, + "version": version, + }) + + return nil +} + +// extractSkillMetadata extracts metadata from skill content +func (s *SpecialistNode) extractSkillMetadata(content string) map[string]string { + metadata := make(map[string]string) + + lines := strings.Split(content, "\n") + + // Look for metadata patterns + // Pattern 1: Key: Value at the start + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.Contains(line, ":") && !strings.HasPrefix(line, "#") { + parts := strings.SplitN(line, ":", 2) + if len(parts) == 2 { + key := strings.TrimSpace(strings.ToLower(parts[0])) + value := strings.TrimSpace(parts[1]) + + // Recognize common metadata keys + switch key { + case "description", "desc", "summary": + metadata["description"] = value + case "version", "ver": + metadata["version"] = value + case "category", "type": + metadata["category"] = value + case "author": + metadata["author"] = value + case "tags": + metadata["tags"] = value + } + } + } + + // Only check first 20 lines for metadata + if strings.HasPrefix(line, "#") && len(metadata) > 0 { + // Found a heading, metadata section likely ended + break + } + } + + // Extract description from first heading if not found + if metadata["description"] == "" { + for _, line := range lines { + if strings.HasPrefix(line, "#") { + metadata["description"] = strings.TrimSpace(strings.TrimPrefix(line, "#")) + break + } + } + } + + return metadata +} + +// RegisterCapability manually registers a capability +func (s *SpecialistNode) RegisterCapability(name, description, version string, metadata map[string]interface{}) error { + return s.registry.Register(name, description, version, metadata) +} + +// GetCapabilities returns all registered capabilities +func (s *SpecialistNode) GetCapabilities() []Capability { + return s.registry.List() +} + +// HasCapability checks if the specialist has a specific capability +func (s *SpecialistNode) HasCapability(name string) bool { + _, ok := s.registry.Get(name) + return ok +} + +// DiscoverSwarmCapabilities finds capabilities across the swarm +func (s *SpecialistNode) DiscoverSwarmCapabilities(ctx context.Context, name, version string) ([]Capability, error) { + return s.registry.Discover(ctx, name, version) +} + +// ExecuteSpecializedTask executes a task that requires this specialist's capabilities +func (s *SpecialistNode) ExecuteSpecializedTask(ctx context.Context, task *SwarmTask) (string, error) { + logger.InfoCF("swarm", "Executing specialized task", map[string]interface{}{ + "task_id": task.ID, + "capability": task.Capability, + }) + + // Check if we have this capability + if !s.HasCapability(task.Capability) { + return "", fmt.Errorf("specialist does not have capability: %s", task.Capability) + } + + // Execute the task using the agent loop + result, err := s.agentLoop.ProcessDirect(ctx, task.Prompt, "swarm:specialist:"+task.ID) + if err != nil { + return "", fmt.Errorf("specialized task execution failed: %w", err) + } + + return result, nil +} + +// HeartbeatWithCapabilities sends a heartbeat with capability information +func (s *SpecialistNode) HeartbeatWithCapabilities() error { + // This would be called periodically to announce capabilities + caps := s.GetCapabilities() + + capabilityNames := make([]string, len(caps)) + for i, cap := range caps { + capabilityNames[i] = cap.Name + } + + // Update node info with current capabilities + s.nodeInfo.Capabilities = capabilityNames + + logger.DebugCF("swarm", "Heartbeat with capabilities", map[string]interface{}{ + "node_id": s.nodeInfo.ID, + "capabilities": capabilityNames, + }) + + return nil +} + +// CapabilityHeartbeatLoop sends periodic capability announcements +func (s *SpecialistNode) CapabilityHeartbeatLoop(ctx context.Context) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.HeartbeatWithCapabilities(); err != nil { + logger.WarnCF("swarm", "Failed to send capability heartbeat", map[string]interface{}{ + "error": err.Error(), + }) + } + } + } +} diff --git a/pkg/swarm/specialist_test.go b/pkg/swarm/specialist_test.go new file mode 100644 index 000000000..c12a6a638 --- /dev/null +++ b/pkg/swarm/specialist_test.go @@ -0,0 +1,129 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSpecialistNode_extractSkillMetadata(t *testing.T) { + specialist := &SpecialistNode{} + + t.Run("metadata with description and version", func(t *testing.T) { + content := `Description: This is a test skill +Version: 1.0.0 +Author: Test Author + +# Skill Name + +Some content here` + + metadata := specialist.extractSkillMetadata(content) + assert.Equal(t, "This is a test skill", metadata["description"]) + assert.Equal(t, "1.0.0", metadata["version"]) + assert.Equal(t, "Test Author", metadata["author"]) + }) + + t.Run("metadata from heading", func(t *testing.T) { + content := `# Test Skill Heading + +Some content here` + + metadata := specialist.extractSkillMetadata(content) + assert.Equal(t, "Test Skill Heading", metadata["description"]) + }) + + t.Run("empty content", func(t *testing.T) { + content := `` + + metadata := specialist.extractSkillMetadata(content) + assert.Empty(t, metadata["description"]) + }) +} + +func TestCapability(t *testing.T) { + cap := &Capability{ + Name: "test-capability", + Description: "A test capability", + Version: "1.0.0", + NodeID: "node-1", + Metadata: map[string]interface{}{ + "key": "value", + }, + } + + assert.Equal(t, "test-capability", cap.Name) + assert.Equal(t, "A test capability", cap.Description) + assert.Equal(t, "1.0.0", cap.Version) + assert.Equal(t, "node-1", cap.NodeID) + assert.NotNil(t, cap.Metadata) +} + +func TestCapabilityRequest(t *testing.T) { + req := CapabilityRequest{ + RequesterID: "node-1", + Capability: "test-capability", + Version: "1.0.0", + } + + assert.Equal(t, "node-1", req.RequesterID) + assert.Equal(t, "test-capability", req.Capability) + assert.Equal(t, "1.0.0", req.Version) +} + +func TestCapabilityResponse(t *testing.T) { + caps := []Capability{ + { + Name: "cap-1", + Description: "Capability 1", + Version: "1.0.0", + NodeID: "node-1", + }, + } + + resp := CapabilityResponse{ + Capabilities: caps, + RequestID: "req-1", + Timestamp: time.Now().UnixMilli(), + } + + assert.Equal(t, caps, resp.Capabilities) + assert.Equal(t, "req-1", resp.RequestID) + assert.NotZero(t, resp.Timestamp) +} + +func TestDAGNodeCreation(t *testing.T) { + node := &DAGNode{ + ID: "node-1", + Task: &SwarmTask{ + ID: "task-1", + Prompt: "Test task", + }, + Status: DAGNodePending, + } + + assert.Equal(t, "node-1", node.ID) + assert.Equal(t, "task-1", node.Task.ID) + assert.Equal(t, DAGNodePending, node.Status) +} + +func TestDAGNodeStatuses(t *testing.T) { + statuses := []DAGNodeStatus{ + DAGNodePending, + DAGNodeReady, + DAGNodeRunning, + DAGNodeCompleted, + DAGNodeFailed, + DAGNodeSkipped, + } + + for _, status := range statuses { + assert.NotEmpty(t, string(status)) + } +} diff --git a/pkg/swarm/subject.go b/pkg/swarm/subject.go new file mode 100644 index 000000000..2c7ef1c2d --- /dev/null +++ b/pkg/swarm/subject.go @@ -0,0 +1,304 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/identity" +) + +// SubjectDomain represents different domains in the swarm +type SubjectDomain string + +const ( + // DomainSwarm is for general swarm messages + DomainSwarm SubjectDomain = "swarm" + // DomainMemory is for memory-related messages + DomainMemory SubjectDomain = "memory" + // DomainTask is for task messages + DomainTask SubjectDomain = "task" + // DomainDiscovery is for node discovery + DomainDiscovery SubjectDomain = "discovery" + // DomainSystem is for system messages + DomainSystem SubjectDomain = "system" + // DomainCross is for cross H-id communication + DomainCross SubjectDomain = "x" +) + +// SubjectBuilder builds NATS subjects with H-id partitioning +type SubjectBuilder struct { + // hid is the H-id for this instance (nil = no partitioning) + hid string + + // prefix is the subject prefix (default: "picoclaw") + prefix string +} + +// NewSubjectBuilder creates a new subject builder +func NewSubjectBuilder() *SubjectBuilder { + return &SubjectBuilder{ + prefix: "picoclaw", + } +} + +// WithHID sets the H-id for partitioning +func (b *SubjectBuilder) WithHID(hid string) *SubjectBuilder { + b.hid = hid + return b +} + +// WithIdentity sets the H-id from an identity +func (b *SubjectBuilder) WithIdentity(id *identity.Identity) *SubjectBuilder { + if id != nil { + b.hid = id.HID + } + return b +} + +// WithPrefix sets a custom prefix +func (b *SubjectBuilder) WithPrefix(prefix string) *SubjectBuilder { + b.prefix = prefix + return b + +} + +// Build builds a subject with the given domain and parts +// Format: {prefix}.{domain}[.{hid}].{parts...} +func (b *SubjectBuilder) Build(domain SubjectDomain, parts ...string) string { + sb := strings.Builder{} + sb.WriteString(b.prefix) + sb.WriteString(".") + sb.WriteString(string(domain)) + + if b.hid != "" { + sb.WriteString(".") + sb.WriteString(b.hid) + } + + for _, part := range parts { + sb.WriteString(".") + sb.WriteString(part) + } + + return sb.String() +} + +// BuildCross builds a subject for cross H-id communication +// Format: {prefix}.x.{from_hid}.{to_hid}.{parts...} +func (b *SubjectBuilder) BuildCross(fromHID, toHID string, parts ...string) string { + sb := strings.Builder{} + sb.WriteString(b.prefix) + sb.WriteString(".") + sb.WriteString(string(DomainCross)) + sb.WriteString(".") + sb.WriteString(fromHID) + sb.WriteString(".") + sb.WriteString(toHID) + + for _, part := range parts { + sb.WriteString(".") + sb.WriteString(part) + } + + return sb.String() +} + +// BuildWildcard builds a wildcard subject for subscribing +// If HID is set, creates a HID-specific wildcard, otherwise creates a global wildcard +func (b *SubjectBuilder) BuildWildcard(domain SubjectDomain, suffix string) string { + if b.hid != "" { + return fmt.Sprintf("%s.%s.%s.%s", b.prefix, domain, b.hid, suffix) + } + return fmt.Sprintf("%s.%s.%s", b.prefix, domain, suffix) +} + +// ParseSubject parses a NATS subject into its components +func ParseSubject(s string) *ParsedSubject { + parts := strings.Split(s, ".") + + if len(parts) < 2 { + return nil + } + + ps := &ParsedSubject{ + Prefix: parts[0], + } + + if len(parts) < 3 { + return ps + } + + ps.Domain = SubjectDomain(parts[1]) + + // Check if this is a cross-domain subject + if ps.Domain == DomainCross && len(parts) >= 4 { + ps.FromHID = parts[2] + ps.ToHID = parts[3] + if len(parts) > 4 { + ps.Parts = parts[4:] + } + return ps + } + + // Regular subject - check if third part is an H-id + if len(parts) >= 3 { + // Heuristic: if third part looks like an H-id (contains "user-", "org-", "group-") + // treat it as H-id, otherwise treat it as regular parts + if isLikelyHID(parts[2]) { + ps.HID = parts[2] + if len(parts) > 3 { + ps.Parts = parts[3:] + } + } else { + ps.Parts = parts[2:] + } + } + + return ps +} + +// isLikelyHID checks if a string is likely an H-id +func isLikelyHID(s string) bool { + return strings.HasPrefix(s, "user-") || + strings.HasPrefix(s, "org-") || + strings.HasPrefix(s, "group-") || + strings.HasPrefix(s, "tenant-") +} + +// ParsedSubject represents a parsed NATS subject +type ParsedSubject struct { + Prefix string + Domain SubjectDomain + HID string + FromHID string // For cross-domain subjects + ToHID string // For cross-domain subjects + Parts []string +} + +// String reconstructs the subject string +func (ps *ParsedSubject) String() string { + sb := strings.Builder{} + sb.WriteString(ps.Prefix) + sb.WriteString(".") + sb.WriteString(string(ps.Domain)) + + if ps.FromHID != "" && ps.ToHID != "" { + sb.WriteString(".") + sb.WriteString(ps.FromHID) + sb.WriteString(".") + sb.WriteString(ps.ToHID) + for _, part := range ps.Parts { + sb.WriteString(".") + sb.WriteString(part) + } + return sb.String() + } + + if ps.HID != "" { + sb.WriteString(".") + sb.WriteString(ps.HID) + } + + for _, part := range ps.Parts { + sb.WriteString(".") + sb.WriteString(part) + } + + return sb.String() +} + +// IsCrossDomain returns true if this is a cross-domain subject +func (ps *ParsedSubject) IsCrossDomain() bool { + return ps.Domain == DomainCross +} + +// GetHID returns the H-id from the subject (either FromHID or HID) +func (ps *ParsedSubject) GetHID() string { + if ps.IsCrossDomain() { + return ps.FromHID + } + return ps.HID +} + +// Helper functions for common subjects + +// HeartbeatSubject builds a heartbeat subject for a node +func HeartbeatSubject(hid, nodeID string) string { + b := NewSubjectBuilder() + if hid != "" { + b = b.WithHID(hid) + } + return b.Build(DomainSwarm, "heartbeat", nodeID) +} + +// DiscoverySubject builds a discovery subject +func DiscoverySubject(hid string) string { + b := NewSubjectBuilder() + if hid != "" { + b = b.WithHID(hid) + } + return b.Build(DomainDiscovery, "announce") +} + +// TaskAssignSubject builds a task assignment subject +func TaskAssignSubject(hid, nodeID string) string { + b := NewSubjectBuilder() + if hid != "" { + b = b.WithHID(hid) + } + return b.Build(DomainTask, "assign", nodeID) +} + +// TaskResultSubject builds a task result subject +func TaskResultSubject(hid, taskID string) string { + b := NewSubjectBuilder() + if hid != "" { + b = b.WithHID(hid) + } + return b.Build(DomainTask, "result", taskID) +} + +// MemoryStoreSubject builds a memory store subject +func MemoryStoreSubject(hid, memoryID string) string { + b := NewSubjectBuilder() + if hid != "" { + b = b.WithHID(hid) + } + return b.Build(DomainMemory, "store", memoryID) +} + +// MemoryQuerySubject builds a memory query subject +func MemoryQuerySubject(hid string) string { + b := NewSubjectBuilder() + if hid != "" { + b = b.WithHID(hid) + } + return b.Build(DomainMemory, "query") +} + +// SystemShutdownSubject builds a system shutdown subject +func SystemShutdownSubject(hid, nodeID string) string { + b := NewSubjectBuilder() + if hid != "" { + b = b.WithHID(hid) + } + return b.Build(DomainSystem, "shutdown", nodeID) +} + +// CrossHIDSubject builds a cross H-id communication subject +func CrossHIDSubject(fromHID, toHID string, action string) string { + b := NewSubjectBuilder() + return b.BuildCross(fromHID, toHID, action) +} + +// LegacySubjectBuilder converts new-style subjects to legacy format +func LegacySubject(subject string) string { + // For backward compatibility, return subject as-is + // When legacy clients are removed, this can be deleted + return subject +} diff --git a/pkg/swarm/temporal.go b/pkg/swarm/temporal.go new file mode 100644 index 000000000..70074a87c --- /dev/null +++ b/pkg/swarm/temporal.go @@ -0,0 +1,160 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "time" + + "go.temporal.io/sdk/client" + "go.temporal.io/sdk/worker" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// TemporalClient wraps the Temporal SDK client +type TemporalClient struct { + client client.Client + worker worker.Worker + cfg *config.TemporalConfig + taskQueue string + connected bool +} + +// NewTemporalClient creates a new Temporal client +func NewTemporalClient(cfg *config.TemporalConfig) *TemporalClient { + return &TemporalClient{ + cfg: cfg, + taskQueue: cfg.TaskQueue, + } +} + +// Connect establishes connection to Temporal server +func (tc *TemporalClient) Connect(ctx context.Context) error { + c, err := client.Dial(client.Options{ + HostPort: tc.cfg.Host, + Namespace: tc.cfg.Namespace, + }) + if err != nil { + // Temporal is optional - log warning but don't fail + logger.WarnCF("swarm", "Failed to connect to Temporal (workflows disabled)", map[string]interface{}{ + "host": tc.cfg.Host, + "error": err.Error(), + }) + return nil + } + + tc.client = c + tc.connected = true + logger.InfoCF("swarm", "Connected to Temporal", map[string]interface{}{ + "host": tc.cfg.Host, + "namespace": tc.cfg.Namespace, + }) + return nil +} + +// IsConnected returns true if connected to Temporal +func (tc *TemporalClient) IsConnected() bool { + return tc.connected +} + +// StartWorker starts a Temporal worker that processes workflows and activities +func (tc *TemporalClient) StartWorker(ctx context.Context, workflows []interface{}, activities *Activities) error { + if !tc.connected { + logger.WarnC("swarm", "Temporal not connected, skipping worker start") + return nil + } + + // Register activities globally for workflow access + RegisterActivities(activities) + + w := worker.New(tc.client, tc.taskQueue, worker.Options{}) + + // Register workflows + for _, wf := range workflows { + w.RegisterWorkflow(wf) + } + + // Register activity functions + w.RegisterActivity(DecomposeTaskActivity) + w.RegisterActivity(ExecuteDirectActivity) + w.RegisterActivity(ExecuteSubtaskActivity) + w.RegisterActivity(SynthesizeResultsActivity) + + tc.worker = w + + // Start worker in background + go func() { + if err := w.Run(worker.InterruptCh()); err != nil { + logger.ErrorCF("swarm", "Temporal worker error", map[string]interface{}{ + "error": err.Error(), + }) + } + }() + + logger.InfoCF("swarm", "Temporal worker started", map[string]interface{}{ + "task_queue": tc.taskQueue, + }) + return nil +} + +// StartWorkflow starts a new workflow execution +func (tc *TemporalClient) StartWorkflow(ctx context.Context, workflowType string, task *SwarmTask) (string, error) { + if !tc.connected { + return "", fmt.Errorf("temporal not connected") + } + + workflowTimeout, _ := time.ParseDuration(tc.cfg.WorkflowTimeout) + if workflowTimeout == 0 { + workflowTimeout = 30 * time.Minute + } + + options := client.StartWorkflowOptions{ + ID: task.ID, + TaskQueue: tc.taskQueue, + WorkflowExecutionTimeout: workflowTimeout, + } + + we, err := tc.client.ExecuteWorkflow(ctx, options, workflowType, task) + if err != nil { + return "", fmt.Errorf("failed to start workflow: %w", err) + } + + logger.InfoCF("swarm", "Workflow started", map[string]interface{}{ + "workflow_id": we.GetID(), + "run_id": we.GetRunID(), + "task_id": task.ID, + }) + + return we.GetID(), nil +} + +// GetWorkflowResult waits for and returns workflow result +func (tc *TemporalClient) GetWorkflowResult(ctx context.Context, workflowID string) (string, error) { + if !tc.connected { + return "", fmt.Errorf("temporal not connected") + } + + run := tc.client.GetWorkflow(ctx, workflowID, "") + + var result string + if err := run.Get(ctx, &result); err != nil { + return "", err + } + return result, nil +} + +// Stop stops the Temporal client and worker +func (tc *TemporalClient) Stop() { + if tc.worker != nil { + tc.worker.Stop() + } + if tc.client != nil { + tc.client.Close() + } + tc.connected = false +} diff --git a/pkg/swarm/testutil.go b/pkg/swarm/testutil.go new file mode 100644 index 000000000..e89e4299c --- /dev/null +++ b/pkg/swarm/testutil.go @@ -0,0 +1,246 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "net" + "testing" + "time" + + "github.com/nats-io/nats.go" + "github.com/sipeed/picoclaw/pkg/config" +) + +// TestNATS provides a test NATS environment with embedded server +type TestNATS struct { + embedded *EmbeddedNATS + nc *nats.Conn + js nats.JetStreamContext + url string +} + +// SetupTestNATS creates a new test NATS environment +// Call t.Cleanup(func() { tn.Stop() }) to ensure cleanup +func SetupTestNATS(t *testing.T) *TestNATS { + t.Helper() + + // Find a free port + port := freePort(t) + + cfg := &config.NATSConfig{ + EmbeddedPort: port, + } + + embedded := NewEmbeddedNATS(cfg) + if err := embedded.Start(); err != nil { + t.Fatalf("Failed to start embedded NATS: %v", err) + } + + // Give the server a moment to be fully ready + time.Sleep(100 * time.Millisecond) + + url := embedded.ClientURL() + + // Connect to the embedded server + nc, err := nats.Connect(url) + if err != nil { + embedded.Stop() + t.Fatalf("Failed to connect to embedded NATS: %v", err) + } + + // Create JetStream context + js, err := nc.JetStream() + if err != nil { + nc.Close() + embedded.Stop() + t.Fatalf("Failed to create JetStream context: %v", err) + } + + tn := &TestNATS{ + embedded: embedded, + nc: nc, + js: js, + url: url, + } + + // Register cleanup + t.Cleanup(func() { + tn.Stop() + }) + + return tn +} + +// Stop stops the test NATS environment +func (tn *TestNATS) Stop() { + if tn.nc != nil { + tn.nc.Close() + } + if tn.embedded != nil { + tn.embedded.Stop() + } +} + +// NC returns the NATS connection +func (tn *TestNATS) NC() *nats.Conn { + return tn.nc +} + +// JS returns the JetStream context +func (tn *TestNATS) JS() nats.JetStreamContext { + return tn.js +} + +// URL returns the NATS server URL +func (tn *TestNATS) URL() string { + return tn.url +} + +// PublishTestMessage publishes a test message +func (tn *TestNATS) PublishTestMessage(subject string, data []byte) error { + return tn.nc.Publish(subject, data) +} + +// CreateTestStream creates a test JetStream stream +func (tn *TestNATS) CreateTestStream(streamName string, subjects []string) error { + _, err := tn.js.AddStream(&nats.StreamConfig{ + Name: streamName, + Subjects: subjects, + MaxAge: 1 * time.Hour, + }) + return err +} + +// CreateTestConsumer creates a test JetStream consumer +func (tn *TestNATS) CreateTestConsumer(stream, consumer string) error { + _, err := tn.js.AddConsumer(stream, &nats.ConsumerConfig{ + Durable: consumer, + AckPolicy: nats.AckExplicitPolicy, + }) + return err +} + +// WaitForMessage waits for a message on the given subject +func (tn *TestNATS) WaitForMessage(ctx context.Context, subject string, timeout time.Duration) (*nats.Msg, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + sub, err := tn.nc.SubscribeSync(subject) + if err != nil { + return nil, err + } + defer sub.Unsubscribe() + + msg, err := sub.NextMsgWithContext(ctx) + if err != nil { + return nil, err + } + + return msg, nil +} + +// CreateTestNodeInfo creates a test NodeInfo +func CreateTestNodeInfo(id, role string, capabilities []string) *NodeInfo { + return &NodeInfo{ + ID: id, + Role: NodeRole(role), + Capabilities: capabilities, + Model: "test-model", + Status: StatusOnline, + MaxTasks: 5, + StartedAt: time.Now().UnixMilli(), + Metadata: make(map[string]string), + } +} + +// CreateTestTask creates a test SwarmTask +func CreateTestTask(id, taskType, prompt, capability string) *SwarmTask { + return &SwarmTask{ + ID: id, + Type: SwarmTaskType(taskType), + Prompt: prompt, + Capability: capability, + Status: TaskPending, + CreatedAt: time.Now().UnixMilli(), + Timeout: 5 * 60 * 1000, // 5 minutes in milliseconds + } +} + +// RunTestWithNATS runs a test function with a test NATS environment +func RunTestWithNATS(t *testing.T, testFn func(*TestNATS)) { + t.Helper() + + tn := SetupTestNATS(t) + testFn(tn) +} + +// freePort finds a free port for testing +// Note: This duplicates helpers_test.go to avoid build issues when +// testutil is used in non-test contexts +func freePort(t *testing.T) int { + t.Helper() + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to find free port: %v", err) + } + defer l.Close() + + return l.Addr().(*net.TCPAddr).Port +} + +// AssertConnected asserts that NATS is connected +func AssertConnected(t *testing.T, tn *TestNATS) { + t.Helper() + + if !tn.nc.IsConnected() { + t.Error("NATS connection is not established") + } +} + +// AssertStreamExists asserts that a stream exists in JetStream +func AssertStreamExists(t *testing.T, tn *TestNATS, streamName string) { + t.Helper() + + stream, err := tn.js.StreamInfo(streamName) + if err != nil { + t.Fatalf("Stream %s does not exist: %v", streamName, err) + } + + if stream == nil { + t.Errorf("Stream %s is nil", streamName) + } +} + +// PurgeTestStream purges a test stream +func PurgeTestStream(t *testing.T, tn *TestNATS, streamName string) { + t.Helper() + + if err := tn.js.PurgeStream(streamName, &nats.StreamPurgeRequest{}); err != nil { + t.Logf("Warning: failed to purge stream %s: %v", streamName, err) + } +} + +// DeleteTestStream deletes a test stream +func DeleteTestStream(t *testing.T, tn *TestNATS, streamName string) { + t.Helper() + + if err := tn.js.DeleteStream(streamName); err != nil { + t.Logf("Warning: failed to delete stream %s: %v", streamName, err) + } +} + +// GetStreamInfo safely gets stream info, returning nil if stream doesn't exist +func GetStreamInfo(t *testing.T, tn *TestNATS, streamName string) *nats.StreamInfo { + t.Helper() + + stream, err := tn.js.StreamInfo(streamName) + if err != nil { + return nil + } + return stream +} diff --git a/pkg/swarm/types.go b/pkg/swarm/types.go new file mode 100644 index 000000000..9831da04e --- /dev/null +++ b/pkg/swarm/types.go @@ -0,0 +1,251 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "fmt" + "time" + + "github.com/google/uuid" +) + +// NodeRole defines the role of a swarm node +type NodeRole string + +const ( + RoleCoordinator NodeRole = "coordinator" + RoleWorker NodeRole = "worker" + RoleSpecialist NodeRole = "specialist" +) + +// NodeStatus defines the status of a swarm node +type NodeStatus string + +const ( + StatusOnline NodeStatus = "online" + StatusBusy NodeStatus = "busy" + StatusOffline NodeStatus = "offline" + StatusSuspicious NodeStatus = "suspicious" + StatusDraining NodeStatus = "draining" +) + +// NodeInfo represents a node in the swarm +type NodeInfo struct { + ID string `json:"id"` + Role NodeRole `json:"role"` + Capabilities []string `json:"capabilities"` + Model string `json:"model"` + Status NodeStatus `json:"status"` + Load float64 `json:"load"` + TasksRunning int `json:"tasks_running"` + MaxTasks int `json:"max_tasks"` + Metadata map[string]string `json:"metadata"` + LastSeen int64 `json:"last_seen"` + StartedAt int64 `json:"started_at"` + Address string `json:"address"` // NATS address for direct messaging +} + +// SwarmTaskType defines how a task is routed +type SwarmTaskType string + +const ( + TaskTypeDirect SwarmTaskType = "direct" // Assigned to specific node + TaskTypeWorkflow SwarmTaskType = "workflow" // Temporal workflow + TaskTypeBroadcast SwarmTaskType = "broadcast" // Broadcast by capability +) + +// SwarmTaskStatus defines the status of a task +type SwarmTaskStatus string + +const ( + TaskPending SwarmTaskStatus = "pending" + TaskAssigned SwarmTaskStatus = "assigned" + TaskRunning SwarmTaskStatus = "running" + TaskDone SwarmTaskStatus = "done" + TaskFailed SwarmTaskStatus = "failed" +) + +// SwarmTask represents a task in the swarm +type SwarmTask struct { + ID string `json:"id"` + WorkflowID string `json:"workflow_id,omitempty"` + ParentID string `json:"parent_id,omitempty"` + Type SwarmTaskType `json:"type"` + Priority int `json:"priority"` // 0=low, 1=normal, 2=high, 3=critical + Capability string `json:"capability"` + Prompt string `json:"prompt"` + Context map[string]interface{} `json:"context"` + AssignedTo string `json:"assigned_to"` + Status SwarmTaskStatus `json:"status"` + Result string `json:"result,omitempty"` + Error string `json:"error,omitempty"` + CreatedAt int64 `json:"created_at"` + CompletedAt int64 `json:"completed_at,omitempty"` + Timeout int64 `json:"timeout"` // Timeout in milliseconds +} + +// NewSwarmTask creates a new task with default values +func NewSwarmTask(taskType SwarmTaskType, capability, prompt string) *SwarmTask { + return &SwarmTask{ + ID: generateTaskID(), + Type: taskType, + Priority: 1, // normal + Capability: capability, + Prompt: prompt, + Context: make(map[string]interface{}), + Status: TaskPending, + CreatedAt: time.Now().UnixMilli(), + Timeout: 10 * 60 * 1000, // 10 minutes default + } +} + +// TaskResult is sent when a task completes +type TaskResult struct { + TaskID string `json:"task_id"` + NodeID string `json:"node_id"` + Status string `json:"status"` + Result string `json:"result,omitempty"` + Error string `json:"error,omitempty"` + CompletedAt int64 `json:"completed_at"` +} + +// TaskProgress is sent periodically during task execution +type TaskProgress struct { + TaskID string `json:"task_id"` + NodeID string `json:"node_id"` + Progress float64 `json:"progress"` // 0.0 to 1.0 + Message string `json:"message"` +} + +// DiscoveryAnnounce is published when a node joins +type DiscoveryAnnounce struct { + Node NodeInfo `json:"node"` + Timestamp int64 `json:"timestamp"` +} + +// DiscoveryQuery is published to discover nodes +type DiscoveryQuery struct { + RequesterID string `json:"requester_id"` + Capability string `json:"capability,omitempty"` // Filter by capability + Role NodeRole `json:"role,omitempty"` // Filter by role +} + +// Heartbeat is published periodically by each node +type Heartbeat struct { + NodeID string `json:"node_id"` + Role NodeRole `json:"role,omitempty"` + Status NodeStatus `json:"status"` + Load float64 `json:"load"` + TasksRunning int `json:"tasks_running"` + Timestamp int64 `json:"timestamp"` + Capabilities []string `json:"capabilities,omitempty"` + HID string `json:"hid,omitempty"` + SID string `json:"sid,omitempty"` +} + +// generateTaskID generates a unique task ID +func generateTaskID() string { + return fmt.Sprintf("task-%s", uuid.New().String()[:8]) +} + +// TaskEventType represents the type of task event +type TaskEventType string + +const ( + TaskEventCreated TaskEventType = "created" + TaskEventAssigned TaskEventType = "assigned" + TaskEventStarted TaskEventType = "started" + TaskEventProgress TaskEventType = "progress" + TaskEventCompleted TaskEventType = "completed" + TaskEventFailed TaskEventType = "failed" + TaskEventRetry TaskEventType = "retry" + TaskEventCheckpoint TaskEventType = "checkpoint" +) + +// TaskEvent represents a single event in task lifecycle +type TaskEvent struct { + EventID string `json:"event_id"` + TaskID string `json:"task_id"` + EventType TaskEventType `json:"event_type"` + Timestamp int64 `json:"timestamp"` + NodeID string `json:"node_id,omitempty"` + Status SwarmTaskStatus `json:"status,omitempty"` + Message string `json:"message,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Progress float64 `json:"progress,omitempty"` +} + +// CheckpointType defines the type of checkpoint +type CheckpointType string + +const ( + CheckpointTypeProgress CheckpointType = "progress" // Periodic progress checkpoint + CheckpointTypeMilestone CheckpointType = "milestone" // Significant milestone reached + CheckpointTypePreFailover CheckpointType = "pre_failover" // Before potential failover + CheckpointTypeUserCheckpointType CheckpointType = "user" // User-requested checkpoint +) + +// TaskCheckpoint represents a saved state for task recovery +type TaskCheckpoint struct { + CheckpointID string `json:"checkpoint_id"` + TaskID string `json:"task_id"` + Type CheckpointType `json:"type"` + Timestamp int64 `json:"timestamp"` + NodeID string `json:"node_id"` + Progress float64 `json:"progress"` // 0.0 to 1.0 + State map[string]interface{} `json:"state"` // Arbitrary state data + PartialResult string `json:"partial_result"` // Partial output so far + Context map[string]interface{} `json:"context"` // LLM context/messages + Metadata map[string]string `json:"metadata"` // Additional metadata +} + +// DAGNode represents a single node in a DAG workflow +type DAGNode struct { + ID string `json:"id"` + Task *SwarmTask `json:"task"` + Dependencies []string `json:"dependencies"` // IDs of nodes this depends on + Status DAGNodeStatus `json:"status"` + Result string `json:"result,omitempty"` + Error string `json:"error,omitempty"` + StartedAt int64 `json:"started_at,omitempty"` + CompletedAt int64 `json:"completed_at,omitempty"` +} + +// DAGNodeStatus represents the status of a DAG node +type DAGNodeStatus string + +const ( + DAGNodePending DAGNodeStatus = "pending" + DAGNodeReady DAGNodeStatus = "ready" + DAGNodeRunning DAGNodeStatus = "running" + DAGNodeCompleted DAGNodeStatus = "completed" + DAGNodeFailed DAGNodeStatus = "failed" + DAGNodeSkipped DAGNodeStatus = "skipped" +) + +// Capability represents a specialized capability that a node can provide +type Capability struct { + Name string `json:"name"` + Description string `json:"description"` + Version string `json:"version"` + Metadata map[string]interface{} `json:"metadata"` + NodeID string `json:"node_id"` + RegisteredAt int64 `json:"registered_at"` +} + +// CapabilityRequest is used to discover capabilities across the swarm +type CapabilityRequest struct { + RequesterID string `json:"requester_id"` + Capability string `json:"capability,omitempty"` // Optional: filter by specific capability + Version string `json:"version,omitempty"` // Optional: filter by version +} + +// CapabilityResponse is the response to a capability discovery request +type CapabilityResponse struct { + Capabilities []Capability `json:"capabilities"` + RequestID string `json:"request_id"` + Timestamp int64 `json:"timestamp"` +} diff --git a/pkg/swarm/types_test.go b/pkg/swarm/types_test.go new file mode 100644 index 000000000..0411defad --- /dev/null +++ b/pkg/swarm/types_test.go @@ -0,0 +1,551 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "encoding/json" + "strings" + "testing" + "time" +) + +func TestNewSwarmTask(t *testing.T) { + tests := []struct { + name string + taskType SwarmTaskType + capability string + prompt string + }{ + { + name: "direct task with defaults", + taskType: TaskTypeDirect, + capability: "code", + prompt: "write code", + }, + { + name: "broadcast task", + taskType: TaskTypeBroadcast, + capability: "research", + prompt: "find info", + }, + { + name: "workflow task", + taskType: TaskTypeWorkflow, + capability: "complex", + prompt: "analyze data", + }, + { + name: "empty capability", + taskType: TaskTypeDirect, + capability: "", + prompt: "hello", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + before := time.Now().UnixMilli() + task := NewSwarmTask(tt.taskType, tt.capability, tt.prompt) + after := time.Now().UnixMilli() + + if !strings.HasPrefix(task.ID, "task-") { + t.Errorf("ID = %q, want prefix 'task-'", task.ID) + } + // "task-" (5 chars) + 8 hex chars = 13 + if len(task.ID) != 13 { + t.Errorf("len(ID) = %d, want 13", len(task.ID)) + } + if task.Type != tt.taskType { + t.Errorf("Type = %q, want %q", task.Type, tt.taskType) + } + if task.Priority != 1 { + t.Errorf("Priority = %d, want 1", task.Priority) + } + if task.Capability != tt.capability { + t.Errorf("Capability = %q, want %q", task.Capability, tt.capability) + } + if task.Prompt != tt.prompt { + t.Errorf("Prompt = %q, want %q", task.Prompt, tt.prompt) + } + if task.Status != TaskPending { + t.Errorf("Status = %q, want %q", task.Status, TaskPending) + } + if task.Context == nil { + t.Error("Context is nil, want non-nil empty map") + } + if len(task.Context) != 0 { + t.Errorf("len(Context) = %d, want 0", len(task.Context)) + } + if task.Timeout != 10*60*1000 { + t.Errorf("Timeout = %d, want %d", task.Timeout, 10*60*1000) + } + if task.CreatedAt < before || task.CreatedAt > after { + t.Errorf("CreatedAt = %d, want between %d and %d", task.CreatedAt, before, after) + } + }) + } +} + +func TestSwarmTask_JSONRoundTrip(t *testing.T) { + tests := []struct { + name string + task SwarmTask + }{ + { + name: "full task", + task: SwarmTask{ + ID: "task-abc12345", + WorkflowID: "wf-1", + ParentID: "task-parent", + Type: TaskTypeDirect, + Priority: 2, + Capability: "code", + Prompt: "write a function", + Context: map[string]interface{}{"key": "value", "num": float64(42)}, + AssignedTo: "node-1", + Status: TaskRunning, + Result: "done", + Error: "", + CreatedAt: 1000000, + CompletedAt: 2000000, + Timeout: 60000, + }, + }, + { + name: "minimal task", + task: SwarmTask{ + ID: "task-min00001", + Type: TaskTypeBroadcast, + Prompt: "hello", + Status: TaskPending, + }, + }, + { + name: "task with nested context", + task: SwarmTask{ + ID: "task-ctx00001", + Type: TaskTypeWorkflow, + Prompt: "complex", + Capability: "analysis", + Status: TaskAssigned, + Context: map[string]interface{}{ + "nested": map[string]interface{}{ + "deep": "value", + }, + "list": []interface{}{"a", "b"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.task) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var got SwarmTask + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if got.ID != tt.task.ID { + t.Errorf("ID = %q, want %q", got.ID, tt.task.ID) + } + if got.Type != tt.task.Type { + t.Errorf("Type = %q, want %q", got.Type, tt.task.Type) + } + if got.Priority != tt.task.Priority { + t.Errorf("Priority = %d, want %d", got.Priority, tt.task.Priority) + } + if got.Capability != tt.task.Capability { + t.Errorf("Capability = %q, want %q", got.Capability, tt.task.Capability) + } + if got.Prompt != tt.task.Prompt { + t.Errorf("Prompt = %q, want %q", got.Prompt, tt.task.Prompt) + } + if got.Status != tt.task.Status { + t.Errorf("Status = %q, want %q", got.Status, tt.task.Status) + } + if got.Result != tt.task.Result { + t.Errorf("Result = %q, want %q", got.Result, tt.task.Result) + } + if got.WorkflowID != tt.task.WorkflowID { + t.Errorf("WorkflowID = %q, want %q", got.WorkflowID, tt.task.WorkflowID) + } + if got.AssignedTo != tt.task.AssignedTo { + t.Errorf("AssignedTo = %q, want %q", got.AssignedTo, tt.task.AssignedTo) + } + if got.CreatedAt != tt.task.CreatedAt { + t.Errorf("CreatedAt = %d, want %d", got.CreatedAt, tt.task.CreatedAt) + } + if got.Timeout != tt.task.Timeout { + t.Errorf("Timeout = %d, want %d", got.Timeout, tt.task.Timeout) + } + }) + } +} + +func TestNodeInfo_JSONRoundTrip(t *testing.T) { + tests := []struct { + name string + node NodeInfo + }{ + { + name: "worker with capabilities", + node: NodeInfo{ + ID: "node-1", + Role: RoleWorker, + Capabilities: []string{"code", "research"}, + Model: "gpt-4", + Status: StatusOnline, + Load: 0.5, + TasksRunning: 2, + MaxTasks: 4, + Metadata: map[string]string{"region": "us-east"}, + LastSeen: 1000000, + StartedAt: 900000, + Address: "nats://127.0.0.1:4222", + }, + }, + { + name: "coordinator", + node: NodeInfo{ + ID: "coord-1", + Role: RoleCoordinator, + Capabilities: []string{}, + Status: StatusBusy, + MaxTasks: 1, + Metadata: map[string]string{}, + }, + }, + { + name: "specialist offline", + node: NodeInfo{ + ID: "spec-1", + Role: RoleSpecialist, + Capabilities: []string{"ml"}, + Status: StatusOffline, + Load: 0.0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.node) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var got NodeInfo + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if got.ID != tt.node.ID { + t.Errorf("ID = %q, want %q", got.ID, tt.node.ID) + } + if got.Role != tt.node.Role { + t.Errorf("Role = %q, want %q", got.Role, tt.node.Role) + } + if got.Status != tt.node.Status { + t.Errorf("Status = %q, want %q", got.Status, tt.node.Status) + } + if got.Load != tt.node.Load { + t.Errorf("Load = %f, want %f", got.Load, tt.node.Load) + } + if got.MaxTasks != tt.node.MaxTasks { + t.Errorf("MaxTasks = %d, want %d", got.MaxTasks, tt.node.MaxTasks) + } + if got.Model != tt.node.Model { + t.Errorf("Model = %q, want %q", got.Model, tt.node.Model) + } + if got.Address != tt.node.Address { + t.Errorf("Address = %q, want %q", got.Address, tt.node.Address) + } + }) + } +} + +func TestTaskResult_JSONRoundTrip(t *testing.T) { + tests := []struct { + name string + result TaskResult + }{ + { + name: "success result", + result: TaskResult{ + TaskID: "task-abc12345", + NodeID: "node-1", + Status: "done", + Result: "completed successfully", + CompletedAt: 1000000, + }, + }, + { + name: "failure result", + result: TaskResult{ + TaskID: "task-err00001", + NodeID: "node-2", + Status: "failed", + Error: "execution timeout", + CompletedAt: 2000000, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.result) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var got TaskResult + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if got.TaskID != tt.result.TaskID { + t.Errorf("TaskID = %q, want %q", got.TaskID, tt.result.TaskID) + } + if got.NodeID != tt.result.NodeID { + t.Errorf("NodeID = %q, want %q", got.NodeID, tt.result.NodeID) + } + if got.Status != tt.result.Status { + t.Errorf("Status = %q, want %q", got.Status, tt.result.Status) + } + if got.Result != tt.result.Result { + t.Errorf("Result = %q, want %q", got.Result, tt.result.Result) + } + if got.Error != tt.result.Error { + t.Errorf("Error = %q, want %q", got.Error, tt.result.Error) + } + if got.CompletedAt != tt.result.CompletedAt { + t.Errorf("CompletedAt = %d, want %d", got.CompletedAt, tt.result.CompletedAt) + } + }) + } +} + +func TestHeartbeat_JSONRoundTrip(t *testing.T) { + tests := []struct { + name string + hb Heartbeat + }{ + { + name: "online node", + hb: Heartbeat{ + NodeID: "node-1", + Status: StatusOnline, + Load: 0.25, + TasksRunning: 1, + Timestamp: 1000000, + }, + }, + { + name: "busy node", + hb: Heartbeat{ + NodeID: "node-2", + Status: StatusBusy, + Load: 1.0, + TasksRunning: 4, + Timestamp: 2000000, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.hb) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var got Heartbeat + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if got.NodeID != tt.hb.NodeID { + t.Errorf("NodeID = %q, want %q", got.NodeID, tt.hb.NodeID) + } + if got.Status != tt.hb.Status { + t.Errorf("Status = %q, want %q", got.Status, tt.hb.Status) + } + if got.Load != tt.hb.Load { + t.Errorf("Load = %f, want %f", got.Load, tt.hb.Load) + } + if got.TasksRunning != tt.hb.TasksRunning { + t.Errorf("TasksRunning = %d, want %d", got.TasksRunning, tt.hb.TasksRunning) + } + if got.Timestamp != tt.hb.Timestamp { + t.Errorf("Timestamp = %d, want %d", got.Timestamp, tt.hb.Timestamp) + } + }) + } +} + +func TestDiscoveryAnnounce_JSONRoundTrip(t *testing.T) { + tests := []struct { + name string + announce DiscoveryAnnounce + }{ + { + name: "worker announcement", + announce: DiscoveryAnnounce{ + Node: NodeInfo{ + ID: "node-1", + Role: RoleWorker, + Capabilities: []string{"code"}, + Status: StatusOnline, + MaxTasks: 4, + Metadata: map[string]string{}, + }, + Timestamp: 1000000, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.announce) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var got DiscoveryAnnounce + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if got.Node.ID != tt.announce.Node.ID { + t.Errorf("Node.ID = %q, want %q", got.Node.ID, tt.announce.Node.ID) + } + if got.Node.Role != tt.announce.Node.Role { + t.Errorf("Node.Role = %q, want %q", got.Node.Role, tt.announce.Node.Role) + } + if got.Timestamp != tt.announce.Timestamp { + t.Errorf("Timestamp = %d, want %d", got.Timestamp, tt.announce.Timestamp) + } + }) + } +} + +func TestDiscoveryQuery_JSONRoundTrip(t *testing.T) { + tests := []struct { + name string + query DiscoveryQuery + }{ + { + name: "query all", + query: DiscoveryQuery{ + RequesterID: "node-0", + }, + }, + { + name: "query by role", + query: DiscoveryQuery{ + RequesterID: "node-0", + Role: RoleWorker, + }, + }, + { + name: "query by capability", + query: DiscoveryQuery{ + RequesterID: "node-0", + Capability: "code", + }, + }, + { + name: "query by both", + query: DiscoveryQuery{ + RequesterID: "node-0", + Role: RoleSpecialist, + Capability: "ml", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.query) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var got DiscoveryQuery + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if got.RequesterID != tt.query.RequesterID { + t.Errorf("RequesterID = %q, want %q", got.RequesterID, tt.query.RequesterID) + } + if got.Role != tt.query.Role { + t.Errorf("Role = %q, want %q", got.Role, tt.query.Role) + } + if got.Capability != tt.query.Capability { + t.Errorf("Capability = %q, want %q", got.Capability, tt.query.Capability) + } + }) + } +} + +func TestTaskProgress_JSONRoundTrip(t *testing.T) { + tests := []struct { + name string + progress TaskProgress + }{ + { + name: "half complete", + progress: TaskProgress{ + TaskID: "task-abc12345", + NodeID: "node-1", + Progress: 0.5, + Message: "processing", + }, + }, + { + name: "complete", + progress: TaskProgress{ + TaskID: "task-done0001", + NodeID: "node-2", + Progress: 1.0, + Message: "finished", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.progress) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var got TaskProgress + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if got.TaskID != tt.progress.TaskID { + t.Errorf("TaskID = %q, want %q", got.TaskID, tt.progress.TaskID) + } + if got.NodeID != tt.progress.NodeID { + t.Errorf("NodeID = %q, want %q", got.NodeID, tt.progress.NodeID) + } + if got.Progress != tt.progress.Progress { + t.Errorf("Progress = %f, want %f", got.Progress, tt.progress.Progress) + } + if got.Message != tt.progress.Message { + t.Errorf("Message = %q, want %q", got.Message, tt.progress.Message) + } + }) + } +} diff --git a/pkg/swarm/worker.go b/pkg/swarm/worker.go new file mode 100644 index 000000000..d02c5fe89 --- /dev/null +++ b/pkg/swarm/worker.go @@ -0,0 +1,303 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// Worker executes tasks received from the swarm +type Worker struct { + bridge *NATSBridge + temporal *TemporalClient + agentLoop *agent.AgentLoop + provider providers.LLMProvider + nodeInfo *NodeInfo + cfg *config.SwarmConfig + taskQueue chan *SwarmTask + activeTasks sync.Map + running atomic.Bool + tasksRunning atomic.Int32 // atomic counter for thread-safe tracking +} + +// NewWorker creates a new worker +func NewWorker( + cfg *config.SwarmConfig, + bridge *NATSBridge, + temporal *TemporalClient, + agentLoop *agent.AgentLoop, + provider providers.LLMProvider, + nodeInfo *NodeInfo, +) *Worker { + return &Worker{ + bridge: bridge, + temporal: temporal, + agentLoop: agentLoop, + provider: provider, + nodeInfo: nodeInfo, + cfg: cfg, + taskQueue: make(chan *SwarmTask, cfg.MaxConcurrent*2), + } +} + +// Start begins the worker +func (w *Worker) Start(ctx context.Context) error { + w.running.Store(true) + + // Set up task handler + w.bridge.SetOnTaskReceived(func(task *SwarmTask) { + select { + case w.taskQueue <- task: + default: + logger.WarnCF("swarm", "Task queue full, rejecting task", map[string]interface{}{ + "task_id": task.ID, + }) + } + }) + + // Start task processors + for i := 0; i < w.cfg.MaxConcurrent; i++ { + go w.processTaskLoop(ctx) + } + + logger.InfoCF("swarm", "Worker started", map[string]interface{}{ + "node_id": w.nodeInfo.ID, + "max_concurrent": w.cfg.MaxConcurrent, + "capabilities": w.nodeInfo.Capabilities, + }) + + return nil +} + +// Stop stops the worker +func (w *Worker) Stop() { + w.running.Store(false) + close(w.taskQueue) +} + +func (w *Worker) processTaskLoop(ctx context.Context) { + for { + select { + case task, ok := <-w.taskQueue: + if !ok { + return + } + w.executeTask(ctx, task) + case <-ctx.Done(): + return + } + } +} + +func (w *Worker) executeTask(ctx context.Context, task *SwarmTask) { + logger.InfoCF("swarm", "Executing task", map[string]interface{}{ + "task_id": task.ID, + "capability": task.Capability, + }) + + // Track active task using atomic operations for thread safety + w.activeTasks.Store(task.ID, task) + current := w.tasksRunning.Add(1) + w.nodeInfo.TasksRunning = int(current) + w.updateLoad() + + defer func() { + w.activeTasks.Delete(task.ID) + current := w.tasksRunning.Add(-1) + w.nodeInfo.TasksRunning = int(current) + w.updateLoad() + }() + + // Create timeout context + timeout := time.Duration(task.Timeout) * time.Millisecond + if timeout == 0 { + timeout = 10 * time.Minute + } + taskCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // Send progress updates periodically + progressDone := make(chan struct{}) + go w.sendProgressUpdates(taskCtx, task, progressDone) + + // Execute using local agent + result, err := w.agentLoop.ProcessDirect(taskCtx, task.Prompt, "swarm:"+task.ID) + + close(progressDone) + + // Send result + taskResult := &TaskResult{ + TaskID: task.ID, + NodeID: w.nodeInfo.ID, + CompletedAt: time.Now().UnixMilli(), + } + + if err != nil { + taskResult.Status = string(TaskFailed) + taskResult.Error = err.Error() + logger.WarnCF("swarm", "Task execution failed", map[string]interface{}{ + "task_id": task.ID, + "error": err.Error(), + }) + } else { + taskResult.Status = string(TaskDone) + taskResult.Result = result + logger.InfoCF("swarm", "Task completed", map[string]interface{}{ + "task_id": task.ID, + "result_length": len(result), + }) + } + + if err := w.bridge.PublishTaskResult(taskResult); err != nil { + logger.ErrorCF("swarm", "Failed to publish task result", map[string]interface{}{ + "task_id": task.ID, + "error": err.Error(), + }) + } +} + +// sendProgressUpdates sends periodic progress updates for a running task +// The progress is estimated based on elapsed time relative to the task timeout +func (w *Worker) sendProgressUpdates(ctx context.Context, task *SwarmTask, done chan struct{}) { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + startTime := time.Now() + + // Determine timeout for progress estimation + timeout := time.Duration(task.Timeout) * time.Millisecond + if timeout == 0 { + timeout = 10 * time.Minute + } + + for { + select { + case <-ticker.C: + // Estimate progress based on elapsed time + elapsed := time.Since(startTime) + progress := float64(elapsed.Milliseconds()) / float64(timeout.Milliseconds()) + + // Clamp progress between 0.1 and 0.9 (we don't know actual LLM progress) + if progress < 0.1 { + progress = 0.1 + } else if progress > 0.9 { + progress = 0.9 + } + + // Generate contextual message based on progress + message := "processing" + if progress < 0.3 { + message = "initializing" + } else if progress < 0.6 { + message = "processing" + } else if progress < 0.9 { + message = "finalizing" + } + + progressUpdate := &TaskProgress{ + TaskID: task.ID, + NodeID: w.nodeInfo.ID, + Progress: progress, + Message: message, + } + if err := w.bridge.PublishTaskProgress(progressUpdate); err != nil { + logger.DebugCF("swarm", "Failed to publish progress", map[string]interface{}{ + "error": err.Error(), + }) + } + case <-done: + return + case <-ctx.Done(): + return + } + } +} + +func (w *Worker) updateLoad() { + tasksRunning := int(w.tasksRunning.Load()) + if w.cfg.MaxConcurrent > 0 { + w.nodeInfo.Load = float64(tasksRunning) / float64(w.cfg.MaxConcurrent) + } + if tasksRunning >= w.cfg.MaxConcurrent { + w.nodeInfo.Status = StatusBusy + } else { + w.nodeInfo.Status = StatusOnline + } + w.nodeInfo.TasksRunning = tasksRunning +} + +// ActiveTaskCount returns the number of currently executing tasks +func (w *Worker) ActiveTaskCount() int { + count := 0 + w.activeTasks.Range(func(key, value interface{}) bool { + count++ + return true + }) + return count +} + +// RecoverFromCheckpoint resumes task execution from a saved checkpoint +func (w *Worker) RecoverFromCheckpoint(ctx context.Context, task *SwarmTask, checkpoint *TaskCheckpoint) (string, error) { + logger.InfoCF("swarm", "Recovering task from checkpoint", map[string]interface{}{ + "task_id": task.ID, + "checkpoint_id": checkpoint.CheckpointID, + "progress": checkpoint.Progress, + }) + + // Build recovery prompt with checkpoint context + recoveryPrompt := fmt.Sprintf(`[TASK RECOVERY MODE] + +You are resuming execution of a task that was interrupted. + +Original Task: %s + +Checkpoint Progress: %.0f%% + +Partial Work Completed: +%s + +Previous Context: +- Last checkpoint was taken by node: %s +- Checkpoint type: %s +- Timestamp: %s + +Continue from where the previous execution left off. Use the partial result as context and complete the remaining work.`, + task.Prompt, + checkpoint.Progress*100, + checkpoint.PartialResult, + checkpoint.NodeID, + string(checkpoint.Type), + time.UnixMilli(checkpoint.Timestamp).Format(time.RFC3339), + ) + + // Execute with the recovery prompt + result, err := w.agentLoop.ProcessDirect(ctx, recoveryPrompt, "swarm:recovery:"+task.ID) + if err != nil { + logger.WarnCF("swarm", "Checkpoint recovery failed", map[string]interface{}{ + "task_id": task.ID, + "checkpoint_id": checkpoint.CheckpointID, + "error": err.Error(), + }) + return "", err + } + + logger.InfoCF("swarm", "Task recovery completed", map[string]interface{}{ + "task_id": task.ID, + "checkpoint_id": checkpoint.CheckpointID, + "result_length": len(result), + }) + + return result, nil +} diff --git a/pkg/swarm/worker_test.go b/pkg/swarm/worker_test.go new file mode 100644 index 000000000..53c4446ba --- /dev/null +++ b/pkg/swarm/worker_test.go @@ -0,0 +1,307 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestWorker_ExecuteTask(t *testing.T) { + _, url, cleanup := startTestNATS(t) + defer cleanup() + + tests := []struct { + name string + chatResponse string + chatErr error + wantStatus string + wantContains string // check result or error + }{ + { + name: "successful execution", + chatResponse: "task completed successfully", + chatErr: nil, + wantStatus: string(TaskDone), + wantContains: "task completed", + }, + { + name: "execution returns error", + chatResponse: "", + chatErr: fmt.Errorf("agent processing failed"), + wantStatus: string(TaskFailed), + wantContains: "agent processing failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + workerNode := newTestNodeInfo("exec-worker", RoleWorker, []string{"code"}, 4) + workerBridge := connectTestBridge(t, url, workerNode) + defer workerBridge.Stop() + + if err := workerBridge.Start(context.Background()); err != nil { + t.Fatalf("Start() error: %v", err) + } + + swarmCfg := newTestSwarmConfig(0) + swarmCfg.MaxConcurrent = 2 + temporal := NewTemporalClient(&config.TemporalConfig{TaskQueue: "test"}) + agentLoop := newTestAgentLoop(t, tt.chatResponse, tt.chatErr) + + worker := NewWorker(swarmCfg, workerBridge, temporal, agentLoop, &mockLLMProvider{}, workerNode) + + // Subscribe to results from coordinator side + coordNode := newTestNodeInfo("exec-coord", RoleCoordinator, nil, 1) + coordBridge := connectTestBridge(t, url, coordNode) + defer coordBridge.Stop() + + taskID := fmt.Sprintf("task-exec%04d", time.Now().UnixNano()%10000) + var received atomic.Value + sub, err := coordBridge.SubscribeTaskResult(taskID, func(r *TaskResult) { + received.Store(r) + }) + if err != nil { + t.Fatalf("SubscribeTaskResult() error: %v", err) + } + defer sub.Unsubscribe() + + // Start worker + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := worker.Start(ctx); err != nil { + t.Fatalf("worker Start() error: %v", err) + } + + // Send task directly to worker's queue + task := &SwarmTask{ + ID: taskID, + Type: TaskTypeDirect, + Capability: "code", + Prompt: "execute this", + Status: TaskPending, + Timeout: 5000, + } + worker.taskQueue <- task + + // Wait for result + ok := waitFor(t, 5*time.Second, func() bool { + return received.Load() != nil + }) + if !ok { + t.Fatal("timed out waiting for task result") + } + + got := received.Load().(*TaskResult) + if got.Status != tt.wantStatus { + t.Errorf("Status = %q, want %q", got.Status, tt.wantStatus) + } + combined := got.Result + got.Error + if !strings.Contains(combined, tt.wantContains) { + t.Errorf("Result+Error = %q, want it to contain %q", combined, tt.wantContains) + } + if got.NodeID != "exec-worker" { + t.Errorf("NodeID = %q, want %q", got.NodeID, "exec-worker") + } + }) + } +} + +func TestWorker_LoadTracking(t *testing.T) { + tests := []struct { + name string + maxConc int + setupTasks int // number of tasks running at check time + wantBusy bool + }{ + { + name: "initial state is idle", + maxConc: 4, + setupTasks: 0, + wantBusy: false, + }, + { + name: "max concurrent reached marks busy", + maxConc: 2, + setupTasks: 2, + wantBusy: true, + }, + { + name: "below max is online", + maxConc: 4, + setupTasks: 1, + wantBusy: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + workerNode := newTestNodeInfo("load-worker", RoleWorker, []string{"code"}, tt.maxConc) + swarmCfg := newTestSwarmConfig(0) + swarmCfg.MaxConcurrent = tt.maxConc + + // Simulate load state directly on nodeInfo + workerNode.TasksRunning = tt.setupTasks + if tt.maxConc > 0 { + workerNode.Load = float64(tt.setupTasks) / float64(tt.maxConc) + } + if tt.setupTasks >= tt.maxConc { + workerNode.Status = StatusBusy + } else { + workerNode.Status = StatusOnline + } + + if tt.wantBusy { + if workerNode.Status != StatusBusy { + t.Errorf("Status = %q, want %q", workerNode.Status, StatusBusy) + } + } else { + if workerNode.Status != StatusOnline { + t.Errorf("Status = %q, want %q", workerNode.Status, StatusOnline) + } + } + + expectedLoad := float64(tt.setupTasks) / float64(tt.maxConc) + if workerNode.Load != expectedLoad { + t.Errorf("Load = %f, want %f", workerNode.Load, expectedLoad) + } + }) + } +} + +func TestWorker_UpdateLoad(t *testing.T) { + // Test the actual updateLoad method on Worker + tests := []struct { + name string + maxConc int + tasksRunning int + wantLoad float64 + wantStatus NodeStatus + }{ + { + name: "idle worker", + maxConc: 4, + tasksRunning: 0, + wantLoad: 0.0, + wantStatus: StatusOnline, + }, + { + name: "partially loaded", + maxConc: 4, + tasksRunning: 2, + wantLoad: 0.5, + wantStatus: StatusOnline, + }, + { + name: "fully loaded", + maxConc: 2, + tasksRunning: 2, + wantLoad: 1.0, + wantStatus: StatusBusy, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nodeInfo := newTestNodeInfo("update-load", RoleWorker, []string{"code"}, tt.maxConc) + + swarmCfg := newTestSwarmConfig(0) + swarmCfg.MaxConcurrent = tt.maxConc + + w := &Worker{ + nodeInfo: nodeInfo, + cfg: swarmCfg, + } + // Set atomic counter for thread-safe load tracking + w.tasksRunning.Store(int32(tt.tasksRunning)) + + w.updateLoad() + + if w.nodeInfo.Load != tt.wantLoad { + t.Errorf("Load = %f, want %f", w.nodeInfo.Load, tt.wantLoad) + } + if w.nodeInfo.Status != tt.wantStatus { + t.Errorf("Status = %q, want %q", w.nodeInfo.Status, tt.wantStatus) + } + }) + } +} + +func TestWorker_QueueFullRejection(t *testing.T) { + tests := []struct { + name string + maxConc int + fillCount int // fill this many tasks first + expectAccept bool + }{ + { + name: "queue accepts when space available", + maxConc: 1, + fillCount: 0, + expectAccept: true, + }, + { + name: "queue rejects when full", + maxConc: 1, + fillCount: 2, // queue size is maxConcurrent*2 = 2, fill completely + expectAccept: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + swarmCfg := newTestSwarmConfig(0) + swarmCfg.MaxConcurrent = tt.maxConc + + nodeInfo := newTestNodeInfo("queue-worker", RoleWorker, []string{"code"}, tt.maxConc) + + // Create worker with small queue + w := &Worker{ + nodeInfo: nodeInfo, + cfg: swarmCfg, + taskQueue: make(chan *SwarmTask, swarmCfg.MaxConcurrent*2), + } + + // Fill the queue + for i := 0; i < tt.fillCount; i++ { + task := &SwarmTask{ + ID: fmt.Sprintf("fill-task-%d", i), + Prompt: "filler", + } + select { + case w.taskQueue <- task: + default: + // Queue full already + } + } + + // Try to send one more + testTask := &SwarmTask{ + ID: "test-overflow", + Prompt: "overflow test", + } + + accepted := false + select { + case w.taskQueue <- testTask: + accepted = true + default: + accepted = false + } + + if accepted != tt.expectAccept { + t.Errorf("accepted = %v, want %v", accepted, tt.expectAccept) + } + }) + } +} diff --git a/pkg/swarm/workflows.go b/pkg/swarm/workflows.go new file mode 100644 index 000000000..a86a34cd1 --- /dev/null +++ b/pkg/swarm/workflows.go @@ -0,0 +1,142 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "time" + + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/workflow" +) + +// SwarmWorkflow is the main workflow for task orchestration. +// It decomposes a task, runs subtasks in parallel, and synthesizes results. +func SwarmWorkflow(ctx workflow.Context, task *SwarmTask) (string, error) { + wfLogger := workflow.GetLogger(ctx) + wfLogger.Info("Starting swarm workflow", "task_id", task.ID) + + // Step 1: Decompose task into subtasks + ctx1 := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 2 * time.Minute, + RetryPolicy: &temporal.RetryPolicy{ + InitialInterval: time.Second, + BackoffCoefficient: 2.0, + MaximumAttempts: 3, + }, + }) + + var subtasks []*SwarmTask + err := workflow.ExecuteActivity(ctx1, DecomposeTaskActivity, task).Get(ctx, &subtasks) + if err != nil { + return "", fmt.Errorf("failed to decompose task: %w", err) + } + + // If no subtasks, execute directly + if len(subtasks) == 0 { + var result string + err := workflow.ExecuteActivity(ctx1, ExecuteDirectActivity, task).Get(ctx, &result) + if err != nil { + return "", err + } + return result, nil + } + + // Step 2: Execute subtasks in parallel + futures := make([]workflow.Future, len(subtasks)) + for i, sub := range subtasks { + ctx2 := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 10 * time.Minute, + HeartbeatTimeout: 30 * time.Second, + RetryPolicy: &temporal.RetryPolicy{ + InitialInterval: 5 * time.Second, + BackoffCoefficient: 2.0, + MaximumAttempts: 3, + }, + }) + futures[i] = workflow.ExecuteActivity(ctx2, ExecuteSubtaskActivity, sub) + } + + // Collect results + results := make([]string, len(futures)) + for i, f := range futures { + var result string + if err := f.Get(ctx, &result); err != nil { + wfLogger.Warn("Subtask failed", "error", err, "index", i) + results[i] = fmt.Sprintf("[FAILED] %v", err) + } else { + results[i] = result + } + } + + // Step 3: Synthesize final result + // Add retry policy for LLM synthesis failures (API errors, rate limits, etc.) + ctx3 := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + StartToCloseTimeout: 5 * time.Minute, + RetryPolicy: &temporal.RetryPolicy{ + InitialInterval: 2 * time.Second, + BackoffCoefficient: 2.0, + MaximumInterval: 30 * time.Second, + MaximumAttempts: 3, + }, + }) + + var finalResult string + err = workflow.ExecuteActivity(ctx3, SynthesizeResultsActivity, task, results).Get(ctx, &finalResult) + if err != nil { + return "", fmt.Errorf("failed to synthesize results: %w", err) + } + + return finalResult, nil +} + +// Note: Activity implementations are in activities.go as methods on the Activities struct. +// These global functions are retained for Temporal registration but delegate to the Activities struct. + +// Global activity wrappers for Temporal registration +// These are registered by the Temporal worker and dispatch to the Activities struct + +// ActivitiesRegistry holds the global activities instance for Temporal workflow execution +// This is set before starting the Temporal worker and accessed by the activity functions +var ActivitiesRegistry *Activities + +// RegisterActivities registers the activities instance for workflow use +func RegisterActivities(activities *Activities) { + ActivitiesRegistry = activities +} + +// DecomposeTaskActivity is a wrapper for the Activities method +func DecomposeTaskActivity(ctx context.Context, task *SwarmTask) ([]*SwarmTask, error) { + if ActivitiesRegistry == nil { + return nil, fmt.Errorf("activities not initialized") + } + return ActivitiesRegistry.DecomposeTaskActivity(ctx, task) +} + +// ExecuteDirectActivity is a wrapper for the Activities method +func ExecuteDirectActivity(ctx context.Context, task *SwarmTask) (string, error) { + if ActivitiesRegistry == nil { + return "", fmt.Errorf("activities not initialized") + } + return ActivitiesRegistry.ExecuteDirectActivity(ctx, task) +} + +// ExecuteSubtaskActivity is a wrapper for the Activities method +func ExecuteSubtaskActivity(ctx context.Context, task *SwarmTask) (string, error) { + if ActivitiesRegistry == nil { + return "", fmt.Errorf("activities not initialized") + } + return ActivitiesRegistry.ExecuteSubtaskActivity(ctx, task) +} + +// SynthesizeResultsActivity is a wrapper for the Activities method +func SynthesizeResultsActivity(ctx context.Context, task *SwarmTask, results []string) (string, error) { + if ActivitiesRegistry == nil { + return "", fmt.Errorf("activities not initialized") + } + return ActivitiesRegistry.SynthesizeResultsActivity(ctx, task, results) +} diff --git a/pkg/tools/base.go b/pkg/tools/base.go index b13174633..770d8cb04 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -6,8 +6,8 @@ import "context" type Tool interface { Name() string Description() string - Parameters() map[string]interface{} - Execute(ctx context.Context, args map[string]interface{}) *ToolResult + Parameters() map[string]any + Execute(ctx context.Context, args map[string]any) *ToolResult } // ContextualTool is an optional interface that tools can implement @@ -69,10 +69,10 @@ type AsyncTool interface { SetCallback(cb AsyncCallback) } -func ToolToSchema(tool Tool) map[string]interface{} { - return map[string]interface{}{ +func ToolToSchema(tool Tool) map[string]any { + return map[string]any{ "type": "function", - "function": map[string]interface{}{ + "function": map[string]any{ "name": tool.Name(), "description": tool.Description(), "parameters": tool.Parameters(), diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 3f2042e38..562fffc84 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -1,4 +1,4 @@ - package tools +package tools import ( "context" @@ -7,6 +7,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -28,12 +29,18 @@ type CronTool struct { } // NewCronTool creates a new CronTool -func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string) *CronTool { +// execTimeout: 0 means no timeout, >0 sets the timeout duration +func NewCronTool( + cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, + execTimeout time.Duration, config *config.Config, +) *CronTool { + execTool := NewExecToolWithConfig(workspace, restrict, config) + execTool.SetTimeout(execTimeout) return &CronTool{ cronService: cronService, executor: executor, msgBus: msgBus, - execTool: NewExecTool(workspace, false), + execTool: execTool, } } @@ -48,40 +55,40 @@ func (t *CronTool) Description() string { } // Parameters returns the tool parameters schema -func (t *CronTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (t *CronTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "action": map[string]interface{}{ + "properties": map[string]any{ + "action": map[string]any{ "type": "string", "enum": []string{"add", "list", "remove", "enable", "disable"}, "description": "Action to perform. Use 'add' when user wants to schedule a reminder or task.", }, - "message": map[string]interface{}{ + "message": map[string]any{ "type": "string", "description": "The reminder/task message to display when triggered. If 'command' is used, this describes what the command does.", }, - "command": map[string]interface{}{ + "command": map[string]any{ "type": "string", "description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.", }, - "at_seconds": map[string]interface{}{ + "at_seconds": map[string]any{ "type": "integer", "description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.", }, - "every_seconds": map[string]interface{}{ + "every_seconds": map[string]any{ "type": "integer", "description": "Recurring interval in seconds (e.g., 3600 for every hour). Use this ONLY for recurring tasks like 'every 2 hours' or 'daily reminder'.", }, - "cron_expr": map[string]interface{}{ + "cron_expr": map[string]any{ "type": "string", "description": "Cron expression for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am). Use this for complex recurring schedules.", }, - "job_id": map[string]interface{}{ + "job_id": map[string]any{ "type": "string", "description": "Job ID (for remove/enable/disable)", }, - "deliver": map[string]interface{}{ + "deliver": map[string]any{ "type": "boolean", "description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true", }, @@ -99,7 +106,7 @@ func (t *CronTool) SetContext(channel, chatID string) { } // Execute runs the tool with the given arguments -func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult { action, ok := args["action"].(string) if !ok { return ErrorResult("action is required") @@ -121,7 +128,7 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *To } } -func (t *CronTool) addJob(args map[string]interface{}) *ToolResult { +func (t *CronTool) addJob(args map[string]any) *ToolResult { t.mu.RLock() channel := t.channel chatID := t.chatID @@ -147,8 +154,8 @@ func (t *CronTool) addJob(args map[string]interface{}) *ToolResult { if hasAt { atMS := time.Now().UnixMilli() + int64(atSeconds)*1000 schedule = cron.CronSchedule{ - Kind: "at", - AtMS: &atMS, + Kind: "at", + AtMS: &atMS, } } else if hasEvery { everyMS := int64(everySeconds) * 1000 @@ -194,7 +201,7 @@ func (t *CronTool) addJob(args map[string]interface{}) *ToolResult { if err != nil { return ErrorResult(fmt.Sprintf("Error adding job: %v", err)) } - + if command != "" { job.Payload.Command = command // Need to save the updated payload @@ -229,7 +236,7 @@ func (t *CronTool) listJobs() *ToolResult { return SilentResult(result) } -func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult { +func (t *CronTool) removeJob(args map[string]any) *ToolResult { jobID, ok := args["job_id"].(string) if !ok || jobID == "" { return ErrorResult("job_id is required for remove") @@ -241,7 +248,7 @@ func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult { return ErrorResult(fmt.Sprintf("Job %s not found", jobID)) } -func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult { +func (t *CronTool) enableJob(args map[string]any, enable bool) *ToolResult { jobID, ok := args["job_id"].(string) if !ok || jobID == "" { return ErrorResult("job_id is required for enable/disable") @@ -275,7 +282,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // Execute command if present if job.Payload.Command != "" { - args := map[string]interface{}{ + args := map[string]any{ "command": job.Payload.Command, } @@ -316,7 +323,6 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { channel, chatID, ) - if err != nil { return fmt.Sprintf("Error: %v", err) } diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go index 1e7c33b45..c28ca6ca2 100644 --- a/pkg/tools/edit.go +++ b/pkg/tools/edit.go @@ -30,19 +30,19 @@ func (t *EditFileTool) Description() string { return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file." } -func (t *EditFileTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (t *EditFileTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "path": map[string]interface{}{ + "properties": map[string]any{ + "path": map[string]any{ "type": "string", "description": "The file path to edit", }, - "old_text": map[string]interface{}{ + "old_text": map[string]any{ "type": "string", "description": "The exact text to find and replace", }, - "new_text": map[string]interface{}{ + "new_text": map[string]any{ "type": "string", "description": "The text to replace with", }, @@ -51,7 +51,7 @@ func (t *EditFileTool) Parameters() map[string]interface{} { } } -func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +func (t *EditFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult { path, ok := args["path"].(string) if !ok { return ErrorResult("path is required") @@ -72,7 +72,7 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) return ErrorResult(err.Error()) } - if _, err := os.Stat(resolvedPath); os.IsNotExist(err) { + if _, err = os.Stat(resolvedPath); os.IsNotExist(err) { return ErrorResult(fmt.Sprintf("file not found: %s", path)) } @@ -89,12 +89,14 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) count := strings.Count(contentStr, oldText) if count > 1 { - return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count)) + return ErrorResult( + fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count), + ) } newContent := strings.Replace(contentStr, oldText, newText, 1) - if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil { + if err := os.WriteFile(resolvedPath, []byte(newContent), 0o644); err != nil { return ErrorResult(fmt.Sprintf("failed to write file: %v", err)) } @@ -118,15 +120,15 @@ func (t *AppendFileTool) Description() string { return "Append content to the end of a file" } -func (t *AppendFileTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (t *AppendFileTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "path": map[string]interface{}{ + "properties": map[string]any{ + "path": map[string]any{ "type": "string", "description": "The file path to append to", }, - "content": map[string]interface{}{ + "content": map[string]any{ "type": "string", "description": "The content to append", }, @@ -135,7 +137,7 @@ func (t *AppendFileTool) Parameters() map[string]interface{} { } } -func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +func (t *AppendFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult { path, ok := args["path"].(string) if !ok { return ErrorResult("path is required") @@ -151,7 +153,7 @@ func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{ return ErrorResult(err.Error()) } - f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return ErrorResult(fmt.Sprintf("failed to open file: %v", err)) } diff --git a/pkg/tools/edit_test.go b/pkg/tools/edit_test.go index c4c02772d..6780dd9f6 100644 --- a/pkg/tools/edit_test.go +++ b/pkg/tools/edit_test.go @@ -12,11 +12,11 @@ import ( func TestEditTool_EditFile_Success(t *testing.T) { tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "test.txt") - os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644) + os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0o644) tool := NewEditFileTool(tmpDir, true) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": testFile, "old_text": "World", "new_text": "Universe", @@ -60,7 +60,7 @@ func TestEditTool_EditFile_NotFound(t *testing.T) { tool := NewEditFileTool(tmpDir, true) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": testFile, "old_text": "old", "new_text": "new", @@ -83,11 +83,11 @@ func TestEditTool_EditFile_NotFound(t *testing.T) { func TestEditTool_EditFile_OldTextNotFound(t *testing.T) { tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "test.txt") - os.WriteFile(testFile, []byte("Hello World"), 0644) + os.WriteFile(testFile, []byte("Hello World"), 0o644) tool := NewEditFileTool(tmpDir, true) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": testFile, "old_text": "Goodbye", "new_text": "Hello", @@ -110,11 +110,11 @@ func TestEditTool_EditFile_OldTextNotFound(t *testing.T) { func TestEditTool_EditFile_MultipleMatches(t *testing.T) { tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "test.txt") - os.WriteFile(testFile, []byte("test test test"), 0644) + os.WriteFile(testFile, []byte("test test test"), 0o644) tool := NewEditFileTool(tmpDir, true) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": testFile, "old_text": "test", "new_text": "done", @@ -138,11 +138,11 @@ func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) { tmpDir := t.TempDir() otherDir := t.TempDir() testFile := filepath.Join(otherDir, "test.txt") - os.WriteFile(testFile, []byte("content"), 0644) + os.WriteFile(testFile, []byte("content"), 0o644) tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": testFile, "old_text": "content", "new_text": "new", @@ -165,7 +165,7 @@ func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) { func TestEditTool_EditFile_MissingPath(t *testing.T) { tool := NewEditFileTool("", false) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "old_text": "old", "new_text": "new", } @@ -182,7 +182,7 @@ func TestEditTool_EditFile_MissingPath(t *testing.T) { func TestEditTool_EditFile_MissingOldText(t *testing.T) { tool := NewEditFileTool("", false) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": "/tmp/test.txt", "new_text": "new", } @@ -199,7 +199,7 @@ func TestEditTool_EditFile_MissingOldText(t *testing.T) { func TestEditTool_EditFile_MissingNewText(t *testing.T) { tool := NewEditFileTool("", false) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": "/tmp/test.txt", "old_text": "old", } @@ -216,11 +216,11 @@ func TestEditTool_EditFile_MissingNewText(t *testing.T) { func TestEditTool_AppendFile_Success(t *testing.T) { tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "test.txt") - os.WriteFile(testFile, []byte("Initial content"), 0644) + os.WriteFile(testFile, []byte("Initial content"), 0o644) tool := NewAppendFileTool("", false) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": testFile, "content": "\nAppended content", } @@ -260,7 +260,7 @@ func TestEditTool_AppendFile_Success(t *testing.T) { func TestEditTool_AppendFile_MissingPath(t *testing.T) { tool := NewAppendFileTool("", false) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "content": "test", } @@ -276,7 +276,7 @@ func TestEditTool_AppendFile_MissingPath(t *testing.T) { func TestEditTool_AppendFile_MissingContent(t *testing.T) { tool := NewAppendFileTool("", false) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": "/tmp/test.txt", } diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 237687734..1bf50906e 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -29,13 +29,56 @@ func validatePath(path, workspace string, restrict bool) (string, error) { } } - if restrict && !strings.HasPrefix(absPath, absWorkspace) { - return "", fmt.Errorf("access denied: path is outside the workspace") + if restrict { + if !isWithinWorkspace(absPath, absWorkspace) { + return "", fmt.Errorf("access denied: path is outside the workspace") + } + + var resolved string + workspaceReal := absWorkspace + if resolved, err = filepath.EvalSymlinks(absWorkspace); err == nil { + workspaceReal = resolved + } + + if resolved, err = filepath.EvalSymlinks(absPath); err == nil { + if !isWithinWorkspace(resolved, workspaceReal) { + return "", fmt.Errorf("access denied: symlink resolves outside workspace") + } + } else if os.IsNotExist(err) { + var parentResolved string + if parentResolved, err = resolveExistingAncestor(filepath.Dir(absPath)); err == nil { + if !isWithinWorkspace(parentResolved, workspaceReal) { + return "", fmt.Errorf("access denied: symlink resolves outside workspace") + } + } else if !os.IsNotExist(err) { + return "", fmt.Errorf("failed to resolve path: %w", err) + } + } else { + return "", fmt.Errorf("failed to resolve path: %w", err) + } } return absPath, nil } +func resolveExistingAncestor(path string) (string, error) { + for current := filepath.Clean(path); ; current = filepath.Dir(current) { + if resolved, err := filepath.EvalSymlinks(current); err == nil { + return resolved, nil + } else if !os.IsNotExist(err) { + return "", err + } + if filepath.Dir(current) == current { + return "", os.ErrNotExist + } + } +} + +func isWithinWorkspace(candidate, workspace string) bool { + rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate)) + return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) +} + type ReadFileTool struct { workspace string restrict bool @@ -53,11 +96,11 @@ func (t *ReadFileTool) Description() string { return "Read the contents of a file" } -func (t *ReadFileTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (t *ReadFileTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "path": map[string]interface{}{ + "properties": map[string]any{ + "path": map[string]any{ "type": "string", "description": "Path to the file to read", }, @@ -66,7 +109,7 @@ func (t *ReadFileTool) Parameters() map[string]interface{} { } } -func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult { path, ok := args["path"].(string) if !ok { return ErrorResult("path is required") @@ -102,15 +145,15 @@ func (t *WriteFileTool) Description() string { return "Write content to a file" } -func (t *WriteFileTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (t *WriteFileTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "path": map[string]interface{}{ + "properties": map[string]any{ + "path": map[string]any{ "type": "string", "description": "Path to the file to write", }, - "content": map[string]interface{}{ + "content": map[string]any{ "type": "string", "description": "Content to write to the file", }, @@ -119,7 +162,7 @@ func (t *WriteFileTool) Parameters() map[string]interface{} { } } -func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult { path, ok := args["path"].(string) if !ok { return ErrorResult("path is required") @@ -136,11 +179,11 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{} } dir := filepath.Dir(resolvedPath) - if err := os.MkdirAll(dir, 0755); err != nil { + if err := os.MkdirAll(dir, 0o755); err != nil { return ErrorResult(fmt.Sprintf("failed to create directory: %v", err)) } - if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil { + if err := os.WriteFile(resolvedPath, []byte(content), 0o644); err != nil { return ErrorResult(fmt.Sprintf("failed to write file: %v", err)) } @@ -164,11 +207,11 @@ func (t *ListDirTool) Description() string { return "List files and directories in a path" } -func (t *ListDirTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (t *ListDirTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "path": map[string]interface{}{ + "properties": map[string]any{ + "path": map[string]any{ "type": "string", "description": "Path to list", }, @@ -177,7 +220,7 @@ func (t *ListDirTool) Parameters() map[string]interface{} { } } -func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +func (t *ListDirTool) Execute(ctx context.Context, args map[string]any) *ToolResult { path, ok := args["path"].(string) if !ok { path = "." diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go index 2707f29b5..5daa3dcea 100644 --- a/pkg/tools/filesystem_test.go +++ b/pkg/tools/filesystem_test.go @@ -12,11 +12,11 @@ import ( func TestFilesystemTool_ReadFile_Success(t *testing.T) { tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "test.txt") - os.WriteFile(testFile, []byte("test content"), 0644) + os.WriteFile(testFile, []byte("test content"), 0o644) tool := &ReadFileTool{} ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": testFile, } @@ -43,7 +43,7 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) { func TestFilesystemTool_ReadFile_NotFound(t *testing.T) { tool := &ReadFileTool{} ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": "/nonexistent_file_12345.txt", } @@ -64,7 +64,7 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) { func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) { tool := &ReadFileTool{} ctx := context.Background() - args := map[string]interface{}{} + args := map[string]any{} result := tool.Execute(ctx, args) @@ -86,7 +86,7 @@ func TestFilesystemTool_WriteFile_Success(t *testing.T) { tool := &WriteFileTool{} ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": testFile, "content": "hello world", } @@ -125,7 +125,7 @@ func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) { tool := &WriteFileTool{} ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": testFile, "content": "test", } @@ -151,7 +151,7 @@ func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) { func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) { tool := &WriteFileTool{} ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "content": "test", } @@ -167,7 +167,7 @@ func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) { func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) { tool := &WriteFileTool{} ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": "/tmp/test.txt", } @@ -179,7 +179,8 @@ func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) { } // Should mention required parameter - if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") { + if !strings.Contains(result.ForLLM, "content is required") && + !strings.Contains(result.ForUser, "content is required") { t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM) } } @@ -187,13 +188,13 @@ func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) { // TestFilesystemTool_ListDir_Success verifies successful directory listing func TestFilesystemTool_ListDir_Success(t *testing.T) { tmpDir := t.TempDir() - os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644) - os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644) - os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755) + os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0o644) + os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0o644) + os.Mkdir(filepath.Join(tmpDir, "subdir"), 0o755) tool := &ListDirTool{} ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": tmpDir, } @@ -217,7 +218,7 @@ func TestFilesystemTool_ListDir_Success(t *testing.T) { func TestFilesystemTool_ListDir_NotFound(t *testing.T) { tool := &ListDirTool{} ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "path": "/nonexistent_directory_12345", } @@ -238,7 +239,7 @@ func TestFilesystemTool_ListDir_NotFound(t *testing.T) { func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { tool := &ListDirTool{} ctx := context.Background() - args := map[string]interface{}{} + args := map[string]any{} result := tool.Execute(ctx, args) @@ -247,3 +248,34 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM) } } + +// Block paths that look inside workspace but point outside via symlink. +func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) { + root := t.TempDir() + workspace := filepath.Join(root, "workspace") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatalf("failed to create workspace: %v", err) + } + + secret := filepath.Join(root, "secret.txt") + if err := os.WriteFile(secret, []byte("top secret"), 0o644); err != nil { + t.Fatalf("failed to write secret file: %v", err) + } + + link := filepath.Join(workspace, "leak.txt") + if err := os.Symlink(secret, link); err != nil { + t.Skipf("symlink not supported in this environment: %v", err) + } + + tool := NewReadFileTool(workspace, true) + result := tool.Execute(context.Background(), map[string]any{ + "path": link, + }) + + if !result.IsError { + t.Fatalf("expected symlink escape to be blocked") + } + if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") { + t.Fatalf("expected symlink escape error, got: %s", result.ForLLM) + } +} diff --git a/pkg/tools/i2c.go b/pkg/tools/i2c.go new file mode 100644 index 000000000..0387a26d3 --- /dev/null +++ b/pkg/tools/i2c.go @@ -0,0 +1,149 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "path/filepath" + "regexp" + "runtime" +) + +// I2CTool provides I2C bus interaction for reading sensors and controlling peripherals. +type I2CTool struct{} + +func NewI2CTool() *I2CTool { + return &I2CTool{} +} + +func (t *I2CTool) Name() string { + return "i2c" +} + +func (t *I2CTool) Description() string { + return "Interact with I2C bus devices for reading sensors and controlling peripherals. Actions: detect (list buses), scan (find devices on a bus), read (read bytes from device), write (send bytes to device). Linux only." +} + +func (t *I2CTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []string{"detect", "scan", "read", "write"}, + "description": "Action to perform: detect (list available I2C buses), scan (find devices on a bus), read (read bytes from a device), write (send bytes to a device)", + }, + "bus": map[string]any{ + "type": "string", + "description": "I2C bus number (e.g. \"1\" for /dev/i2c-1). Required for scan/read/write.", + }, + "address": map[string]any{ + "type": "integer", + "description": "7-bit I2C device address (0x03-0x77). Required for read/write.", + }, + "register": map[string]any{ + "type": "integer", + "description": "Register address to read from or write to. If set, sends register byte before read/write.", + }, + "data": map[string]any{ + "type": "array", + "items": map[string]any{"type": "integer"}, + "description": "Bytes to write (0-255 each). Required for write action.", + }, + "length": map[string]any{ + "type": "integer", + "description": "Number of bytes to read (1-256). Default: 1. Used with read action.", + }, + "confirm": map[string]any{ + "type": "boolean", + "description": "Must be true for write operations. Safety guard to prevent accidental writes.", + }, + }, + "required": []string{"action"}, + } +} + +func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + if runtime.GOOS != "linux" { + return ErrorResult("I2C is only supported on Linux. This tool requires /dev/i2c-* device files.") + } + + action, ok := args["action"].(string) + if !ok { + return ErrorResult("action is required") + } + + switch action { + case "detect": + return t.detect() + case "scan": + return t.scan(args) + case "read": + return t.readDevice(args) + case "write": + return t.writeDevice(args) + default: + return ErrorResult(fmt.Sprintf("unknown action: %s (valid: detect, scan, read, write)", action)) + } +} + +// detect lists available I2C buses by globbing /dev/i2c-* +func (t *I2CTool) detect() *ToolResult { + matches, err := filepath.Glob("/dev/i2c-*") + if err != nil { + return ErrorResult(fmt.Sprintf("failed to scan for I2C buses: %v", err)) + } + + if len(matches) == 0 { + return SilentResult( + "No I2C buses found. You may need to:\n1. Load the i2c-dev module: modprobe i2c-dev\n2. Check that I2C is enabled in device tree\n3. Configure pinmux for your board (see hardware skill)", + ) + } + + type busInfo struct { + Path string `json:"path"` + Bus string `json:"bus"` + } + + buses := make([]busInfo, 0, len(matches)) + re := regexp.MustCompile(`/dev/i2c-(\d+)`) + for _, m := range matches { + if sub := re.FindStringSubmatch(m); sub != nil { + buses = append(buses, busInfo{Path: m, Bus: sub[1]}) + } + } + + result, _ := json.MarshalIndent(buses, "", " ") + return SilentResult(fmt.Sprintf("Found %d I2C bus(es):\n%s", len(buses), string(result))) +} + +// isValidBusID checks that a bus identifier is a simple number (prevents path injection) +func isValidBusID(id string) bool { + matched, _ := regexp.MatchString(`^\d+$`, id) + return matched +} + +// parseI2CAddress extracts and validates an I2C address from args +func parseI2CAddress(args map[string]any) (int, *ToolResult) { + addrFloat, ok := args["address"].(float64) + if !ok { + return 0, ErrorResult("address is required (e.g. 0x38 for AHT20)") + } + addr := int(addrFloat) + if addr < 0x03 || addr > 0x77 { + return 0, ErrorResult("address must be in valid 7-bit range (0x03-0x77)") + } + return addr, nil +} + +// parseI2CBus extracts and validates an I2C bus from args +func parseI2CBus(args map[string]any) (string, *ToolResult) { + bus, ok := args["bus"].(string) + if !ok || bus == "" { + return "", ErrorResult("bus is required (e.g. \"1\" for /dev/i2c-1)") + } + if !isValidBusID(bus) { + return "", ErrorResult("invalid bus identifier: must be a number (e.g. \"1\")") + } + return bus, nil +} diff --git a/pkg/tools/i2c_linux.go b/pkg/tools/i2c_linux.go new file mode 100644 index 000000000..4eaaf8f09 --- /dev/null +++ b/pkg/tools/i2c_linux.go @@ -0,0 +1,286 @@ +package tools + +import ( + "encoding/json" + "fmt" + "syscall" + "unsafe" +) + +// I2C ioctl constants from Linux kernel headers (, ) +const ( + i2cSlave = 0x0703 // Set slave address (fails if in use by driver) + i2cFuncs = 0x0705 // Query adapter functionality bitmask + i2cSmbus = 0x0720 // Perform SMBus transaction + + // I2C_FUNC capability bits + i2cFuncSmbusQuick = 0x00010000 + i2cFuncSmbusReadByte = 0x00020000 + + // SMBus transaction types + i2cSmbusRead = 0 + i2cSmbusWrite = 1 + + // SMBus protocol sizes + i2cSmbusQuick = 0 + i2cSmbusByte = 1 +) + +// i2cSmbusData matches the kernel union i2c_smbus_data (34 bytes max). +// For quick and byte transactions only the first byte is used (if at all). +type i2cSmbusData [34]byte + +// i2cSmbusArgs matches the kernel struct i2c_smbus_ioctl_data. +type i2cSmbusArgs struct { + readWrite uint8 + command uint8 + size uint32 + data *i2cSmbusData +} + +// smbusProbe performs a single SMBus probe at the given address. +// Uses SMBus Quick Write (safest) or falls back to SMBus Read Byte for +// EEPROM address ranges where quick write can corrupt AT24RF08 chips. +// This matches i2cdetect's MODE_AUTO behavior. +func smbusProbe(fd int, addr int, hasQuick bool) bool { + // EEPROM ranges: use read byte (quick write can corrupt AT24RF08) + useReadByte := (addr >= 0x30 && addr <= 0x37) || (addr >= 0x50 && addr <= 0x5F) + + if !useReadByte && hasQuick { + // SMBus Quick Write: [START] [ADDR|W] [ACK/NACK] [STOP] + // Safest probe — no data transferred + args := i2cSmbusArgs{ + readWrite: i2cSmbusWrite, + command: 0, + size: i2cSmbusQuick, + data: nil, + } + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args))) + return errno == 0 + } + + // SMBus Read Byte: [START] [ADDR|R] [ACK/NACK] [DATA] [STOP] + var data i2cSmbusData + args := i2cSmbusArgs{ + readWrite: i2cSmbusRead, + command: 0, + size: i2cSmbusByte, + data: &data, + } + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args))) + return errno == 0 +} + +// scan probes valid 7-bit addresses on a bus for connected devices. +// Uses the same hybrid probe strategy as i2cdetect's MODE_AUTO: +// SMBus Quick Write for most addresses, SMBus Read Byte for EEPROM ranges. +func (t *I2CTool) scan(args map[string]any) *ToolResult { + bus, errResult := parseI2CBus(args) + if errResult != nil { + return errResult + } + + devPath := fmt.Sprintf("/dev/i2c-%s", bus) + fd, err := syscall.Open(devPath, syscall.O_RDWR, 0) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and i2c-dev module)", devPath, err)) + } + defer syscall.Close(fd) + + // Query adapter capabilities to determine available probe methods. + // I2C_FUNCS writes an unsigned long, which is word-sized on Linux. + var funcs uintptr + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cFuncs, uintptr(unsafe.Pointer(&funcs))) + if errno != 0 { + return ErrorResult(fmt.Sprintf("failed to query I2C adapter capabilities on %s: %v", devPath, errno)) + } + + hasQuick := funcs&i2cFuncSmbusQuick != 0 + hasReadByte := funcs&i2cFuncSmbusReadByte != 0 + + if !hasQuick && !hasReadByte { + return ErrorResult( + fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath), + ) + } + + type deviceEntry struct { + Address string `json:"address"` + Status string `json:"status,omitempty"` + } + + var found []deviceEntry + // Scan 0x08-0x77, skipping I2C reserved addresses 0x00-0x07 + for addr := 0x08; addr <= 0x77; addr++ { + // Set slave address — EBUSY means a kernel driver owns this address + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSlave, uintptr(addr)) + if errno != 0 { + if errno == syscall.EBUSY { + found = append(found, deviceEntry{ + Address: fmt.Sprintf("0x%02x", addr), + Status: "busy (in use by kernel driver)", + }) + } + continue + } + + if smbusProbe(fd, addr, hasQuick) { + found = append(found, deviceEntry{ + Address: fmt.Sprintf("0x%02x", addr), + }) + } + } + + if len(found) == 0 { + return SilentResult(fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath)) + } + + result, _ := json.MarshalIndent(map[string]any{ + "bus": devPath, + "devices": found, + "count": len(found), + }, "", " ") + return SilentResult(fmt.Sprintf("Scan of %s:\n%s", devPath, string(result))) +} + +// readDevice reads bytes from an I2C device, optionally at a specific register +func (t *I2CTool) readDevice(args map[string]any) *ToolResult { + bus, errResult := parseI2CBus(args) + if errResult != nil { + return errResult + } + + addr, errResult := parseI2CAddress(args) + if errResult != nil { + return errResult + } + + length := 1 + if l, ok := args["length"].(float64); ok { + length = int(l) + } + if length < 1 || length > 256 { + return ErrorResult("length must be between 1 and 256") + } + + devPath := fmt.Sprintf("/dev/i2c-%s", bus) + fd, err := syscall.Open(devPath, syscall.O_RDWR, 0) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to open %s: %v", devPath, err)) + } + defer syscall.Close(fd) + + // Set slave address + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSlave, uintptr(addr)) + if errno != 0 { + return ErrorResult(fmt.Sprintf("failed to set I2C address 0x%02x: %v", addr, errno)) + } + + // If register is specified, write it first + if regFloat, ok := args["register"].(float64); ok { + reg := int(regFloat) + if reg < 0 || reg > 255 { + return ErrorResult("register must be between 0x00 and 0xFF") + } + _, err = syscall.Write(fd, []byte{byte(reg)}) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to write register 0x%02x: %v", reg, err)) + } + } + + // Read data + buf := make([]byte, length) + n, err := syscall.Read(fd, buf) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to read from device 0x%02x: %v", addr, err)) + } + + // Format as hex bytes + hexBytes := make([]string, n) + intBytes := make([]int, n) + for i := 0; i < n; i++ { + hexBytes[i] = fmt.Sprintf("0x%02x", buf[i]) + intBytes[i] = int(buf[i]) + } + + result, _ := json.MarshalIndent(map[string]any{ + "bus": devPath, + "address": fmt.Sprintf("0x%02x", addr), + "bytes": intBytes, + "hex": hexBytes, + "length": n, + }, "", " ") + return SilentResult(string(result)) +} + +// writeDevice writes bytes to an I2C device, optionally at a specific register +func (t *I2CTool) writeDevice(args map[string]any) *ToolResult { + confirm, _ := args["confirm"].(bool) + if !confirm { + return ErrorResult( + "write operations require confirm: true. Please confirm with the user before writing to I2C devices, as incorrect writes can misconfigure hardware.", + ) + } + + bus, errResult := parseI2CBus(args) + if errResult != nil { + return errResult + } + + addr, errResult := parseI2CAddress(args) + if errResult != nil { + return errResult + } + + dataRaw, ok := args["data"].([]any) + if !ok || len(dataRaw) == 0 { + return ErrorResult("data is required for write (array of byte values 0-255)") + } + if len(dataRaw) > 256 { + return ErrorResult("data too long: maximum 256 bytes per I2C transaction") + } + + data := make([]byte, 0, len(dataRaw)+1) + + // If register is specified, prepend it to the data + if regFloat, ok := args["register"].(float64); ok { + reg := int(regFloat) + if reg < 0 || reg > 255 { + return ErrorResult("register must be between 0x00 and 0xFF") + } + data = append(data, byte(reg)) + } + + for i, v := range dataRaw { + f, ok := v.(float64) + if !ok { + return ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i)) + } + b := int(f) + if b < 0 || b > 255 { + return ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b)) + } + data = append(data, byte(b)) + } + + devPath := fmt.Sprintf("/dev/i2c-%s", bus) + fd, err := syscall.Open(devPath, syscall.O_RDWR, 0) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to open %s: %v", devPath, err)) + } + defer syscall.Close(fd) + + // Set slave address + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSlave, uintptr(addr)) + if errno != 0 { + return ErrorResult(fmt.Sprintf("failed to set I2C address 0x%02x: %v", addr, errno)) + } + + // Write data + n, err := syscall.Write(fd, data) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to write to device 0x%02x: %v", addr, err)) + } + + return SilentResult(fmt.Sprintf("Wrote %d byte(s) to device 0x%02x on %s", n, addr, devPath)) +} diff --git a/pkg/tools/i2c_other.go b/pkg/tools/i2c_other.go new file mode 100644 index 000000000..7becf8339 --- /dev/null +++ b/pkg/tools/i2c_other.go @@ -0,0 +1,18 @@ +//go:build !linux + +package tools + +// scan is a stub for non-Linux platforms. +func (t *I2CTool) scan(args map[string]any) *ToolResult { + return ErrorResult("I2C is only supported on Linux") +} + +// readDevice is a stub for non-Linux platforms. +func (t *I2CTool) readDevice(args map[string]any) *ToolResult { + return ErrorResult("I2C is only supported on Linux") +} + +// writeDevice is a stub for non-Linux platforms. +func (t *I2CTool) writeDevice(args map[string]any) *ToolResult { + return ErrorResult("I2C is only supported on Linux") +} diff --git a/pkg/tools/message.go b/pkg/tools/message.go index 9c803bacf..15ef4ff73 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -11,6 +11,7 @@ type MessageTool struct { sendCallback SendCallback defaultChannel string defaultChatID string + sentInRound bool // Tracks whether a message was sent in the current processing round } func NewMessageTool() *MessageTool { @@ -25,19 +26,19 @@ func (t *MessageTool) Description() string { return "Send a message to user on a chat channel. Use this when you want to communicate something." } -func (t *MessageTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (t *MessageTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "content": map[string]interface{}{ + "properties": map[string]any{ + "content": map[string]any{ "type": "string", "description": "The message content to send", }, - "channel": map[string]interface{}{ + "channel": map[string]any{ "type": "string", "description": "Optional: target channel (telegram, whatsapp, etc.)", }, - "chat_id": map[string]interface{}{ + "chat_id": map[string]any{ "type": "string", "description": "Optional: target chat/user ID", }, @@ -49,13 +50,19 @@ func (t *MessageTool) Parameters() map[string]interface{} { func (t *MessageTool) SetContext(channel, chatID string) { t.defaultChannel = channel t.defaultChatID = chatID + t.sentInRound = false // Reset send tracking for new processing round +} + +// HasSentInRound returns true if the message tool sent a message during the current round. +func (t *MessageTool) HasSentInRound() bool { + return t.sentInRound } func (t *MessageTool) SetSendCallback(callback SendCallback) { t.sendCallback = callback } -func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult { content, ok := args["content"].(string) if !ok { return &ToolResult{ForLLM: "content is required", IsError: true} @@ -87,9 +94,10 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) } } + t.sentInRound = true // Silent: user already received the message directly return &ToolResult{ ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID), - Silent: true, + Silent: true, } } diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go index 4bedbe79b..717c1117b 100644 --- a/pkg/tools/message_test.go +++ b/pkg/tools/message_test.go @@ -19,7 +19,7 @@ func TestMessageTool_Execute_Success(t *testing.T) { }) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "content": "Hello, world!", } @@ -70,7 +70,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { }) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "content": "Test message", "channel": "custom-channel", "chat_id": "custom-chat-id", @@ -104,7 +104,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) { }) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "content": "Test message", } @@ -136,7 +136,7 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) { tool.SetContext("test-channel", "test-chat-id") ctx := context.Background() - args := map[string]interface{}{} // content missing + args := map[string]any{} // content missing result := tool.Execute(ctx, args) @@ -158,7 +158,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { }) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "content": "Test message", } @@ -179,7 +179,7 @@ func TestMessageTool_Execute_NotConfigured(t *testing.T) { // No SetSendCallback called ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "content": "Test message", } @@ -219,7 +219,7 @@ func TestMessageTool_Parameters(t *testing.T) { t.Error("Expected type 'object'") } - props, ok := params["properties"].(map[string]interface{}) + props, ok := params["properties"].(map[string]any) if !ok { t.Fatal("Expected properties to be a map") } @@ -231,7 +231,7 @@ func TestMessageTool_Parameters(t *testing.T) { } // Check content property - contentProp, ok := props["content"].(map[string]interface{}) + contentProp, ok := props["content"].(map[string]any) if !ok { t.Error("Expected 'content' property") } @@ -240,7 +240,7 @@ func TestMessageTool_Parameters(t *testing.T) { } // Check channel property (optional) - channelProp, ok := props["channel"].(map[string]interface{}) + channelProp, ok := props["channel"].(map[string]any) if !ok { t.Error("Expected 'channel' property") } @@ -249,7 +249,7 @@ func TestMessageTool_Parameters(t *testing.T) { } // Check chat_id property (optional) - chatIDProp, ok := props["chat_id"].(map[string]interface{}) + chatIDProp, ok := props["chat_id"].(map[string]any) if !ok { t.Error("Expected 'chat_id' property") } diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index c8cf92863..6ecb8ae7c 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -34,16 +34,22 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) { return tool, ok } -func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult { +func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult { return r.ExecuteWithContext(ctx, name, args, "", "", nil) } // ExecuteWithContext executes a tool with channel/chatID context and optional async callback. // If the tool implements AsyncTool and a non-nil callback is provided, // the callback will be set on the tool before execution. -func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult { +func (r *ToolRegistry) ExecuteWithContext( + ctx context.Context, + name string, + args map[string]any, + channel, chatID string, + asyncCallback AsyncCallback, +) *ToolResult { logger.InfoCF("tool", "Tool execution started", - map[string]interface{}{ + map[string]any{ "tool": name, "args": args, }) @@ -51,7 +57,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args tool, ok := r.Get(name) if !ok { logger.ErrorCF("tool", "Tool not found", - map[string]interface{}{ + map[string]any{ "tool": name, }) return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found")) @@ -66,7 +72,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil { asyncTool.SetCallback(asyncCallback) logger.DebugCF("tool", "Async callback injected", - map[string]interface{}{ + map[string]any{ "tool": name, }) } @@ -78,20 +84,20 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args // Log based on result type if result.IsError { logger.ErrorCF("tool", "Tool execution failed", - map[string]interface{}{ + map[string]any{ "tool": name, "duration": duration.Milliseconds(), "error": result.ForLLM, }) } else if result.Async { logger.InfoCF("tool", "Tool started (async)", - map[string]interface{}{ + map[string]any{ "tool": name, "duration": duration.Milliseconds(), }) } else { logger.InfoCF("tool", "Tool execution completed", - map[string]interface{}{ + map[string]any{ "tool": name, "duration_ms": duration.Milliseconds(), "result_length": len(result.ForLLM), @@ -101,11 +107,11 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args return result } -func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { +func (r *ToolRegistry) GetDefinitions() []map[string]any { r.mu.RLock() defer r.mu.RUnlock() - definitions := make([]map[string]interface{}, 0, len(r.tools)) + definitions := make([]map[string]any, 0, len(r.tools)) for _, tool := range r.tools { definitions = append(definitions, ToolToSchema(tool)) } @@ -123,14 +129,14 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition { schema := ToolToSchema(tool) // Safely extract nested values with type checks - fn, ok := schema["function"].(map[string]interface{}) + fn, ok := schema["function"].(map[string]any) if !ok { continue } name, _ := fn["name"].(string) desc, _ := fn["description"].(string) - params, _ := fn["parameters"].(map[string]interface{}) + params, _ := fn["parameters"].(map[string]any) definitions = append(definitions, providers.ToolDefinition{ Type: "function", diff --git a/pkg/tools/result_test.go b/pkg/tools/result_test.go index bc798cd70..a234e33f3 100644 --- a/pkg/tools/result_test.go +++ b/pkg/tools/result_test.go @@ -192,7 +192,7 @@ func TestToolResultJSONStructure(t *testing.T) { } // Verify JSON structure - var parsed map[string]interface{} + var parsed map[string]any if err := json.Unmarshal(data, &parsed); err != nil { t.Fatalf("Failed to parse JSON: %v", err) } diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index d35219240..a1ee0b6e1 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -3,6 +3,7 @@ package tools import ( "bytes" "context" + "errors" "fmt" "os" "os/exec" @@ -11,8 +12,9 @@ import ( "runtime" "strings" "time" -) + "github.com/sipeed/picoclaw/pkg/config" +) type ExecTool struct { workingDir string @@ -22,16 +24,82 @@ type ExecTool struct { restrictToWorkspace bool } +var defaultDenyPatterns = []*regexp.Regexp{ + regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`), + regexp.MustCompile(`\bdel\s+/[fq]\b`), + regexp.MustCompile(`\brmdir\s+/s\b`), + regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args) + regexp.MustCompile(`\bdd\s+if=`), + regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null) + regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`), + regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`), + regexp.MustCompile(`\$\([^)]+\)`), + regexp.MustCompile(`\$\{[^}]+\}`), + regexp.MustCompile("`[^`]+`"), + regexp.MustCompile(`\|\s*sh\b`), + regexp.MustCompile(`\|\s*bash\b`), + regexp.MustCompile(`;\s*rm\s+-[rf]`), + regexp.MustCompile(`&&\s*rm\s+-[rf]`), + regexp.MustCompile(`\|\|\s*rm\s+-[rf]`), + regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`), + regexp.MustCompile(`<<\s*EOF`), + regexp.MustCompile(`\$\(\s*cat\s+`), + regexp.MustCompile(`\$\(\s*curl\s+`), + regexp.MustCompile(`\$\(\s*wget\s+`), + regexp.MustCompile(`\$\(\s*which\s+`), + regexp.MustCompile(`\bsudo\b`), + regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`), + regexp.MustCompile(`\bchown\b`), + regexp.MustCompile(`\bpkill\b`), + regexp.MustCompile(`\bkillall\b`), + regexp.MustCompile(`\bkill\s+-[9]\b`), + regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`), + regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`), + regexp.MustCompile(`\bnpm\s+install\s+-g\b`), + regexp.MustCompile(`\bpip\s+install\s+--user\b`), + regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`), + regexp.MustCompile(`\byum\s+(install|remove)\b`), + regexp.MustCompile(`\bdnf\s+(install|remove)\b`), + regexp.MustCompile(`\bdocker\s+run\b`), + regexp.MustCompile(`\bdocker\s+exec\b`), + regexp.MustCompile(`\bgit\s+push\b`), + regexp.MustCompile(`\bgit\s+force\b`), + regexp.MustCompile(`\bssh\b.*@`), + regexp.MustCompile(`\beval\b`), + regexp.MustCompile(`\bsource\s+.*\.sh\b`), +} + func NewExecTool(workingDir string, restrict bool) *ExecTool { - denyPatterns := []*regexp.Regexp{ - regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`), - regexp.MustCompile(`\bdel\s+/[fq]\b`), - regexp.MustCompile(`\brmdir\s+/s\b`), - regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args) - regexp.MustCompile(`\bdd\s+if=`), - regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null) - regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`), - regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`), + return NewExecToolWithConfig(workingDir, restrict, nil) +} + +func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool { + denyPatterns := make([]*regexp.Regexp, 0) + + enableDenyPatterns := true + if config != nil { + execConfig := config.Tools.Exec + enableDenyPatterns = execConfig.EnableDenyPatterns + if enableDenyPatterns { + if len(execConfig.CustomDenyPatterns) > 0 { + fmt.Printf("Using custom deny patterns: %v\n", execConfig.CustomDenyPatterns) + for _, pattern := range execConfig.CustomDenyPatterns { + re, err := regexp.Compile(pattern) + if err != nil { + fmt.Printf("Invalid custom deny pattern %q: %v\n", pattern, err) + continue + } + denyPatterns = append(denyPatterns, re) + } + } else { + denyPatterns = append(denyPatterns, defaultDenyPatterns...) + } + } else { + // If deny patterns are disabled, we won't add any patterns, allowing all commands. + fmt.Println("Warning: deny patterns are disabled. All commands will be allowed.") + } + } else { + denyPatterns = append(denyPatterns, defaultDenyPatterns...) } return &ExecTool{ @@ -51,15 +119,15 @@ func (t *ExecTool) Description() string { return "Execute a shell command and return its output. Use with caution." } -func (t *ExecTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (t *ExecTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "command": map[string]interface{}{ + "properties": map[string]any{ + "command": map[string]any{ "type": "string", "description": "The shell command to execute", }, - "working_dir": map[string]interface{}{ + "working_dir": map[string]any{ "type": "string", "description": "Optional working directory for the command", }, @@ -68,7 +136,7 @@ func (t *ExecTool) Parameters() map[string]interface{} { } } -func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult { command, ok := args["command"].(string) if !ok { return ErrorResult("command is required") @@ -76,7 +144,15 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To cwd := t.workingDir if wd, ok := args["working_dir"].(string); ok && wd != "" { - cwd = wd + if t.restrictToWorkspace && t.workingDir != "" { + resolvedWD, err := validatePath(wd, t.workingDir, true) + if err != nil { + return ErrorResult("Command blocked by safety guard (" + err.Error() + ")") + } + cwd = resolvedWD + } else { + cwd = wd + } } if cwd == "" { @@ -90,7 +166,14 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To return ErrorResult(guardError) } - cmdCtx, cancel := context.WithTimeout(ctx, t.timeout) + // timeout == 0 means no timeout + var cmdCtx context.Context + var cancel context.CancelFunc + if t.timeout > 0 { + cmdCtx, cancel = context.WithTimeout(ctx, t.timeout) + } else { + cmdCtx, cancel = context.WithCancel(ctx) + } defer cancel() var cmd *exec.Cmd @@ -103,18 +186,43 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To cmd.Dir = cwd } + prepareCommandForTermination(cmd) + var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr - err := cmd.Run() + if err := cmd.Start(); err != nil { + return ErrorResult(fmt.Sprintf("failed to start command: %v", err)) + } + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + var err error + select { + case err = <-done: + case <-cmdCtx.Done(): + _ = terminateProcessTree(cmd) + select { + case err = <-done: + case <-time.After(2 * time.Second): + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + err = <-done + } + } + output := stdout.String() if stderr.Len() > 0 { output += "\nSTDERR:\n" + stderr.String() } if err != nil { - if cmdCtx.Err() == context.DeadlineExceeded { + if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) { msg := fmt.Sprintf("Command timed out after %v", t.timeout) return &ToolResult{ ForLLM: msg, diff --git a/pkg/tools/shell_process_unix.go b/pkg/tools/shell_process_unix.go new file mode 100644 index 000000000..7b29a81bf --- /dev/null +++ b/pkg/tools/shell_process_unix.go @@ -0,0 +1,32 @@ +//go:build !windows + +package tools + +import ( + "os/exec" + "syscall" +) + +func prepareCommandForTermination(cmd *exec.Cmd) { + if cmd == nil { + return + } + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} +} + +func terminateProcessTree(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + pid := cmd.Process.Pid + if pid <= 0 { + return nil + } + + // Kill the entire process group spawned by the shell command. + _ = syscall.Kill(-pid, syscall.SIGKILL) + // Fallback kill on the shell process itself. + _ = cmd.Process.Kill() + return nil +} diff --git a/pkg/tools/shell_process_windows.go b/pkg/tools/shell_process_windows.go new file mode 100644 index 000000000..fe23b5c96 --- /dev/null +++ b/pkg/tools/shell_process_windows.go @@ -0,0 +1,27 @@ +//go:build windows + +package tools + +import ( + "os/exec" + "strconv" +) + +func prepareCommandForTermination(cmd *exec.Cmd) { + // no-op on Windows +} + +func terminateProcessTree(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + pid := cmd.Process.Pid + if pid <= 0 { + return nil + } + + _ = exec.Command("taskkill", "/T", "/F", "/PID", strconv.Itoa(pid)).Run() + _ = cmd.Process.Kill() + return nil +} diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index c06468a39..6d35815e8 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -14,7 +14,7 @@ func TestShellTool_Success(t *testing.T) { tool := NewExecTool("", false) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "command": "echo 'hello world'", } @@ -41,7 +41,7 @@ func TestShellTool_Failure(t *testing.T) { tool := NewExecTool("", false) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "command": "ls /nonexistent_directory_12345", } @@ -69,7 +69,7 @@ func TestShellTool_Timeout(t *testing.T) { tool.SetTimeout(100 * time.Millisecond) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "command": "sleep 10", } @@ -91,12 +91,12 @@ func TestShellTool_WorkingDir(t *testing.T) { // Create temp directory tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "test.txt") - os.WriteFile(testFile, []byte("test content"), 0644) + os.WriteFile(testFile, []byte("test content"), 0o644) tool := NewExecTool("", false) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "command": "cat test.txt", "working_dir": tmpDir, } @@ -117,7 +117,7 @@ func TestShellTool_DangerousCommand(t *testing.T) { tool := NewExecTool("", false) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "command": "rm -rf /", } @@ -138,7 +138,7 @@ func TestShellTool_MissingCommand(t *testing.T) { tool := NewExecTool("", false) ctx := context.Background() - args := map[string]interface{}{} + args := map[string]any{} result := tool.Execute(ctx, args) @@ -153,7 +153,7 @@ func TestShellTool_StderrCapture(t *testing.T) { tool := NewExecTool("", false) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "command": "sh -c 'echo stdout; echo stderr >&2'", } @@ -174,7 +174,7 @@ func TestShellTool_OutputTruncation(t *testing.T) { ctx := context.Background() // Generate long output (>10000 chars) - args := map[string]interface{}{ + args := map[string]any{ "command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000), } @@ -186,6 +186,66 @@ func TestShellTool_OutputTruncation(t *testing.T) { } } +// TestShellTool_WorkingDir_OutsideWorkspace verifies that working_dir cannot escape the workspace directly +func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) { + root := t.TempDir() + workspace := filepath.Join(root, "workspace") + outsideDir := filepath.Join(root, "outside") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatalf("failed to create workspace: %v", err) + } + if err := os.MkdirAll(outsideDir, 0o755); err != nil { + t.Fatalf("failed to create outside dir: %v", err) + } + + tool := NewExecTool(workspace, true) + result := tool.Execute(context.Background(), map[string]any{ + "command": "pwd", + "working_dir": outsideDir, + }) + + if !result.IsError { + t.Fatalf("expected working_dir outside workspace to be blocked, got output: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "blocked") { + t.Errorf("expected 'blocked' in error, got: %s", result.ForLLM) + } +} + +// TestShellTool_WorkingDir_SymlinkEscape verifies that a symlink inside the workspace +// pointing outside cannot be used as working_dir to escape the sandbox. +func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) { + root := t.TempDir() + workspace := filepath.Join(root, "workspace") + secretDir := filepath.Join(root, "secret") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatalf("failed to create workspace: %v", err) + } + if err := os.MkdirAll(secretDir, 0o755); err != nil { + t.Fatalf("failed to create secret dir: %v", err) + } + os.WriteFile(filepath.Join(secretDir, "secret.txt"), []byte("top secret"), 0o644) + + // symlink lives inside the workspace but resolves to secretDir outside it + link := filepath.Join(workspace, "escape") + if err := os.Symlink(secretDir, link); err != nil { + t.Skipf("symlinks not supported in this environment: %v", err) + } + + tool := NewExecTool(workspace, true) + result := tool.Execute(context.Background(), map[string]any{ + "command": "cat secret.txt", + "working_dir": link, + }) + + if !result.IsError { + t.Fatalf("expected symlink working_dir escape to be blocked, got output: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "blocked") { + t.Errorf("expected 'blocked' in error, got: %s", result.ForLLM) + } +} + // TestShellTool_RestrictToWorkspace verifies workspace restriction func TestShellTool_RestrictToWorkspace(t *testing.T) { tmpDir := t.TempDir() @@ -193,7 +253,7 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) { tool.SetRestrictToWorkspace(true) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "command": "cat ../../etc/passwd", } @@ -205,6 +265,10 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) { } if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") { - t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + t.Errorf( + "Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", + result.ForLLM, + result.ForUser, + ) } } diff --git a/pkg/tools/shell_timeout_unix_test.go b/pkg/tools/shell_timeout_unix_test.go new file mode 100644 index 000000000..04ef8e441 --- /dev/null +++ b/pkg/tools/shell_timeout_unix_test.go @@ -0,0 +1,61 @@ +//go:build !windows + +package tools + +import ( + "context" + "os" + "path/filepath" + "strconv" + "strings" + "syscall" + "testing" + "time" +) + +func processExists(pid int) bool { + if pid <= 0 { + return false + } + err := syscall.Kill(pid, 0) + return err == nil || err == syscall.EPERM +} + +func TestShellTool_TimeoutKillsChildProcess(t *testing.T) { + tool := NewExecTool(t.TempDir(), false) + tool.SetTimeout(500 * time.Millisecond) + + args := map[string]any{ + // Spawn a child process that would outlive the shell unless process-group kill is used. + "command": "sleep 60 & echo $! > child.pid; wait", + } + + result := tool.Execute(context.Background(), args) + if !result.IsError { + t.Fatalf("expected timeout error, got success: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "timed out") { + t.Fatalf("expected timeout message, got: %s", result.ForLLM) + } + + childPIDPath := filepath.Join(tool.workingDir, "child.pid") + data, err := os.ReadFile(childPIDPath) + if err != nil { + t.Fatalf("failed to read child pid file: %v", err) + } + + childPID, err := strconv.Atoi(strings.TrimSpace(string(data))) + if err != nil { + t.Fatalf("failed to parse child pid: %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if !processExists(childPID) { + return + } + time.Sleep(50 * time.Millisecond) + } + + t.Fatalf("child process %d is still running after timeout", childPID) +} diff --git a/pkg/tools/skills_install.go b/pkg/tools/skills_install.go new file mode 100644 index 000000000..55c0b678d --- /dev/null +++ b/pkg/tools/skills_install.go @@ -0,0 +1,201 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/skills" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// InstallSkillTool allows the LLM agent to install skills from registries. +// It shares the same RegistryManager that FindSkillsTool uses, +// so all registries configured in config are available for installation. +type InstallSkillTool struct { + registryMgr *skills.RegistryManager + workspace string + mu sync.Mutex +} + +// NewInstallSkillTool creates a new InstallSkillTool. +// registryMgr is the shared registry manager (same instance as FindSkillsTool). +// workspace is the root workspace directory; skills install to {workspace}/skills/{slug}/. +func NewInstallSkillTool(registryMgr *skills.RegistryManager, workspace string) *InstallSkillTool { + return &InstallSkillTool{ + registryMgr: registryMgr, + workspace: workspace, + mu: sync.Mutex{}, + } +} + +func (t *InstallSkillTool) Name() string { + return "install_skill" +} + +func (t *InstallSkillTool) Description() string { + return "Install a skill from a registry by slug. Downloads and extracts the skill into the workspace. Use find_skills first to discover available skills." +} + +func (t *InstallSkillTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "slug": map[string]any{ + "type": "string", + "description": "The unique slug of the skill to install (e.g., 'github', 'docker-compose')", + }, + "version": map[string]any{ + "type": "string", + "description": "Specific version to install (optional, defaults to latest)", + }, + "registry": map[string]any{ + "type": "string", + "description": "Registry to install from (required, e.g., 'clawhub')", + }, + "force": map[string]any{ + "type": "boolean", + "description": "Force reinstall if skill already exists (default false)", + }, + }, + "required": []string{"slug", "registry"}, + } +} + +func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + // Install lock to prevent concurrent directory operations. + // Ideally this should be done at a `slug` level, currently, its at a `workspace` level. + t.mu.Lock() + defer t.mu.Unlock() + + // Validate slug + slug, _ := args["slug"].(string) + if err := utils.ValidateSkillIdentifier(slug); err != nil { + return ErrorResult(fmt.Sprintf("invalid slug %q: error: %s", slug, err.Error())) + } + + // Validate registry + registryName, _ := args["registry"].(string) + if err := utils.ValidateSkillIdentifier(registryName); err != nil { + return ErrorResult(fmt.Sprintf("invalid registry %q: error: %s", registryName, err.Error())) + } + + version, _ := args["version"].(string) + force, _ := args["force"].(bool) + + // Check if already installed. + skillsDir := filepath.Join(t.workspace, "skills") + targetDir := filepath.Join(skillsDir, slug) + + if !force { + if _, err := os.Stat(targetDir); err == nil { + return ErrorResult( + fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir), + ) + } + } else { + // Force: remove existing if present. + os.RemoveAll(targetDir) + } + + // Resolve which registry to use. + registry := t.registryMgr.GetRegistry(registryName) + if registry == nil { + return ErrorResult(fmt.Sprintf("registry %q not found", registryName)) + } + + // Ensure skills directory exists. + if err := os.MkdirAll(skillsDir, 0o755); err != nil { + return ErrorResult(fmt.Sprintf("failed to create skills directory: %v", err)) + } + + // Download and install (handles metadata, version resolution, extraction). + result, err := registry.DownloadAndInstall(ctx, slug, version, targetDir) + if err != nil { + // Clean up partial install. + rmErr := os.RemoveAll(targetDir) + if rmErr != nil { + logger.ErrorCF("tool", "Failed to remove partial install", + map[string]any{ + "tool": "install_skill", + "target_dir": targetDir, + "error": rmErr.Error(), + }) + } + return ErrorResult(fmt.Sprintf("failed to install %q: %v", slug, err)) + } + + // Moderation: block malware. + if result.IsMalwareBlocked { + rmErr := os.RemoveAll(targetDir) + if rmErr != nil { + logger.ErrorCF("tool", "Failed to remove partial install", + map[string]any{ + "tool": "install_skill", + "target_dir": targetDir, + "error": rmErr.Error(), + }) + } + return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug)) + } + + // Write origin metadata. + if err := writeOriginMeta(targetDir, registry.Name(), slug, result.Version); err != nil { + logger.ErrorCF("tool", "Failed to write origin metadata", + map[string]any{ + "tool": "install_skill", + "error": err.Error(), + "target": targetDir, + "registry": registry.Name(), + "slug": slug, + "version": result.Version, + }) + _ = err + } + + // Build result with moderation warning if suspicious. + var output string + if result.IsSuspicious { + output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug) + } + output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n", + slug, result.Version, registry.Name(), targetDir) + + if result.Summary != "" { + output += fmt.Sprintf("Description: %s\n", result.Summary) + } + output += "\nThe skill is now available and can be loaded in the current session." + + return SilentResult(output) +} + +// originMeta tracks which registry a skill was installed from. +type originMeta struct { + Version int `json:"version"` + Registry string `json:"registry"` + Slug string `json:"slug"` + InstalledVersion string `json:"installed_version"` + InstalledAt int64 `json:"installed_at"` +} + +func writeOriginMeta(targetDir, registryName, slug, version string) error { + meta := originMeta{ + Version: 1, + Registry: registryName, + Slug: slug, + InstalledVersion: version, + InstalledAt: time.Now().UnixMilli(), + } + + data, err := json.MarshalIndent(meta, "", " ") + if err != nil { + return err + } + + return os.WriteFile(filepath.Join(targetDir, ".skill-origin.json"), data, 0o644) +} diff --git a/pkg/tools/skills_install_test.go b/pkg/tools/skills_install_test.go new file mode 100644 index 000000000..676fcecc0 --- /dev/null +++ b/pkg/tools/skills_install_test.go @@ -0,0 +1,104 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sipeed/picoclaw/pkg/skills" +) + +func TestInstallSkillToolName(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + assert.Equal(t, "install_skill", tool.Name()) +} + +func TestInstallSkillToolMissingSlug(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + result := tool.Execute(context.Background(), map[string]any{}) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string") +} + +func TestInstallSkillToolEmptySlug(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + result := tool.Execute(context.Background(), map[string]any{ + "slug": " ", + }) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string") +} + +func TestInstallSkillToolUnsafeSlug(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + + cases := []string{ + "../etc/passwd", + "path/traversal", + "path\\traversal", + } + + for _, slug := range cases { + result := tool.Execute(context.Background(), map[string]any{ + "slug": slug, + }) + assert.True(t, result.IsError, "slug %q should be rejected", slug) + assert.Contains(t, result.ForLLM, "invalid slug") + } +} + +func TestInstallSkillToolAlreadyExists(t *testing.T) { + workspace := t.TempDir() + skillDir := filepath.Join(workspace, "skills", "existing-skill") + require.NoError(t, os.MkdirAll(skillDir, 0o755)) + + tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace) + result := tool.Execute(context.Background(), map[string]any{ + "slug": "existing-skill", + "registry": "clawhub", + }) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "already installed") +} + +func TestInstallSkillToolRegistryNotFound(t *testing.T) { + workspace := t.TempDir() + tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace) + result := tool.Execute(context.Background(), map[string]any{ + "slug": "some-skill", + "registry": "nonexistent", + }) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "registry") + assert.Contains(t, result.ForLLM, "not found") +} + +func TestInstallSkillToolParameters(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + params := tool.Parameters() + + props, ok := params["properties"].(map[string]any) + assert.True(t, ok) + assert.Contains(t, props, "slug") + assert.Contains(t, props, "version") + assert.Contains(t, props, "registry") + assert.Contains(t, props, "force") + + required, ok := params["required"].([]string) + assert.True(t, ok) + assert.Contains(t, required, "slug") + assert.Contains(t, required, "registry") +} + +func TestInstallSkillToolMissingRegistry(t *testing.T) { + tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir()) + result := tool.Execute(context.Background(), map[string]any{ + "slug": "some-skill", + }) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "invalid registry") +} diff --git a/pkg/tools/skills_search.go b/pkg/tools/skills_search.go new file mode 100644 index 000000000..2b6cffd38 --- /dev/null +++ b/pkg/tools/skills_search.go @@ -0,0 +1,119 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/skills" +) + +// FindSkillsTool allows the LLM agent to search for installable skills from registries. +type FindSkillsTool struct { + registryMgr *skills.RegistryManager + cache *skills.SearchCache +} + +// NewFindSkillsTool creates a new FindSkillsTool. +// registryMgr is the shared registry manager (built from config in createToolRegistry). +// cache is the search cache for deduplicating similar queries. +func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool { + return &FindSkillsTool{ + registryMgr: registryMgr, + cache: cache, + } +} + +func (t *FindSkillsTool) Name() string { + return "find_skills" +} + +func (t *FindSkillsTool) Description() string { + return "Search for installable skills from skill registries. Returns skill slugs, descriptions, versions, and relevance scores. Use this to discover skills before installing them with install_skill." +} + +func (t *FindSkillsTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "Search query describing the desired skill capability (e.g., 'github integration', 'database management')", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of results to return (1-20, default 5)", + "minimum": 1.0, + "maximum": 20.0, + }, + }, + "required": []string{"query"}, + } +} + +func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + query, ok := args["query"].(string) + query = strings.ToLower(strings.TrimSpace(query)) + if !ok || query == "" { + return ErrorResult("query is required and must be a non-empty string") + } + + limit := 5 + if l, ok := args["limit"].(float64); ok { + li := int(l) + if li >= 1 && li <= 20 { + limit = li + } + } + + // Check cache first. + if t.cache != nil { + if cached, hit := t.cache.Get(query); hit { + return SilentResult(formatSearchResults(query, cached, true)) + } + } + + // Search all registries. + results, err := t.registryMgr.SearchAll(ctx, query, limit) + if err != nil { + return ErrorResult(fmt.Sprintf("skill search failed: %v", err)) + } + + // Cache the results. + if t.cache != nil && len(results) > 0 { + t.cache.Put(query, results) + } + + return SilentResult(formatSearchResults(query, results, false)) +} + +func formatSearchResults(query string, results []skills.SearchResult, cached bool) string { + if len(results) == 0 { + return fmt.Sprintf("No skills found for query: %q", query) + } + + var sb strings.Builder + source := "" + if cached { + source = " (cached)" + } + sb.WriteString(fmt.Sprintf("Found %d skills for %q%s:\n\n", len(results), query, source)) + + for i, r := range results { + sb.WriteString(fmt.Sprintf("%d. **%s**", i+1, r.Slug)) + if r.Version != "" { + sb.WriteString(fmt.Sprintf(" v%s", r.Version)) + } + sb.WriteString(fmt.Sprintf(" (score: %.3f, registry: %s)\n", r.Score, r.RegistryName)) + if r.DisplayName != "" && r.DisplayName != r.Slug { + sb.WriteString(fmt.Sprintf(" Name: %s\n", r.DisplayName)) + } + if r.Summary != "" { + sb.WriteString(fmt.Sprintf(" %s\n", r.Summary)) + } + sb.WriteString("\n") + } + + sb.WriteString("Use install_skill with the slug to install a skill.") + return sb.String() +} diff --git a/pkg/tools/skills_search_test.go b/pkg/tools/skills_search_test.go new file mode 100644 index 000000000..0e5387cf5 --- /dev/null +++ b/pkg/tools/skills_search_test.go @@ -0,0 +1,90 @@ +package tools + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/sipeed/picoclaw/pkg/skills" +) + +func TestFindSkillsToolName(t *testing.T) { + tool := NewFindSkillsTool(skills.NewRegistryManager(), nil) + assert.Equal(t, "find_skills", tool.Name()) +} + +func TestFindSkillsToolMissingQuery(t *testing.T) { + tool := NewFindSkillsTool(skills.NewRegistryManager(), nil) + result := tool.Execute(context.Background(), map[string]any{}) + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "query is required") +} + +func TestFindSkillsToolEmptyQuery(t *testing.T) { + tool := NewFindSkillsTool(skills.NewRegistryManager(), nil) + result := tool.Execute(context.Background(), map[string]any{ + "query": " ", + }) + assert.True(t, result.IsError) +} + +func TestFindSkillsToolCacheHit(t *testing.T) { + cache := skills.NewSearchCache(10, 5*60*1000*1000*1000) // 5 min + cache.Put("github", []skills.SearchResult{ + {Slug: "github", Score: 0.9, RegistryName: "clawhub"}, + }) + + tool := NewFindSkillsTool(skills.NewRegistryManager(), cache) + result := tool.Execute(context.Background(), map[string]any{ + "query": "github", + }) + + assert.False(t, result.IsError) + assert.Contains(t, result.ForLLM, "github") + assert.Contains(t, result.ForLLM, "cached") +} + +func TestFindSkillsToolParameters(t *testing.T) { + tool := NewFindSkillsTool(skills.NewRegistryManager(), nil) + params := tool.Parameters() + + props, ok := params["properties"].(map[string]any) + assert.True(t, ok) + assert.Contains(t, props, "query") + assert.Contains(t, props, "limit") + + required, ok := params["required"].([]string) + assert.True(t, ok) + assert.Contains(t, required, "query") +} + +func TestFindSkillsToolDescription(t *testing.T) { + tool := NewFindSkillsTool(skills.NewRegistryManager(), nil) + assert.NotEmpty(t, tool.Description()) + assert.Contains(t, tool.Description(), "skill") +} + +func TestFormatSearchResultsEmpty(t *testing.T) { + result := formatSearchResults("test query", nil, false) + assert.Contains(t, result, "No skills found") +} + +func TestFormatSearchResultsWithData(t *testing.T) { + results := []skills.SearchResult{ + { + Slug: "github", + Score: 0.95, + DisplayName: "GitHub", + Summary: "GitHub API integration", + Version: "1.0.0", + RegistryName: "clawhub", + }, + } + output := formatSearchResults("github", results, false) + assert.Contains(t, output, "github") + assert.Contains(t, output, "v1.0.0") + assert.Contains(t, output, "0.950") + assert.Contains(t, output, "clawhub") + assert.Contains(t, output, "install_skill") +} diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 42dd36a33..73d385cb0 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -6,10 +6,11 @@ import ( ) type SpawnTool struct { - manager *SubagentManager - originChannel string - originChatID string - callback AsyncCallback // For async completion notification + manager *SubagentManager + originChannel string + originChatID string + allowlistCheck func(targetAgentID string) bool + callback AsyncCallback // For async completion notification } func NewSpawnTool(manager *SubagentManager) *SpawnTool { @@ -33,18 +34,22 @@ func (t *SpawnTool) Description() string { return "Spawn a subagent to handle a task in the background. Use this for complex or time-consuming tasks that can run independently. The subagent will complete the task and report back when done." } -func (t *SpawnTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (t *SpawnTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "task": map[string]interface{}{ + "properties": map[string]any{ + "task": map[string]any{ "type": "string", "description": "The task for subagent to complete", }, - "label": map[string]interface{}{ + "label": map[string]any{ "type": "string", "description": "Optional short label for the task (for display)", }, + "agent_id": map[string]any{ + "type": "string", + "description": "Optional target agent ID to delegate the task to", + }, }, "required": []string{"task"}, } @@ -55,20 +60,32 @@ func (t *SpawnTool) SetContext(channel, chatID string) { t.originChatID = chatID } -func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) { + t.allowlistCheck = check +} + +func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult { task, ok := args["task"].(string) if !ok { return ErrorResult("task is required") } label, _ := args["label"].(string) + agentID, _ := args["agent_id"].(string) + + // Check allowlist if targeting a specific agent + if agentID != "" && t.allowlistCheck != nil { + if !t.allowlistCheck(agentID) { + return ErrorResult(fmt.Sprintf("not allowed to spawn agent '%s'", agentID)) + } + } if t.manager == nil { return ErrorResult("Subagent manager not configured") } // Pass callback to manager for async completion notification - result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback) + result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback) if err != nil { return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err)) } diff --git a/pkg/tools/spi.go b/pkg/tools/spi.go new file mode 100644 index 000000000..d6a88a5b0 --- /dev/null +++ b/pkg/tools/spi.go @@ -0,0 +1,158 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "path/filepath" + "regexp" + "runtime" +) + +// SPITool provides SPI bus interaction for high-speed peripheral communication. +type SPITool struct{} + +func NewSPITool() *SPITool { + return &SPITool{} +} + +func (t *SPITool) Name() string { + return "spi" +} + +func (t *SPITool) Description() string { + return "Interact with SPI bus devices for high-speed peripheral communication. Actions: list (find SPI devices), transfer (full-duplex send/receive), read (receive bytes). Linux only." +} + +func (t *SPITool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []string{"list", "transfer", "read"}, + "description": "Action to perform: list (find available SPI devices), transfer (full-duplex send/receive), read (receive bytes by sending zeros)", + }, + "device": map[string]any{ + "type": "string", + "description": "SPI device identifier (e.g. \"2.0\" for /dev/spidev2.0). Required for transfer/read.", + }, + "speed": map[string]any{ + "type": "integer", + "description": "SPI clock speed in Hz. Default: 1000000 (1 MHz).", + }, + "mode": map[string]any{ + "type": "integer", + "description": "SPI mode (0-3). Default: 0. Mode sets CPOL and CPHA: 0=0,0 1=0,1 2=1,0 3=1,1.", + }, + "bits": map[string]any{ + "type": "integer", + "description": "Bits per word. Default: 8.", + }, + "data": map[string]any{ + "type": "array", + "items": map[string]any{"type": "integer"}, + "description": "Bytes to send (0-255 each). Required for transfer action.", + }, + "length": map[string]any{ + "type": "integer", + "description": "Number of bytes to read (1-4096). Required for read action.", + }, + "confirm": map[string]any{ + "type": "boolean", + "description": "Must be true for transfer operations. Safety guard to prevent accidental writes.", + }, + }, + "required": []string{"action"}, + } +} + +func (t *SPITool) Execute(ctx context.Context, args map[string]any) *ToolResult { + if runtime.GOOS != "linux" { + return ErrorResult("SPI is only supported on Linux. This tool requires /dev/spidev* device files.") + } + + action, ok := args["action"].(string) + if !ok { + return ErrorResult("action is required") + } + + switch action { + case "list": + return t.list() + case "transfer": + return t.transfer(args) + case "read": + return t.readDevice(args) + default: + return ErrorResult(fmt.Sprintf("unknown action: %s (valid: list, transfer, read)", action)) + } +} + +// list finds available SPI devices by globbing /dev/spidev* +func (t *SPITool) list() *ToolResult { + matches, err := filepath.Glob("/dev/spidev*") + if err != nil { + return ErrorResult(fmt.Sprintf("failed to scan for SPI devices: %v", err)) + } + + if len(matches) == 0 { + return SilentResult( + "No SPI devices found. You may need to:\n1. Enable SPI in device tree\n2. Configure pinmux for your board (see hardware skill)\n3. Check that spidev module is loaded", + ) + } + + type devInfo struct { + Path string `json:"path"` + Device string `json:"device"` + } + + devices := make([]devInfo, 0, len(matches)) + re := regexp.MustCompile(`/dev/spidev(\d+\.\d+)`) + for _, m := range matches { + if sub := re.FindStringSubmatch(m); sub != nil { + devices = append(devices, devInfo{Path: m, Device: sub[1]}) + } + } + + result, _ := json.MarshalIndent(devices, "", " ") + return SilentResult(fmt.Sprintf("Found %d SPI device(s):\n%s", len(devices), string(result))) +} + +// parseSPIArgs extracts and validates common SPI parameters +func parseSPIArgs(args map[string]any) (device string, speed uint32, mode uint8, bits uint8, errMsg string) { + dev, ok := args["device"].(string) + if !ok || dev == "" { + return "", 0, 0, 0, "device is required (e.g. \"2.0\" for /dev/spidev2.0)" + } + matched, _ := regexp.MatchString(`^\d+\.\d+$`, dev) + if !matched { + return "", 0, 0, 0, "invalid device identifier: must be in format \"X.Y\" (e.g. \"2.0\")" + } + + speed = 1000000 // default 1 MHz + if s, ok := args["speed"].(float64); ok { + if s < 1 || s > 125000000 { + return "", 0, 0, 0, "speed must be between 1 Hz and 125 MHz" + } + speed = uint32(s) + } + + mode = 0 + if m, ok := args["mode"].(float64); ok { + if int(m) < 0 || int(m) > 3 { + return "", 0, 0, 0, "mode must be 0-3" + } + mode = uint8(m) + } + + bits = 8 + if b, ok := args["bits"].(float64); ok { + if int(b) < 1 || int(b) > 32 { + return "", 0, 0, 0, "bits must be between 1 and 32" + } + bits = uint8(b) + } + + return dev, speed, mode, bits, "" +} diff --git a/pkg/tools/spi_linux.go b/pkg/tools/spi_linux.go new file mode 100644 index 000000000..9def73662 --- /dev/null +++ b/pkg/tools/spi_linux.go @@ -0,0 +1,198 @@ +package tools + +import ( + "encoding/json" + "fmt" + "runtime" + "syscall" + "unsafe" +) + +// SPI ioctl constants from Linux kernel headers. +// Calculated from _IOW('k', nr, size) macro: +// +// direction(1)<<30 | size<<16 | type(0x6B)<<8 | nr +const ( + spiIocWrMode = 0x40016B01 // _IOW('k', 1, __u8) + spiIocWrBitsPerWord = 0x40016B03 // _IOW('k', 3, __u8) + spiIocWrMaxSpeedHz = 0x40046B04 // _IOW('k', 4, __u32) + spiIocMessage1 = 0x40206B00 // _IOW('k', 0, struct spi_ioc_transfer) — 32 bytes +) + +// spiTransfer matches Linux kernel struct spi_ioc_transfer (32 bytes on all architectures). +type spiTransfer struct { + txBuf uint64 + rxBuf uint64 + length uint32 + speedHz uint32 + delayUsecs uint16 + bitsPerWord uint8 + csChange uint8 + txNbits uint8 + rxNbits uint8 + wordDelay uint8 + pad uint8 +} + +// configureSPI opens an SPI device and sets mode, bits per word, and speed +func configureSPI(devPath string, mode uint8, bits uint8, speed uint32) (int, *ToolResult) { + fd, err := syscall.Open(devPath, syscall.O_RDWR, 0) + if err != nil { + return -1, ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and spidev module)", devPath, err)) + } + + // Set SPI mode + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMode, uintptr(unsafe.Pointer(&mode))) + if errno != 0 { + syscall.Close(fd) + return -1, ErrorResult(fmt.Sprintf("failed to set SPI mode %d: %v", mode, errno)) + } + + // Set bits per word + _, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrBitsPerWord, uintptr(unsafe.Pointer(&bits))) + if errno != 0 { + syscall.Close(fd) + return -1, ErrorResult(fmt.Sprintf("failed to set bits per word %d: %v", bits, errno)) + } + + // Set max speed + _, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMaxSpeedHz, uintptr(unsafe.Pointer(&speed))) + if errno != 0 { + syscall.Close(fd) + return -1, ErrorResult(fmt.Sprintf("failed to set SPI speed %d Hz: %v", speed, errno)) + } + + return fd, nil +} + +// transfer performs a full-duplex SPI transfer +func (t *SPITool) transfer(args map[string]any) *ToolResult { + confirm, _ := args["confirm"].(bool) + if !confirm { + return ErrorResult( + "transfer operations require confirm: true. Please confirm with the user before sending data to SPI devices.", + ) + } + + dev, speed, mode, bits, errMsg := parseSPIArgs(args) + if errMsg != "" { + return ErrorResult(errMsg) + } + + dataRaw, ok := args["data"].([]any) + if !ok || len(dataRaw) == 0 { + return ErrorResult("data is required for transfer (array of byte values 0-255)") + } + if len(dataRaw) > 4096 { + return ErrorResult("data too long: maximum 4096 bytes per SPI transfer") + } + + txBuf := make([]byte, len(dataRaw)) + for i, v := range dataRaw { + f, ok := v.(float64) + if !ok { + return ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i)) + } + b := int(f) + if b < 0 || b > 255 { + return ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b)) + } + txBuf[i] = byte(b) + } + + devPath := fmt.Sprintf("/dev/spidev%s", dev) + fd, errResult := configureSPI(devPath, mode, bits, speed) + if errResult != nil { + return errResult + } + defer syscall.Close(fd) + + rxBuf := make([]byte, len(txBuf)) + + xfer := spiTransfer{ + txBuf: uint64(uintptr(unsafe.Pointer(&txBuf[0]))), + rxBuf: uint64(uintptr(unsafe.Pointer(&rxBuf[0]))), + length: uint32(len(txBuf)), + speedHz: speed, + bitsPerWord: bits, + } + + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocMessage1, uintptr(unsafe.Pointer(&xfer))) + runtime.KeepAlive(txBuf) + runtime.KeepAlive(rxBuf) + if errno != 0 { + return ErrorResult(fmt.Sprintf("SPI transfer failed: %v", errno)) + } + + // Format received bytes + hexBytes := make([]string, len(rxBuf)) + intBytes := make([]int, len(rxBuf)) + for i, b := range rxBuf { + hexBytes[i] = fmt.Sprintf("0x%02x", b) + intBytes[i] = int(b) + } + + result, _ := json.MarshalIndent(map[string]any{ + "device": devPath, + "sent": len(txBuf), + "received": intBytes, + "hex": hexBytes, + }, "", " ") + return SilentResult(string(result)) +} + +// readDevice reads bytes from SPI by sending zeros (read-only, no confirm needed) +func (t *SPITool) readDevice(args map[string]any) *ToolResult { + dev, speed, mode, bits, errMsg := parseSPIArgs(args) + if errMsg != "" { + return ErrorResult(errMsg) + } + + length := 0 + if l, ok := args["length"].(float64); ok { + length = int(l) + } + if length < 1 || length > 4096 { + return ErrorResult("length is required for read (1-4096)") + } + + devPath := fmt.Sprintf("/dev/spidev%s", dev) + fd, errResult := configureSPI(devPath, mode, bits, speed) + if errResult != nil { + return errResult + } + defer syscall.Close(fd) + + txBuf := make([]byte, length) // zeros + rxBuf := make([]byte, length) + + xfer := spiTransfer{ + txBuf: uint64(uintptr(unsafe.Pointer(&txBuf[0]))), + rxBuf: uint64(uintptr(unsafe.Pointer(&rxBuf[0]))), + length: uint32(length), + speedHz: speed, + bitsPerWord: bits, + } + + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocMessage1, uintptr(unsafe.Pointer(&xfer))) + runtime.KeepAlive(txBuf) + runtime.KeepAlive(rxBuf) + if errno != 0 { + return ErrorResult(fmt.Sprintf("SPI read failed: %v", errno)) + } + + hexBytes := make([]string, len(rxBuf)) + intBytes := make([]int, len(rxBuf)) + for i, b := range rxBuf { + hexBytes[i] = fmt.Sprintf("0x%02x", b) + intBytes[i] = int(b) + } + + result, _ := json.MarshalIndent(map[string]any{ + "device": devPath, + "bytes": intBytes, + "hex": hexBytes, + "length": len(rxBuf), + }, "", " ") + return SilentResult(string(result)) +} diff --git a/pkg/tools/spi_other.go b/pkg/tools/spi_other.go new file mode 100644 index 000000000..5d078ac3f --- /dev/null +++ b/pkg/tools/spi_other.go @@ -0,0 +1,13 @@ +//go:build !linux + +package tools + +// transfer is a stub for non-Linux platforms. +func (t *SPITool) transfer(args map[string]any) *ToolResult { + return ErrorResult("SPI is only supported on Linux") +} + +// readDevice is a stub for non-Linux platforms. +func (t *SPITool) readDevice(args map[string]any) *ToolResult { + return ErrorResult("SPI is only supported on Linux") +} diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 97b130396..91ebff636 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -14,6 +14,7 @@ type SubagentTask struct { ID string Task string Label string + AgentID string OriginChannel string OriginChatID string Status string @@ -22,30 +23,48 @@ type SubagentTask struct { } type SubagentManager struct { - tasks map[string]*SubagentTask - mu sync.RWMutex - provider providers.LLMProvider - defaultModel string - bus *bus.MessageBus - workspace string - tools *ToolRegistry - maxIterations int - nextID int + tasks map[string]*SubagentTask + mu sync.RWMutex + provider providers.LLMProvider + defaultModel string + bus *bus.MessageBus + workspace string + tools *ToolRegistry + maxIterations int + maxTokens int + temperature float64 + hasMaxTokens bool + hasTemperature bool + nextID int } -func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager { +func NewSubagentManager( + provider providers.LLMProvider, + defaultModel, workspace string, + bus *bus.MessageBus, +) *SubagentManager { return &SubagentManager{ - tasks: make(map[string]*SubagentTask), - provider: provider, - defaultModel: defaultModel, - bus: bus, - workspace: workspace, - tools: NewToolRegistry(), + tasks: make(map[string]*SubagentTask), + provider: provider, + defaultModel: defaultModel, + bus: bus, + workspace: workspace, + tools: NewToolRegistry(), maxIterations: 10, - nextID: 1, + nextID: 1, } } +// SetLLMOptions sets max tokens and temperature for subagent LLM calls. +func (sm *SubagentManager) SetLLMOptions(maxTokens int, temperature float64) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.maxTokens = maxTokens + sm.hasMaxTokens = true + sm.temperature = temperature + sm.hasTemperature = true +} + // SetTools sets the tool registry for subagent execution. // If not set, subagent will have access to the provided tools. func (sm *SubagentManager) SetTools(tools *ToolRegistry) { @@ -61,7 +80,11 @@ func (sm *SubagentManager) RegisterTool(tool Tool) { sm.tools.Register(tool) } -func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) { +func (sm *SubagentManager) Spawn( + ctx context.Context, + task, label, agentID, originChannel, originChatID string, + callback AsyncCallback, +) (string, error) { sm.mu.Lock() defer sm.mu.Unlock() @@ -72,6 +95,7 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel ID: taskID, Task: task, Label: label, + AgentID: agentID, OriginChannel: originChannel, OriginChatID: originChatID, Status: "running", @@ -123,17 +147,29 @@ After completing the task, provide a clear summary of what was done.` sm.mu.RLock() tools := sm.tools maxIter := sm.maxIterations + maxTokens := sm.maxTokens + temperature := sm.temperature + hasMaxTokens := sm.hasMaxTokens + hasTemperature := sm.hasTemperature sm.mu.RUnlock() + var llmOptions map[string]any + if hasMaxTokens || hasTemperature { + llmOptions = map[string]any{} + if hasMaxTokens { + llmOptions["max_tokens"] = maxTokens + } + if hasTemperature { + llmOptions["temperature"] = temperature + } + } + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ Provider: sm.provider, Model: sm.defaultModel, Tools: tools, MaxIterations: maxIter, - LLMOptions: map[string]any{ - "max_tokens": 4096, - "temperature": 0.7, - }, + LLMOptions: llmOptions, }, messages, task.OriginChannel, task.OriginChatID) sm.mu.Lock() @@ -166,7 +202,12 @@ After completing the task, provide a clear summary of what was done.` task.Status = "completed" task.Result = loopResult.Content result = &ToolResult{ - ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content), + ForLLM: fmt.Sprintf( + "Subagent '%s' completed (iterations: %d): %s", + task.Label, + loopResult.Iterations, + loopResult.Content, + ), ForUser: loopResult.Content, Silent: false, IsError: false, @@ -230,15 +271,15 @@ func (t *SubagentTool) Description() string { return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM." } -func (t *SubagentTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (t *SubagentTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "task": map[string]interface{}{ + "properties": map[string]any{ + "task": map[string]any{ "type": "string", "description": "The task for subagent to complete", }, - "label": map[string]interface{}{ + "label": map[string]any{ "type": "string", "description": "Optional short label for the task (for display)", }, @@ -252,7 +293,7 @@ func (t *SubagentTool) SetContext(channel, chatID string) { t.originChatID = chatID } -func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult { task, ok := args["task"].(string) if !ok { return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required")) @@ -281,19 +322,30 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) sm.mu.RLock() tools := sm.tools maxIter := sm.maxIterations + maxTokens := sm.maxTokens + temperature := sm.temperature + hasMaxTokens := sm.hasMaxTokens + hasTemperature := sm.hasTemperature sm.mu.RUnlock() + var llmOptions map[string]any + if hasMaxTokens || hasTemperature { + llmOptions = map[string]any{} + if hasMaxTokens { + llmOptions["max_tokens"] = maxTokens + } + if hasTemperature { + llmOptions["temperature"] = temperature + } + } + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ Provider: sm.provider, Model: sm.defaultModel, Tools: tools, MaxIterations: maxIter, - LLMOptions: map[string]any{ - "max_tokens": 4096, - "temperature": 0.7, - }, + LLMOptions: llmOptions, }, messages, t.originChannel, t.originChatID) - if err != nil { return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err) } diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent_tool_test.go index 8a7d22f24..59bfdffae 100644 --- a/pkg/tools/subagent_tool_test.go +++ b/pkg/tools/subagent_tool_test.go @@ -10,9 +10,18 @@ import ( ) // MockLLMProvider is a test implementation of LLMProvider -type MockLLMProvider struct{} +type MockLLMProvider struct { + lastOptions map[string]any +} -func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { +func (m *MockLLMProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + options map[string]any, +) (*providers.LLMResponse, error) { + m.lastOptions = options // Find the last user message to generate a response for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { @@ -36,6 +45,32 @@ func (m *MockLLMProvider) GetContextWindow() int { return 4096 } +func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + manager.SetLLMOptions(2048, 0.6) + tool := NewSubagentTool(manager) + tool.SetContext("cli", "direct") + + ctx := context.Background() + args := map[string]any{"task": "Do something"} + result := tool.Execute(ctx, args) + + if result == nil || result.IsError { + t.Fatalf("Expected successful result, got: %+v", result) + } + + if provider.lastOptions == nil { + t.Fatal("Expected LLM options to be passed, got nil") + } + if provider.lastOptions["max_tokens"] != 2048 { + t.Fatalf("max_tokens = %v, want %d", provider.lastOptions["max_tokens"], 2048) + } + if provider.lastOptions["temperature"] != 0.6 { + t.Fatalf("temperature = %v, want %v", provider.lastOptions["temperature"], 0.6) + } +} + // TestSubagentTool_Name verifies tool name func TestSubagentTool_Name(t *testing.T) { provider := &MockLLMProvider{} @@ -79,13 +114,13 @@ func TestSubagentTool_Parameters(t *testing.T) { } // Check properties - props, ok := params["properties"].(map[string]interface{}) + props, ok := params["properties"].(map[string]any) if !ok { t.Fatal("Properties should be a map") } // Verify task parameter - task, ok := props["task"].(map[string]interface{}) + task, ok := props["task"].(map[string]any) if !ok { t.Fatal("Task parameter should exist") } @@ -94,7 +129,7 @@ func TestSubagentTool_Parameters(t *testing.T) { } // Verify label parameter - label, ok := props["label"].(map[string]interface{}) + label, ok := props["label"].(map[string]any) if !ok { t.Fatal("Label parameter should exist") } @@ -134,7 +169,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) { tool.SetContext("telegram", "chat-123") ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "task": "Write a haiku about coding", "label": "haiku-task", } @@ -189,7 +224,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) { tool := NewSubagentTool(manager) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "task": "Test task without label", } @@ -212,7 +247,7 @@ func TestSubagentTool_Execute_MissingTask(t *testing.T) { tool := NewSubagentTool(manager) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "label": "test", } @@ -239,7 +274,7 @@ func TestSubagentTool_Execute_NilManager(t *testing.T) { tool := NewSubagentTool(nil) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "task": "test task", } @@ -268,7 +303,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) { tool.SetContext(channel, chatID) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "task": "Test context passing", } @@ -295,7 +330,7 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) { // Create a task that will generate long response longTask := strings.Repeat("This is a very long task description. ", 100) - args := map[string]interface{}{ + args := map[string]any{ "task": longTask, "label": "long-test", } diff --git a/pkg/tools/swarm_info.go b/pkg/tools/swarm_info.go new file mode 100644 index 000000000..755966e7c --- /dev/null +++ b/pkg/tools/swarm_info.go @@ -0,0 +1,225 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package tools + +import ( + "context" + "fmt" + "os" + "strings" + "sync" +) + +// SwarmInfoTool provides information about the PicoClaw swarm environment +type SwarmInfoTool struct { + nodeID string + nodeRole string + swarmConfig string + hid string + sid string + workers map[string]WorkerInfo + mu sync.RWMutex +} + +type WorkerInfo struct { + ID string `json:"id"` + Role string `json:"role"` + Capabilities []string `json:"capabilities"` + WorkDir string `json:"work_dir"` + Status string `json:"status"` + Host string `json:"host"` + Port int `json:"port"` +} + +// NewSwarmInfoTool creates a new swarm info tool +func NewSwarmInfoTool() *SwarmInfoTool { + return &SwarmInfoTool{ + workers: make(map[string]WorkerInfo), + } +} + +func (t *SwarmInfoTool) Name() string { + return "swarm_info" +} + +func (t *SwarmInfoTool) Description() string { + return "CRITICAL: Call this FIRST when asked about workers, nodes, or their directories. Returns complete information about all swarm nodes including: node IDs (e.g., claw-xxx), roles (coordinator/worker), work directories (e.g., /Users/dev/service/worker-a), and capabilities. ALWAYS use this before accessing worker directories to understand the swarm layout." +} + +func (t *SwarmInfoTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "Optional: specific query like 'workdirs', 'nodes', 'current', 'all'", + }, + }, + } +} + +func (t *SwarmInfoTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { + t.mu.Lock() + defer t.mu.Unlock() + + query := "" + if q, ok := args["query"].(string); ok { + query = q + } + + // Try to load worker info from known locations + t.discoverWorkers() + + var result strings.Builder + + if query == "" || query == "all" { + result.WriteString("🦀 PicoClaw Swarm Environment\n\n") + result.WriteString(t.formatAllInfo()) + } else if query == "workdirs" { + result.WriteString("📁 Worker Directories:\n\n") + result.WriteString(t.formatWorkDirs()) + } else if query == "nodes" { + result.WriteString("🔗 Swarm Nodes:\n\n") + result.WriteString(t.formatNodes()) + } else if query == "current" { + result.WriteString("📍 Current Node:\n\n") + result.WriteString(t.formatCurrentNode()) + } else { + result.WriteString("🦀 PicoClaw Swarm Environment\n\n") + result.WriteString(t.formatAllInfo()) + } + + return &ToolResult{ + ForLLM: result.String(), + ForUser: result.String(), + IsError: false, + } +} + +func (t *SwarmInfoTool) discoverWorkers() { + // Add known workers from config + workers := []WorkerInfo{ + { + ID: "coordinator", + Role: "coordinator", + Capabilities: []string{"orchestration", "scheduling"}, + WorkDir: "/Users/dev/service/coordinator", + Status: "online", + Host: "localhost", + }, + { + ID: "worker-a", + Role: "worker", + Capabilities: []string{"code", "macos"}, + WorkDir: "/Users/dev/service/worker-a", + Status: "online", + Host: "localhost", + }, + { + ID: "worker-b", + Role: "worker", + Capabilities: []string{"search", "windows"}, + WorkDir: "/Users/dev/service/worker-b", + Status: "online", + Host: "localhost", + }, + } + + for _, w := range workers { + t.workers[w.ID] = w + } +} + +func (t *SwarmInfoTool) formatAllInfo() string { + var s strings.Builder + + s.WriteString("📋 SWARM NODES DIRECTORY MAP\n\n") + + s.WriteString("When asked about worker directories, use these exact paths:\n\n") + + for _, w := range t.workers { + s.WriteString(fmt.Sprintf("【%s】%s\n", strings.ToUpper(w.ID), w.ID)) + s.WriteString(fmt.Sprintf(" Role: %s\n", w.Role)) + s.WriteString(fmt.Sprintf(" WorkDir: %s\n", w.WorkDir)) + s.WriteString(fmt.Sprintf(" Capabilities: %s\n", strings.Join(w.Capabilities, ", "))) + s.WriteString(fmt.Sprintf(" → To list files: ls %s\n\n", w.WorkDir)) + } + + s.WriteString("⚠️ IMPORTANT:\n") + s.WriteString(" - Each worker has its own work directory\n") + s.WriteString(" - Use the full path (e.g., /Users/dev/service/worker-a) when accessing files\n") + s.WriteString(" - Do not use relative paths or 'worker' subdirectory\n") + + return s.String() +} + +func (t *SwarmInfoTool) formatWorkDirs() string { + var s strings.Builder + for _, w := range t.workers { + s.WriteString(fmt.Sprintf("%s: %s\n", w.ID, w.WorkDir)) + s.WriteString(fmt.Sprintf(" → ls %s\n", w.WorkDir)) + } + return s.String() +} + +func (t *SwarmInfoTool) formatNodes() string { + var s strings.Builder + for _, w := range t.workers { + s.WriteString(fmt.Sprintf("- %s: %s\n", w.ID, w.Role)) + s.WriteString(fmt.Sprintf(" Capabilities: %v\n", w.Capabilities)) + s.WriteString(fmt.Sprintf(" WorkDir: %s\n", w.WorkDir)) + } + return s.String() +} + +func (t *SwarmInfoTool) formatCurrentNode() string { + // Try to detect current node by checking working directory + cwd, _ := os.Getwd() + var s strings.Builder + + s.WriteString(fmt.Sprintf("Working Directory: %s\n", cwd)) + + // Determine which node this is based on path + if strings.Contains(cwd, "coordinator") { + s.WriteString("\nDetected: coordinator node\n") + s.WriteString("WorkDir: /Users/dev/service/coordinator\n") + } else if strings.Contains(cwd, "worker-a") { + s.WriteString("\nDetected: worker-a node\n") + s.WriteString("WorkDir: /Users/dev/service/worker-a\n") + } else if strings.Contains(cwd, "worker-b") { + s.WriteString("\nDetected: worker-b node\n") + s.WriteString("WorkDir: /Users/dev/service/worker-b\n") + } else { + s.WriteString("\nDetected: gateway/workspace node\n") + s.WriteString("WorkDir: /Users/dev/workspace\n") + } + + return s.String() +} + +// SetNodeInfo sets the current node information +func (t *SwarmInfoTool) SetNodeInfo(nodeID, nodeRole, hid, sid string) { + t.mu.Lock() + defer t.mu.Unlock() + t.nodeID = nodeID + t.nodeRole = nodeRole + t.hid = hid + t.sid = sid +} + +// AddWorker registers a worker's information +func (t *SwarmInfoTool) AddWorker(id, role string, capabilities []string, workDir string) { + t.mu.Lock() + defer t.mu.Unlock() + t.workers[id] = WorkerInfo{ + ID: id, + Role: role, + Capabilities: capabilities, + WorkDir: workDir, + Status: "online", + Host: "localhost", + } +} diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index 1302079b4..cdfe0d6ce 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -33,7 +33,12 @@ type ToolLoopResult struct { // RunToolLoop executes the LLM + tool call iteration loop. // This is the core agent logic that can be reused by both main agent and subagents. -func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) { +func RunToolLoop( + ctx context.Context, + config ToolLoopConfig, + messages []providers.Message, + channel, chatID string, +) (*ToolLoopResult, error) { iteration := 0 var finalContent string @@ -55,12 +60,8 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider // 2. Set default LLM options llmOpts := config.LLMOptions if llmOpts == nil { - llmOpts = map[string]any{ - "max_tokens": 4096, - "temperature": 0.7, - } + llmOpts = map[string]any{} } - // 3. Call LLM response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts) if err != nil { @@ -83,15 +84,20 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider break } - // 5. Log tool calls - toolNames := make([]string, 0, len(response.ToolCalls)) + normalizedToolCalls := make([]providers.ToolCall, 0, len(response.ToolCalls)) for _, tc := range response.ToolCalls { + normalizedToolCalls = append(normalizedToolCalls, providers.NormalizeToolCall(tc)) + } + + // 5. Log tool calls + toolNames := make([]string, 0, len(normalizedToolCalls)) + for _, tc := range normalizedToolCalls { toolNames = append(toolNames, tc.Name) } logger.InfoCF("toolloop", "LLM requested tool calls", map[string]any{ "tools": toolNames, - "count": len(response.ToolCalls), + "count": len(normalizedToolCalls), "iteration": iteration, }) @@ -100,11 +106,13 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider Role: "assistant", Content: response.Content, } - for _, tc := range response.ToolCalls { + for _, tc := range normalizedToolCalls { argumentsJSON, _ := json.Marshal(tc.Arguments) assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ - ID: tc.ID, - Type: "function", + ID: tc.ID, + Type: "function", + Name: tc.Name, + Arguments: tc.Arguments, Function: &providers.FunctionCall{ Name: tc.Name, Arguments: string(argumentsJSON), @@ -114,7 +122,7 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider messages = append(messages, assistantMsg) // 7. Execute tool calls - for _, tc := range response.ToolCalls { + for _, tc := range normalizedToolCalls { argsJSON, _ := json.Marshal(tc.Arguments) argsPreview := utils.Truncate(string(argsJSON), 200) logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), diff --git a/pkg/tools/types.go b/pkg/tools/types.go index f8205b8bd..a6015cde3 100644 --- a/pkg/tools/types.go +++ b/pkg/tools/types.go @@ -10,11 +10,11 @@ type Message struct { } type ToolCall struct { - ID string `json:"id"` - Type string `json:"type"` - Function *FunctionCall `json:"function,omitempty"` - Name string `json:"name,omitempty"` - Arguments map[string]interface{} `json:"arguments,omitempty"` + ID string `json:"id"` + Type string `json:"type"` + Function *FunctionCall `json:"function,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]any `json:"arguments,omitempty"` } type FunctionCall struct { @@ -36,7 +36,13 @@ type UsageInfo struct { } type LLMProvider interface { - Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) + Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + options map[string]any, + ) (*LLMResponse, error) GetDefaultModel() string } @@ -46,7 +52,7 @@ type ToolDefinition struct { } type ToolFunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]interface{} `json:"parameters"` + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]any `json:"parameters"` } diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 3e8b7e9e8..301e00daf 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -13,89 +13,39 @@ import ( ) const ( - userAgent = "Mozilla/5.0 (compatible; picoclaw/1.0)" + userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" ) -type WebSearchTool struct { - apiKey string - maxResults int -} - -func NewWebSearchTool(apiKey string, maxResults int) *WebSearchTool { - if maxResults <= 0 || maxResults > 10 { - maxResults = 5 - } - return &WebSearchTool{ - apiKey: apiKey, - maxResults: maxResults, - } +type SearchProvider interface { + Search(ctx context.Context, query string, count int) (string, error) } -func (t *WebSearchTool) Name() string { - return "web_search" +type BraveSearchProvider struct { + apiKey string } -func (t *WebSearchTool) Description() string { - return "Search the web for current information. Returns titles, URLs, and snippets from search results." -} - -func (t *WebSearchTool) Parameters() map[string]interface{} { - return map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "query": map[string]interface{}{ - "type": "string", - "description": "Search query", - }, - "count": map[string]interface{}{ - "type": "integer", - "description": "Number of results (1-10)", - "minimum": 1.0, - "maximum": 10.0, - }, - }, - "required": []string{"query"}, - } -} - -func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { - if t.apiKey == "" { - return ErrorResult("BRAVE_API_KEY not configured") - } - - query, ok := args["query"].(string) - if !ok { - return ErrorResult("query is required") - } - - count := t.maxResults - if c, ok := args["count"].(float64); ok { - if int(c) > 0 && int(c) <= 10 { - count = int(c) - } - } - +func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d", url.QueryEscape(query), count) req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) if err != nil { - return ErrorResult(fmt.Sprintf("failed to create request: %v", err)) + return "", fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Accept", "application/json") - req.Header.Set("X-Subscription-Token", t.apiKey) + req.Header.Set("X-Subscription-Token", p.apiKey) client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Do(req) if err != nil { - return ErrorResult(fmt.Sprintf("request failed: %v", err)) + return "", fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return ErrorResult(fmt.Sprintf("failed to read response: %v", err)) + return "", fmt.Errorf("failed to read response: %w", err) } var searchResp struct { @@ -109,16 +59,14 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{} } if err := json.Unmarshal(body, &searchResp); err != nil { - return ErrorResult(fmt.Sprintf("failed to parse response: %v", err)) + // Log error body for debugging + fmt.Printf("Brave API Error Body: %s\n", string(body)) + return "", fmt.Errorf("failed to parse response: %w", err) } results := searchResp.Web.Results if len(results) == 0 { - msg := fmt.Sprintf("No results for: %s", query) - return &ToolResult{ - ForLLM: msg, - ForUser: msg, - } + return fmt.Sprintf("No results for: %s", query), nil } var lines []string @@ -133,10 +81,266 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{} } } - output := strings.Join(lines, "\n") + return strings.Join(lines, "\n"), nil +} + +type DuckDuckGoSearchProvider struct{} + +func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { + searchURL := fmt.Sprintf("https://html.duckduckgo.com/html/?q=%s", url.QueryEscape(query)) + + req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("User-Agent", userAgent) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + return p.extractResults(string(body), count, query) +} + +func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query string) (string, error) { + // Simple regex based extraction for DDG HTML + // Strategy: Find all result containers or key anchors directly + + // Try finding the result links directly first, as they are the most critical + // Pattern: Title + // The previous regex was a bit strict. Let's make it more flexible for attributes order/content + reLink := regexp.MustCompile(`]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)`) + matches := reLink.FindAllStringSubmatch(html, count+5) + + if len(matches) == 0 { + return fmt.Sprintf("No results found or extraction failed. Query: %s", query), nil + } + + var lines []string + lines = append(lines, fmt.Sprintf("Results for: %s (via DuckDuckGo)", query)) + + // Pre-compile snippet regex to run inside the loop + // We'll search for snippets relative to the link position or just globally if needed + // But simple global search for snippets might mismatch order. + // Since we only have the raw HTML string, let's just extract snippets globally and assume order matches (risky but simple for regex) + // Or better: Let's assume the snippet follows the link in the HTML + + // A better regex approach: iterate through text and find matches in order + // But for now, let's grab all snippets too + reSnippet := regexp.MustCompile(`([\s\S]*?)`) + snippetMatches := reSnippet.FindAllStringSubmatch(html, count+5) + + maxItems := min(len(matches), count) + + for i := 0; i < maxItems; i++ { + urlStr := matches[i][1] + title := stripTags(matches[i][2]) + title = strings.TrimSpace(title) + + // URL decoding if needed + if strings.Contains(urlStr, "uddg=") { + if u, err := url.QueryUnescape(urlStr); err == nil { + idx := strings.Index(u, "uddg=") + if idx != -1 { + urlStr = u[idx+5:] + } + } + } + + lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, title, urlStr)) + + // Attempt to attach snippet if available and index aligns + if i < len(snippetMatches) { + snippet := stripTags(snippetMatches[i][1]) + snippet = strings.TrimSpace(snippet) + if snippet != "" { + lines = append(lines, fmt.Sprintf(" %s", snippet)) + } + } + } + + return strings.Join(lines, "\n"), nil +} + +func stripTags(content string) string { + re := regexp.MustCompile(`<[^>]+>`) + return re.ReplaceAllString(content, "") +} + +type PerplexitySearchProvider struct { + apiKey string +} + +func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) { + searchURL := "https://api.perplexity.ai/chat/completions" + + payload := map[string]any{ + "model": "sonar", + "messages": []map[string]string{ + { + "role": "system", + "content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.", + }, + { + "role": "user", + "content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count), + }, + }, + "max_tokens": 1000, + } + + payloadBytes, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes))) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+p.apiKey) + req.Header.Set("User-Agent", userAgent) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("Perplexity API error: %s", string(body)) + } + + var searchResp struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + + if err := json.Unmarshal(body, &searchResp); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if len(searchResp.Choices) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil +} + +type WebSearchTool struct { + provider SearchProvider + maxResults int +} + +type WebSearchToolOptions struct { + BraveAPIKey string + BraveMaxResults int + BraveEnabled bool + DuckDuckGoMaxResults int + DuckDuckGoEnabled bool + PerplexityAPIKey string + PerplexityMaxResults int + PerplexityEnabled bool +} + +func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool { + var provider SearchProvider + maxResults := 5 + + // Priority: Perplexity > Brave > DuckDuckGo + if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" { + provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey} + if opts.PerplexityMaxResults > 0 { + maxResults = opts.PerplexityMaxResults + } + } else if opts.BraveEnabled && opts.BraveAPIKey != "" { + provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey} + if opts.BraveMaxResults > 0 { + maxResults = opts.BraveMaxResults + } + } else if opts.DuckDuckGoEnabled { + provider = &DuckDuckGoSearchProvider{} + if opts.DuckDuckGoMaxResults > 0 { + maxResults = opts.DuckDuckGoMaxResults + } + } else { + return nil + } + + return &WebSearchTool{ + provider: provider, + maxResults: maxResults, + } +} + +func (t *WebSearchTool) Name() string { + return "web_search" +} + +func (t *WebSearchTool) Description() string { + return "Search the web for current information. Returns titles, URLs, and snippets from search results." +} + +func (t *WebSearchTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "Search query", + }, + "count": map[string]any{ + "type": "integer", + "description": "Number of results (1-10)", + "minimum": 1.0, + "maximum": 10.0, + }, + }, + "required": []string{"query"}, + } +} + +func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + query, ok := args["query"].(string) + if !ok { + return ErrorResult("query is required") + } + + count := t.maxResults + if c, ok := args["count"].(float64); ok { + if int(c) > 0 && int(c) <= 10 { + count = int(c) + } + } + + result, err := t.provider.Search(ctx, query, count) + if err != nil { + return ErrorResult(fmt.Sprintf("search failed: %v", err)) + } + return &ToolResult{ - ForLLM: fmt.Sprintf("Found %d results for: %s", len(results), query), - ForUser: output, + ForLLM: result, + ForUser: result, } } @@ -161,15 +365,15 @@ func (t *WebFetchTool) Description() string { return "Fetch a URL and extract readable content (HTML to text). Use this to get weather info, news, articles, or any web content." } -func (t *WebFetchTool) Parameters() map[string]interface{} { - return map[string]interface{}{ +func (t *WebFetchTool) Parameters() map[string]any { + return map[string]any{ "type": "object", - "properties": map[string]interface{}{ - "url": map[string]interface{}{ + "properties": map[string]any{ + "url": map[string]any{ "type": "string", "description": "URL to fetch", }, - "maxChars": map[string]interface{}{ + "maxChars": map[string]any{ "type": "integer", "description": "Maximum characters to extract", "minimum": 100.0, @@ -179,7 +383,7 @@ func (t *WebFetchTool) Parameters() map[string]interface{} { } } -func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolResult { urlStr, ok := args["url"].(string) if !ok { return ErrorResult("url is required") @@ -244,7 +448,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) var text, extractor string if strings.Contains(contentType, "application/json") { - var jsonData interface{} + var jsonData any if err := json.Unmarshal(body, &jsonData); err == nil { formatted, _ := json.MarshalIndent(jsonData, "", " ") text = string(formatted) @@ -267,7 +471,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) text = text[:maxChars] } - result := map[string]interface{}{ + result := map[string]any{ "url": urlStr, "status": resp.StatusCode, "extractor": extractor, @@ -279,7 +483,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) resultJSON, _ := json.MarshalIndent(result, "", " ") return &ToolResult{ - ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated), + ForLLM: fmt.Sprintf( + "Fetched %d bytes from %s (extractor: %s, truncated: %v)", + len(text), + urlStr, + extractor, + truncated, + ), ForUser: string(resultJSON), } } @@ -294,8 +504,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string { result = strings.TrimSpace(result) - re = regexp.MustCompile(`\s+`) - result = re.ReplaceAllLiteralString(result, " ") + re = regexp.MustCompile(`[^\S\n]+`) + result = re.ReplaceAllString(result, " ") + re = regexp.MustCompile(`\n{3,}`) + result = re.ReplaceAllString(result, "\n\n") lines := strings.Split(result, "\n") var cleanLines []string diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 30bc7d991..d999d8958 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -20,7 +20,7 @@ func TestWebTool_WebFetch_Success(t *testing.T) { tool := NewWebFetchTool(50000) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "url": server.URL, } @@ -56,7 +56,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) { tool := NewWebFetchTool(50000) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "url": server.URL, } @@ -77,7 +77,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) { func TestWebTool_WebFetch_InvalidURL(t *testing.T) { tool := NewWebFetchTool(50000) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "url": "not-a-valid-url", } @@ -98,7 +98,7 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) { func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) { tool := NewWebFetchTool(50000) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "url": "ftp://example.com/file.txt", } @@ -119,7 +119,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) { func TestWebTool_WebFetch_MissingURL(t *testing.T) { tool := NewWebFetchTool(50000) ctx := context.Background() - args := map[string]interface{}{} + args := map[string]any{} result := tool.Execute(ctx, args) @@ -147,7 +147,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { tool := NewWebFetchTool(1000) // Limit to 1000 chars ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "url": server.URL, } @@ -159,7 +159,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { } // ForUser should contain truncated content (not the full 20000 chars) - resultMap := make(map[string]interface{}) + resultMap := make(map[string]any) json.Unmarshal([]byte(result.ForUser), &resultMap) if text, ok := resultMap["text"].(string); ok { if len(text) > 1100 { // Allow some margin @@ -173,32 +173,25 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { } } -// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing +// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing func TestWebTool_WebSearch_NoApiKey(t *testing.T) { - tool := NewWebSearchTool("", 5) - ctx := context.Background() - args := map[string]interface{}{ - "query": "test", - } - - result := tool.Execute(ctx, args) - - // Should return error result - if !result.IsError { - t.Errorf("Expected error when API key is missing") + tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""}) + if tool != nil { + t.Errorf("Expected nil tool when Brave API key is empty") } - // Should mention missing API key - if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") { - t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM) + // Also nil when nothing is enabled + tool = NewWebSearchTool(WebSearchToolOptions{}) + if tool != nil { + t.Errorf("Expected nil tool when no provider is enabled") } } // TestWebTool_WebSearch_MissingQuery verifies error handling for missing query func TestWebTool_WebSearch_MissingQuery(t *testing.T) { - tool := NewWebSearchTool("test-key", 5) + tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5}) ctx := context.Background() - args := map[string]interface{}{} + args := map[string]any{} result := tool.Execute(ctx, args) @@ -213,13 +206,17 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusOK) - w.Write([]byte(`

Title

Content

`)) + w.Write( + []byte( + `

Title

Content

`, + ), + ) })) defer server.Close() tool := NewWebFetchTool(50000) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "url": server.URL, } @@ -241,11 +238,86 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { } } +// TestWebFetchTool_extractText verifies text extraction preserves newlines +func TestWebFetchTool_extractText(t *testing.T) { + tool := &WebFetchTool{} + + tests := []struct { + name string + input string + wantFunc func(t *testing.T, got string) + }{ + { + name: "preserves newlines between block elements", + input: "

Title

\n

Paragraph 1

\n

Paragraph 2

", + wantFunc: func(t *testing.T, got string) { + lines := strings.Split(got, "\n") + if len(lines) < 2 { + t.Errorf("Expected multiple lines, got %d: %q", len(lines), got) + } + if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") || + !strings.Contains(got, "Paragraph 2") { + t.Errorf("Missing expected text: %q", got) + } + }, + }, + { + name: "removes script and style tags", + input: "

Keep this

", + wantFunc: func(t *testing.T, got string) { + if strings.Contains(got, "alert") || strings.Contains(got, "body{}") { + t.Errorf("Expected script/style content removed, got: %q", got) + } + if !strings.Contains(got, "Keep this") { + t.Errorf("Expected 'Keep this' to remain, got: %q", got) + } + }, + }, + { + name: "collapses excessive blank lines", + input: "

A

\n\n\n\n\n

B

", + wantFunc: func(t *testing.T, got string) { + if strings.Contains(got, "\n\n\n") { + t.Errorf("Expected excessive blank lines collapsed, got: %q", got) + } + }, + }, + { + name: "collapses horizontal whitespace", + input: "

hello world

", + wantFunc: func(t *testing.T, got string) { + if strings.Contains(got, " ") { + t.Errorf("Expected spaces collapsed, got: %q", got) + } + if !strings.Contains(got, "hello world") { + t.Errorf("Expected 'hello world', got: %q", got) + } + }, + }, + { + name: "empty input", + input: "", + wantFunc: func(t *testing.T, got string) { + if got != "" { + t.Errorf("Expected empty string, got: %q", got) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tool.extractText(tt.input) + tt.wantFunc(t, got) + }) + } +} + // TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain func TestWebTool_WebFetch_MissingDomain(t *testing.T) { tool := NewWebFetchTool(50000) ctx := context.Background() - args := map[string]interface{}{ + args := map[string]any{ "url": "https://", } diff --git a/pkg/utils/download.go b/pkg/utils/download.go new file mode 100644 index 000000000..5d9a13a30 --- /dev/null +++ b/pkg/utils/download.go @@ -0,0 +1,93 @@ +package utils + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// DownloadToFile streams an HTTP response body to a temporary file in small +// chunks (~32KB), keeping peak memory usage constant regardless of file size. +// +// Parameters: +// - ctx: context for cancellation/timeout +// - client: HTTP client to use (caller controls timeouts, transport, etc.) +// - req: fully prepared *http.Request (method, URL, headers, etc.) +// - maxBytes: maximum bytes to download; 0 means no limit +// +// Returns the path to the temporary file. The caller is responsible for +// removing it when done (defer os.Remove(path)). +// +// On any error the temp file is cleaned up automatically. +func DownloadToFile(ctx context.Context, client *http.Client, req *http.Request, maxBytes int64) (string, error) { + // Attach context. + req = req.WithContext(ctx) + + logger.DebugCF("download", "Starting download", map[string]any{ + "url": req.URL.String(), + "max_bytes": maxBytes, + }) + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + // Read a small amount for the error message. + errBody := make([]byte, 512) + n, _ := io.ReadFull(resp.Body, errBody) + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n])) + } + + // Create temp file. + tmpFile, err := os.CreateTemp("", "picoclaw-dl-*") + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + logger.DebugCF("download", "Streaming to temp file", map[string]any{ + "path": tmpPath, + }) + + // Cleanup helper — removes the temp file on any error. + cleanup := func() { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + } + + // Optionally limit the download size. + var src io.Reader = resp.Body + if maxBytes > 0 { + src = io.LimitReader(resp.Body, maxBytes+1) // +1 to detect overflow + } + + written, err := io.Copy(tmpFile, src) + if err != nil { + cleanup() + return "", fmt.Errorf("download write failed: %w", err) + } + + if maxBytes > 0 && written > maxBytes { + cleanup() + return "", fmt.Errorf("download too large: %d bytes (max %d)", written, maxBytes) + } + + if err := tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Errorf("failed to close temp file: %w", err) + } + + logger.DebugCF("download", "Download complete", map[string]any{ + "path": tmpPath, + "bytes_written": written, + }) + + return tmpPath, nil +} diff --git a/pkg/utils/media.go b/pkg/utils/media.go index 6345da8fc..a34889fb8 100644 --- a/pkg/utils/media.go +++ b/pkg/utils/media.go @@ -9,6 +9,7 @@ import ( "time" "github.com/google/uuid" + "github.com/sipeed/picoclaw/pkg/logger" ) @@ -65,22 +66,21 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { } mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0700); err != nil { - logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{ + if err := os.MkdirAll(mediaDir, 0o700); err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]any{ "error": err.Error(), }) return "" } // Generate unique filename with UUID prefix to prevent conflicts - ext := filepath.Ext(filename) safeName := SanitizeFilename(filename) - localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext) + localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName) // Create HTTP request req, err := http.NewRequest("GET", url, nil) if err != nil { - logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{ + logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{ "error": err.Error(), }) return "" @@ -94,7 +94,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { client := &http.Client{Timeout: opts.Timeout} resp, err := client.Do(req) if err != nil { - logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{ + logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{ "error": err.Error(), "url": url, }) @@ -103,7 +103,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{ + logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{ "status": resp.StatusCode, "url": url, }) @@ -112,7 +112,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { out, err := os.Create(localPath) if err != nil { - logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{ + logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]any{ "error": err.Error(), }) return "" @@ -122,13 +122,13 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { if _, err := io.Copy(out, resp.Body); err != nil { out.Close() os.Remove(localPath) - logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{ + logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]any{ "error": err.Error(), }) return "" } - logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{ + logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]any{ "path": localPath, }) diff --git a/pkg/utils/message.go b/pkg/utils/message.go new file mode 100644 index 000000000..1d05950d9 --- /dev/null +++ b/pkg/utils/message.go @@ -0,0 +1,179 @@ +package utils + +import ( + "strings" +) + +// SplitMessage splits long messages into chunks, preserving code block integrity. +// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks, +// but may extend to maxLen when needed. +// Call SplitMessage with the full text content and the maximum allowed length of a single message; +// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks. +func SplitMessage(content string, maxLen int) []string { + var messages []string + + // Dynamic buffer: 10% of maxLen, but at least 50 chars if possible + codeBlockBuffer := maxLen / 10 + if codeBlockBuffer < 50 { + codeBlockBuffer = 50 + } + if codeBlockBuffer > maxLen/2 { + codeBlockBuffer = maxLen / 2 + } + + for len(content) > 0 { + if len(content) <= maxLen { + messages = append(messages, content) + break + } + + // Effective split point: maxLen minus buffer, to leave room for code blocks + effectiveLimit := maxLen - codeBlockBuffer + if effectiveLimit < maxLen/2 { + effectiveLimit = maxLen / 2 + } + + // Find natural split point within the effective limit + msgEnd := findLastNewline(content[:effectiveLimit], 200) + if msgEnd <= 0 { + msgEnd = findLastSpace(content[:effectiveLimit], 100) + } + if msgEnd <= 0 { + msgEnd = effectiveLimit + } + + // Check if this would end with an incomplete code block + candidate := content[:msgEnd] + unclosedIdx := findLastUnclosedCodeBlock(candidate) + + if unclosedIdx >= 0 { + // Message would end with incomplete code block + // Try to extend up to maxLen to include the closing ``` + if len(content) > msgEnd { + closingIdx := findNextClosingCodeBlock(content, msgEnd) + if closingIdx > 0 && closingIdx <= maxLen { + // Extend to include the closing ``` + msgEnd = closingIdx + } else { + // Code block is too long to fit in one chunk or missing closing fence. + // Try to split inside by injecting closing and reopening fences. + headerEnd := strings.Index(content[unclosedIdx:], "\n") + if headerEnd == -1 { + headerEnd = unclosedIdx + 3 + } else { + headerEnd += unclosedIdx + } + header := strings.TrimSpace(content[unclosedIdx:headerEnd]) + + // If we have a reasonable amount of content after the header, split inside + if msgEnd > headerEnd+20 { + // Find a better split point closer to maxLen + innerLimit := maxLen - 5 // Leave room for "\n```" + betterEnd := findLastNewline(content[:innerLimit], 200) + if betterEnd > headerEnd { + msgEnd = betterEnd + } else { + msgEnd = innerLimit + } + messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") + content = strings.TrimSpace(header + "\n" + content[msgEnd:]) + continue + } + + // Otherwise, try to split before the code block starts + newEnd := findLastNewline(content[:unclosedIdx], 200) + if newEnd <= 0 { + newEnd = findLastSpace(content[:unclosedIdx], 100) + } + if newEnd > 0 { + msgEnd = newEnd + } else { + // If we can't split before, we MUST split inside (last resort) + if unclosedIdx > 20 { + msgEnd = unclosedIdx + } else { + msgEnd = maxLen - 5 + messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") + content = strings.TrimSpace(header + "\n" + content[msgEnd:]) + continue + } + } + } + } + } + + if msgEnd <= 0 { + msgEnd = effectiveLimit + } + + messages = append(messages, content[:msgEnd]) + content = strings.TrimSpace(content[msgEnd:]) + } + + return messages +} + +// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ``` +// Returns the position of the opening ``` or -1 if all code blocks are complete +func findLastUnclosedCodeBlock(text string) int { + inCodeBlock := false + lastOpenIdx := -1 + + for i := 0; i < len(text); i++ { + if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { + // Toggle code block state on each fence + if !inCodeBlock { + // Entering a code block: record this opening fence + lastOpenIdx = i + } + inCodeBlock = !inCodeBlock + i += 2 + } + } + + if inCodeBlock { + return lastOpenIdx + } + return -1 +} + +// findNextClosingCodeBlock finds the next closing ``` starting from a position +// Returns the position after the closing ``` or -1 if not found +func findNextClosingCodeBlock(text string, startIdx int) int { + for i := startIdx; i < len(text); i++ { + if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { + return i + 3 + } + } + return -1 +} + +// findLastNewline finds the last newline character within the last N characters +// Returns the position of the newline or -1 if not found +func findLastNewline(s string, searchWindow int) int { + searchStart := len(s) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(s) - 1; i >= searchStart; i-- { + if s[i] == '\n' { + return i + } + } + return -1 +} + +// findLastSpace finds the last space character within the last N characters +// Returns the position of the space or -1 if not found +func findLastSpace(s string, searchWindow int) int { + searchStart := len(s) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(s) - 1; i >= searchStart; i-- { + if s[i] == ' ' || s[i] == '\t' { + return i + } + } + return -1 +} diff --git a/pkg/utils/message_test.go b/pkg/utils/message_test.go new file mode 100644 index 000000000..338509437 --- /dev/null +++ b/pkg/utils/message_test.go @@ -0,0 +1,151 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestSplitMessage(t *testing.T) { + longText := strings.Repeat("a", 2500) + longCode := "```go\n" + strings.Repeat("fmt.Println(\"hello\")\n", 100) + "```" // ~2100 chars + + tests := []struct { + name string + content string + maxLen int + expectChunks int // Check number of chunks + checkContent func(t *testing.T, chunks []string) // Custom validation + }{ + { + name: "Empty message", + content: "", + maxLen: 2000, + expectChunks: 0, + }, + { + name: "Short message fits in one chunk", + content: "Hello world", + maxLen: 2000, + expectChunks: 1, + }, + { + name: "Simple split regular text", + content: longText, + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + if len(chunks[0]) > 2000 { + t.Errorf("Chunk 0 too large: %d", len(chunks[0])) + } + if len(chunks[0])+len(chunks[1]) != len(longText) { + t.Errorf("Total length mismatch. Got %d, want %d", len(chunks[0])+len(chunks[1]), len(longText)) + } + }, + }, + { + name: "Split at newline", + // 1750 chars then newline, then more chars. + // Dynamic buffer: 2000 / 10 = 200. + // Effective limit: 2000 - 200 = 1800. + // Split should happen at newline because it's at 1750 (< 1800). + // Total length must > 2000 to trigger split. 1750 + 1 + 300 = 2051. + content: strings.Repeat("a", 1750) + "\n" + strings.Repeat("b", 300), + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + if len(chunks[0]) != 1750 { + t.Errorf("Expected chunk 0 to be 1750 length (split at newline), got %d", len(chunks[0])) + } + if chunks[1] != strings.Repeat("b", 300) { + t.Errorf("Chunk 1 content mismatch. Len: %d", len(chunks[1])) + } + }, + }, + { + name: "Long code block split", + content: "Prefix\n" + longCode, + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + // Check that first chunk ends with closing fence + if !strings.HasSuffix(chunks[0], "\n```") { + t.Error("First chunk should end with injected closing fence") + } + // Check that second chunk starts with execution header + if !strings.HasPrefix(chunks[1], "```go") { + t.Error("Second chunk should start with injected code block header") + } + }, + }, + { + name: "Preserve Unicode characters", + content: strings.Repeat("\u4e16", 1000), // 3000 bytes + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + // Just verify we didn't panic and got valid strings. + // Go strings are UTF-8, if we split mid-rune it would be bad, + // but standard slicing might do that. + // Let's assume standard behavior is acceptable or check if it produces invalid rune? + if !strings.Contains(chunks[0], "\u4e16") { + t.Error("Chunk should contain unicode characters") + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := SplitMessage(tc.content, tc.maxLen) + + if tc.expectChunks == 0 { + if len(got) != 0 { + t.Errorf("Expected 0 chunks, got %d", len(got)) + } + return + } + + if len(got) != tc.expectChunks { + t.Errorf("Expected %d chunks, got %d", tc.expectChunks, len(got)) + // Log sizes for debugging + for i, c := range got { + t.Logf("Chunk %d length: %d", i, len(c)) + } + return // Stop further checks if count assumes specific split + } + + if tc.checkContent != nil { + tc.checkContent(t, got) + } + }) + } +} + +func TestSplitMessage_CodeBlockIntegrity(t *testing.T) { + // Focused test for the core requirement: splitting inside a code block preserves syntax highlighting + + // 60 chars total approximately + content := "```go\npackage main\n\nfunc main() {\n\tprintln(\"Hello\")\n}\n```" + maxLen := 40 + + chunks := SplitMessage(content, maxLen) + + if len(chunks) != 2 { + t.Fatalf("Expected 2 chunks, got %d: %q", len(chunks), chunks) + } + + // First chunk must end with "\n```" + if !strings.HasSuffix(chunks[0], "\n```") { + t.Errorf("First chunk should end with closing fence. Got: %q", chunks[0]) + } + + // Second chunk must start with the header "```go" + if !strings.HasPrefix(chunks[1], "```go") { + t.Errorf("Second chunk should start with code block header. Got: %q", chunks[1]) + } + + // First chunk should contain meaningful content + if len(chunks[0]) > 40 { + t.Errorf("First chunk exceeded maxLen: length %d", len(chunks[0])) + } +} diff --git a/pkg/utils/skills.go b/pkg/utils/skills.go new file mode 100644 index 000000000..1d2cfac7f --- /dev/null +++ b/pkg/utils/skills.go @@ -0,0 +1,19 @@ +package utils + +import ( + "fmt" + "strings" +) + +// ValidateSkillIdentifier validates that the given skill identifier (slug or registry name) is non-empty +// and does not contain path separators ("/", "\\") or ".." for security. +func ValidateSkillIdentifier(identifier string) error { + trimmed := strings.TrimSpace(identifier) + if trimmed == "" { + return fmt.Errorf("identifier is required and must be a non-empty string") + } + if strings.ContainsAny(trimmed, "/\\") || strings.Contains(trimmed, "..") { + return fmt.Errorf("identifier must not contain path separators or '..' to prevent directory traversal") + } + return nil +} diff --git a/pkg/utils/string.go b/pkg/utils/string.go index 0d9837cb9..7a6aa37cc 100644 --- a/pkg/utils/string.go +++ b/pkg/utils/string.go @@ -14,3 +14,12 @@ func Truncate(s string, maxLen int) string { } return string(runes[:maxLen-3]) + "..." } + +// DerefStr dereferences a pointer to a string and +// returns the value or a fallback if the pointer is nil. +func DerefStr(s *string, fallback string) string { + if s == nil { + return fallback + } + return *s +} diff --git a/pkg/utils/zip.go b/pkg/utils/zip.go new file mode 100644 index 000000000..919ce5a20 --- /dev/null +++ b/pkg/utils/zip.go @@ -0,0 +1,121 @@ +package utils + +import ( + "archive/zip" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// ExtractZipFile extracts a ZIP archive from disk to targetDir. +// It reads entries one at a time from disk, keeping memory usage minimal. +// +// Security: rejects path traversal attempts and symlinks. +func ExtractZipFile(zipPath string, targetDir string) error { + reader, err := zip.OpenReader(zipPath) + if err != nil { + return fmt.Errorf("invalid ZIP: %w", err) + } + defer reader.Close() + + logger.DebugCF("zip", "Extracting ZIP", map[string]any{ + "zip_path": zipPath, + "target_dir": targetDir, + "entries": len(reader.File), + }) + + if err := os.MkdirAll(targetDir, 0o755); err != nil { + return fmt.Errorf("failed to create target dir: %w", err) + } + + for _, f := range reader.File { + // Path traversal protection. + cleanName := filepath.Clean(f.Name) + if strings.HasPrefix(cleanName, "..") || filepath.IsAbs(cleanName) { + return fmt.Errorf("zip entry has unsafe path: %q", f.Name) + } + + destPath := filepath.Join(targetDir, cleanName) + + // Double-check the resolved path is within target directory (defense-in-depth). + targetDirClean := filepath.Clean(targetDir) + if !strings.HasPrefix(filepath.Clean(destPath), targetDirClean+string(filepath.Separator)) && + filepath.Clean(destPath) != targetDirClean { + return fmt.Errorf("zip entry escapes target dir: %q", f.Name) + } + + mode := f.FileInfo().Mode() + + // Reject any symlink. + if mode&os.ModeSymlink != 0 { + return fmt.Errorf("zip contains symlink %q; symlinks are not allowed", f.Name) + } + + if f.FileInfo().IsDir() { + if err := os.MkdirAll(destPath, 0o755); err != nil { + return err + } + continue + } + + // Ensure parent directory exists. + if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + return err + } + + if err := extractSingleFile(f, destPath); err != nil { + return err + } + } + + return nil +} + +// extractSingleFile extracts one zip.File entry to destPath, with a size check. +func extractSingleFile(f *zip.File, destPath string) error { + const maxFileSize = 5 * 1024 * 1024 // 5MB, adjust as appropriate + + // Check the uncompressed size from the header, if available. + if f.UncompressedSize64 > maxFileSize { + return fmt.Errorf("zip entry %q is too large (%d bytes)", f.Name, f.UncompressedSize64) + } + + rc, err := f.Open() + if err != nil { + return fmt.Errorf("failed to open zip entry %q: %w", f.Name, err) + } + defer rc.Close() + + outFile, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("failed to create file %q: %w", destPath, err) + } + // We don't return the close error via return, since it's not a named error return. + // Instead, we log to stderr and remove the partially written file as defensive cleanup. + defer func() { + if cerr := outFile.Close(); cerr != nil { + _ = os.Remove(destPath) + logger.ErrorCF("zip", "Failed to close file", map[string]any{ + "dest_path": destPath, + "error": cerr.Error(), + }) + } + }() + + // Streamed size check: prevent overruns and malicious/corrupt headers. + written, err := io.CopyN(outFile, rc, maxFileSize+1) + if err != nil && err != io.EOF { + _ = os.Remove(destPath) + return fmt.Errorf("failed to extract %q: %w", f.Name, err) + } + if written > maxFileSize { + _ = os.Remove(destPath) + return fmt.Errorf("zip entry %q exceeds max size (%d bytes)", f.Name, written) + } + + return nil +} diff --git a/pkg/voice/transcriber.go b/pkg/voice/transcriber.go index 9af2ea6bb..f973e77fe 100644 --- a/pkg/voice/transcriber.go +++ b/pkg/voice/transcriber.go @@ -29,7 +29,7 @@ type TranscriptionResponse struct { } func NewGroqTranscriber(apiKey string) *GroqTranscriber { - logger.DebugCF("voice", "Creating Groq transcriber", map[string]interface{}{"has_api_key": apiKey != ""}) + logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""}) apiBase := "https://api.groq.com/openai/v1" return &GroqTranscriber{ @@ -42,22 +42,22 @@ func NewGroqTranscriber(apiKey string) *GroqTranscriber { } func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) { - logger.InfoCF("voice", "Starting transcription", map[string]interface{}{"audio_file": audioFilePath}) + logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath}) audioFile, err := os.Open(audioFilePath) if err != nil { - logger.ErrorCF("voice", "Failed to open audio file", map[string]interface{}{"path": audioFilePath, "error": err}) + logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err}) return nil, fmt.Errorf("failed to open audio file: %w", err) } defer audioFile.Close() fileInfo, err := audioFile.Stat() if err != nil { - logger.ErrorCF("voice", "Failed to get file info", map[string]interface{}{"path": audioFilePath, "error": err}) + logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err}) return nil, fmt.Errorf("failed to get file info: %w", err) } - logger.DebugCF("voice", "Audio file details", map[string]interface{}{ + logger.DebugCF("voice", "Audio file details", map[string]any{ "size_bytes": fileInfo.Size(), "file_name": filepath.Base(audioFilePath), }) @@ -67,44 +67,44 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath)) if err != nil { - logger.ErrorCF("voice", "Failed to create form file", map[string]interface{}{"error": err}) + logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err}) return nil, fmt.Errorf("failed to create form file: %w", err) } copied, err := io.Copy(part, audioFile) if err != nil { - logger.ErrorCF("voice", "Failed to copy file content", map[string]interface{}{"error": err}) + logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err}) return nil, fmt.Errorf("failed to copy file content: %w", err) } - logger.DebugCF("voice", "File copied to request", map[string]interface{}{"bytes_copied": copied}) + logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied}) - if err := writer.WriteField("model", "whisper-large-v3"); err != nil { - logger.ErrorCF("voice", "Failed to write model field", map[string]interface{}{"error": err}) + if err = writer.WriteField("model", "whisper-large-v3"); err != nil { + logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err}) return nil, fmt.Errorf("failed to write model field: %w", err) } - if err := writer.WriteField("response_format", "json"); err != nil { - logger.ErrorCF("voice", "Failed to write response_format field", map[string]interface{}{"error": err}) + if err = writer.WriteField("response_format", "json"); err != nil { + logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err}) return nil, fmt.Errorf("failed to write response_format field: %w", err) } - if err := writer.Close(); err != nil { - logger.ErrorCF("voice", "Failed to close multipart writer", map[string]interface{}{"error": err}) + if err = writer.Close(); err != nil { + logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err}) return nil, fmt.Errorf("failed to close multipart writer: %w", err) } url := t.apiBase + "/audio/transcriptions" req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody) if err != nil { - logger.ErrorCF("voice", "Failed to create request", map[string]interface{}{"error": err}) + logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err}) return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", writer.FormDataContentType()) req.Header.Set("Authorization", "Bearer "+t.apiKey) - logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]interface{}{ + logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{ "url": url, "request_size_bytes": requestBody.Len(), "file_size_bytes": fileInfo.Size(), @@ -112,37 +112,37 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) resp, err := t.httpClient.Do(req) if err != nil { - logger.ErrorCF("voice", "Failed to send request", map[string]interface{}{"error": err}) + logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err}) return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - logger.ErrorCF("voice", "Failed to read response", map[string]interface{}{"error": err}) + logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err}) return nil, fmt.Errorf("failed to read response: %w", err) } if resp.StatusCode != http.StatusOK { - logger.ErrorCF("voice", "API error", map[string]interface{}{ + logger.ErrorCF("voice", "API error", map[string]any{ "status_code": resp.StatusCode, "response": string(body), }) return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) } - logger.DebugCF("voice", "Received response from Groq API", map[string]interface{}{ + logger.DebugCF("voice", "Received response from Groq API", map[string]any{ "status_code": resp.StatusCode, "response_size_bytes": len(body), }) var result TranscriptionResponse if err := json.Unmarshal(body, &result); err != nil { - logger.ErrorCF("voice", "Failed to unmarshal response", map[string]interface{}{"error": err}) + logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err}) return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - logger.InfoCF("voice", "Transcription completed successfully", map[string]interface{}{ + logger.InfoCF("voice", "Transcription completed successfully", map[string]any{ "text_length": len(result.Text), "language": result.Language, "duration_seconds": result.Duration, @@ -154,6 +154,6 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) func (t *GroqTranscriber) IsAvailable() bool { available := t.apiKey != "" - logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available}) + logger.DebugCF("voice", "Checking transcriber availability", map[string]any{"available": available}) return available } diff --git a/tasks/prd-tool-result-refactor.md b/tasks/prd-tool-result-refactor.md deleted file mode 100644 index c0e984d53..000000000 --- a/tasks/prd-tool-result-refactor.md +++ /dev/null @@ -1,293 +0,0 @@ -# PRD: Tool 返回值结构化重构 - -## Introduction - -当前 picoclaw 的 Tool 接口返回 `(string, error)`,存在以下问题: - -1. **语义不明确**:返回的字符串是给 LLM 看还是给用户看,无法区分 -2. **字符串匹配黑魔法**:`isToolConfirmationMessage` 靠字符串包含判断是否发送给用户,容易误判 -3. **无法支持异步任务**:心跳触发长任务时会一直阻塞,影响定时器 -4. **状态保存不原子**:`SetLastChannel` 和 `Save` 分离,崩溃时状态不一致 - -本重构将 Tool 返回值改为结构化的 `ToolResult`,明确区分 `ForLLM`(给 AI 看)和 `ForUser`(给用户看),支持异步任务和回调通知,删除字符串匹配逻辑。 - -## Goals - -- Tool 返回结构化的 `ToolResult`,明确区分 LLM 内容和用户内容 -- 支持异步任务执行,心跳触发后不等待完成 -- 异步任务完成时通过回调通知系统 -- 删除 `isToolConfirmationMessage` 字符串匹配黑魔法 -- 状态保存原子化,防止数据不一致 -- 为所有改造添加完整测试覆盖 - -## User Stories - -### US-001: 新增 ToolResult 结构体和辅助函数 -**Description:** 作为开发者,我需要定义新的 ToolResult 结构体和辅助构造函数,以便工具可以明确表达返回结果的语义。 - -**Acceptance Criteria:** -- [ ] `ToolResult` 包含字段:ForLLM, ForUser, Silent, IsError, Async, Err -- [ ] 提供辅助函数:NewToolResult(), SilentResult(), AsyncResult(), ErrorResult(), UserResult() -- [ ] ToolResult 支持 JSON 序列化(除 Err 字段) -- [ ] 添加完整 godoc 注释 -- [ ] `go test ./pkg/tools -run TestToolResult` 通过 - -### US-002: 修改 Tool 接口返回值 -**Description:** 作为开发者,我需要将 Tool 接口的 Execute 方法返回值从 `(string, error)` 改为 `*ToolResult`,以便使用新的结构化返回值。 - -**Acceptance Criteria:** -- [ ] `pkg/tools/base.go` 中 `Tool.Execute()` 签名改为返回 `*ToolResult` -- [ ] 所有实现了 Tool 接口的类型更新方法签名 -- [ ] `go build ./...` 无编译错误 -- [ ] `go vet ./...` 通过 - -### US-003: 修改 ToolRegistry 处理 ToolResult -**Description:** 作为中间层,ToolRegistry 需要处理新的 ToolResult 返回值,并调整日志逻辑以反映异步任务状态。 - -**Acceptance Criteria:** -- [ ] `ExecuteWithContext()` 返回值改为 `*ToolResult` -- [ ] 日志区分:completed / async / failed 三种状态 -- [ ] 异步任务记录启动日志而非完成日志 -- [ ] 错误日志包含 ToolResult.Err 内容 -- [ ] `go test ./pkg/tools -run TestRegistry` 通过 - -### US-004: 删除 isToolConfirmationMessage 字符串匹配 -**Description:** 作为代码维护者,我需要删除 `isToolConfirmationMessage` 函数及相关调用,因为 ToolResult.Silent 字段已经解决了这个问题。 - -**Acceptance Criteria:** -- [ ] 删除 `pkg/agent/loop.go` 中的 `isToolConfirmationMessage` 函数 -- [ ] `runAgentLoop` 中移除对该函数的调用 -- [ ] 工具结果是否发送由 ToolResult.Silent 决定 -- [ ] `go build ./...` 无编译错误 - -### US-005: 修改 AgentLoop 工具结果处理逻辑 -**Description:** 作为 agent 主循环,我需要根据 ToolResult 的字段决定如何处理工具执行结果。 - -**Acceptance Criteria:** -- [ ] LLM 收到的消息内容来自 ToolResult.ForLLM -- [ ] 用户收到的消息优先使用 ToolResult.ForUser,其次使用 LLM 最终回复 -- [ ] ToolResult.Silent 为 true 时不发送用户消息 -- [ ] 记录最后执行的工具结果以便后续判断 -- [ ] `go test ./pkg/agent -run TestLoop` 通过 - -### US-006: 心跳支持异步任务执行 -**Description:** 作为心跳服务,我需要触发异步任务后立即返回,不等待任务完成,避免阻塞定时器。 - -**Acceptance Criteria:** -- [ ] `ExecuteHeartbeatWithTools` 检测 ToolResult.Async 标记 -- [ ] 异步任务返回 "Task started in background" 给 LLM -- [ ] 异步任务不阻塞心跳流程 -- [ ] 删除重复的 `ProcessHeartbeat` 函数 -- [ ] `go test ./pkg/heartbeat -run TestAsync` 通过 - -### US-007: 异步任务完成回调机制 -**Description:** 作为系统,我需要支持异步任务完成后的回调通知,以便任务结果能正确发送给用户。 - -**Acceptance Criteria:** -- [ ] 定义 AsyncCallback 函数类型:`func(ctx context.Context, result *ToolResult)` -- [ ] Tool 添加可选接口 `AsyncTool`,包含 `SetCallback(cb AsyncCallback)` -- [ ] 执行异步工具时注入回调函数 -- [ ] 工具内部 goroutine 完成后调用回调 -- [ ] 回调通过 SendToChannel 发送结果给用户 -- [ ] `go test ./pkg/tools -run TestAsyncCallback` 通过 - -### US-008: 状态保存原子化 -**Description:** 作为状态管理,我需要确保状态更新和保存是原子操作,防止程序崩溃时数据不一致。 - -**Acceptance Criteria:** -- [ ] `SetLastChannel` 合并保存逻辑,接受 workspace 参数 -- [ ] 使用临时文件 + rename 实现原子写入 -- [ ] rename 失败时清理临时文件 -- [ ] 更新时间戳在锁内完成 -- [ ] `go test ./pkg/state -run TestAtomicSave` 通过 - -### US-009: 改造 MessageTool -**Description:** 作为消息发送工具,我需要使用新的 ToolResult 返回值,发送成功后静默不通知用户。 - -**Acceptance Criteria:** -- [ ] 发送成功返回 `SilentResult("Message sent to ...")` -- [ ] 发送失败返回 `ErrorResult(...)` -- [ ] ForLLM 包含发送状态描述 -- [ ] ForUser 为空(用户已直接收到消息) -- [ ] `go test ./pkg/tools -run TestMessageTool` 通过 - -### US-010: 改造 ShellTool -**Description:** 作为 shell 命令工具,我需要将命令结果发送给用户,失败时显示错误信息。 - -**Acceptance Criteria:** -- [ ] 成功返回包含 ForUser = 命令输出的 ToolResult -- [ ] 失败返回 IsError = true 的 ToolResult -- [ ] ForLLM 包含完整输出和退出码 -- [ ] `go test ./pkg/tools -run TestShellTool` 通过 - -### US-011: 改造 FilesystemTool -**Description:** 作为文件操作工具,我需要静默完成文件读写,不向用户发送确认消息。 - -**Acceptance Criteria:** -- [ ] 所有文件操作返回 `SilentResult(...)` -- [ ] 错误时返回 `ErrorResult(...)` -- [ ] ForLLM 包含操作摘要(如 "File updated: /path/to/file") -- [ ] `go test ./pkg/tools -run TestFilesystemTool` 通过 - -### US-012: 改造 WebTool -**Description:** 作为网络请求工具,我需要将抓取的内容发送给用户查看。 - -**Acceptance Criteria:** -- [ ] 成功时 ForUser 包含抓取的内容 -- [ ] ForLLM 包含内容摘要和字节数 -- [ ] 失败时返回 ErrorResult -- [ ] `go test ./pkg/tools -run TestWebTool` 通过 - -### US-013: 改造 EditTool -**Description:** 作为文件编辑工具,我需要静默完成编辑,避免重复内容发送给用户。 - -**Acceptance Criteria:** -- [ ] 编辑成功返回 `SilentResult("File edited: ...")` -- [ ] ForLLM 包含编辑摘要 -- [ ] `go test ./pkg/tools -run TestEditTool` 通过 - -### US-014: 改造 CronTool -**Description:** 作为定时任务工具,我需要静默完成 cron 操作,不发送确认消息。 - -**Acceptance Criteria:** -- [ ] 所有 cron 操作返回 `SilentResult(...)` -- [ ] ForLLM 包含操作摘要(如 "Cron job added: daily-backup") -- [ ] `go test ./pkg/tools -run TestCronTool` 通过 - -### US-015: 改造 SpawnTool -**Description:** 作为子代理生成工具,我需要标记为异步任务,并通过回调通知完成。 - -**Acceptance Criteria:** -- [ ] 实现 `AsyncTool` 接口 -- [ ] 返回 `AsyncResult("Subagent spawned, will report back")` -- [ ] 子代理完成时调用回调发送结果 -- [ ] `go test ./pkg/tools -run TestSpawnTool` 通过 - -### US-016: 改造 SubagentTool -**Description:** 作为子代理工具,我需要将子代理的执行摘要发送给用户。 - -**Acceptance Criteria:** -- [ ] ForUser 包含子代理的输出摘要 -- [ ] ForLLM 包含完整执行详情 -- [ ] `go test ./pkg/tools -run TestSubagentTool` 通过 - -### US-017: 心跳配置默认启用 -**Description:** 作为系统配置,心跳功能应该默认启用,因为这是核心功能。 - -**Acceptance Criteria:** -- [ ] `DefaultConfig()` 中 `Heartbeat.Enabled` 改为 `true` -- [ ] 可通过环境变量 `PICOCLAW_HEARTBEAT_ENABLED=false` 覆盖 -- [ ] 配置文档更新说明默认启用 -- [ ] `go test ./pkg/config -run TestDefaultConfig` 通过 - -### US-018: 心跳日志写入 memory 目录 -**Description:** 作为心跳服务,日志应该写入 memory 目录以便被 LLM 访问和纳入知识系统。 - -**Acceptance Criteria:** -- [ ] 日志路径从 `workspace/heartbeat.log` 改为 `workspace/memory/heartbeat.log` -- [ ] 目录不存在时自动创建 -- [ ] 日志格式保持不变 -- [ ] `go test ./pkg/heartbeat -run TestLogPath` 通过 - -### US-019: 心跳调用 ExecuteHeartbeatWithTools -**Description:** 作为心跳服务,我需要调用支持异步的工具执行方法。 - -**Acceptance Criteria:** -- [ ] `executeHeartbeat` 调用 `handler.ExecuteHeartbeatWithTools(...)` -- [ ] 删除废弃的 `ProcessHeartbeat` 函数 -- [ ] `go build ./...` 无编译错误 - -### US-020: RecordLastChannel 调用原子化方法 -**Description:** 作为 AgentLoop,我需要调用新的原子化状态保存方法。 - -**Acceptance Criteria:** -- [ ] `RecordLastChannel` 调用 `st.SetLastChannel(al.workspace, lastChannel)` -- [ ] 传参包含 workspace 路径 -- [ ] `go test ./pkg/agent -run TestRecordLastChannel` 通过 - -## Functional Requirements - -- FR-1: ToolResult 结构体包含 ForLLM, ForUser, Silent, IsError, Async, Err 字段 -- FR-2: 提供 5 个辅助构造函数:NewToolResult, SilentResult, AsyncResult, ErrorResult, UserResult -- FR-3: Tool 接口 Execute 方法返回 `*ToolResult` -- FR-4: ToolRegistry 处理 ToolResult 并记录日志(区分 async/completed/failed) -- FR-5: AgentLoop 根据 ToolResult.Silent 决定是否发送用户消息 -- FR-6: 异步任务不阻塞心跳流程,返回 "Task started in background" -- FR-7: 工具可实现 AsyncTool 接口接收完成回调 -- FR-8: 状态保存使用临时文件 + rename 实现原子操作 -- FR-9: 心跳默认启用(Enabled: true) -- FR-10: 心跳日志写入 `workspace/memory/heartbeat.log` - -## Non-Goals (Out of Scope) - -- 不支持工具返回复杂对象(仅结构化文本) -- 不实现任务队列系统(异步任务由工具自己管理) -- 不支持异步任务超时取消 -- 不实现异步任务状态查询 API -- 不修改 LLMProvider 接口 -- 不支持嵌套异步任务 - -## Design Considerations - -### ToolResult 设计原则 -- **ForLLM**: 给 AI 看的内容,用于推理和决策 -- **ForUser**: 给用户看的内容,会通过 channel 发送 -- **Silent**: 为 true 时完全不发送用户消息 -- **Async**: 为 true 时任务在后台执行,立即返回 - -### 异步任务流程 -``` -心跳触发 → LLM 调用工具 → 工具返回 AsyncResult - ↓ - 工具启动 goroutine - ↓ - 任务完成 → 回调通知 → SendToChannel -``` - -### 原子写入实现 -```go -// 写入临时文件 -os.WriteFile(path + ".tmp", data, 0644) -// 原子重命名 -os.Rename(path + ".tmp", path) -``` - -## Technical Considerations - -- **破坏性变更**:所有工具实现需要同步修改,不支持向后兼容 -- **Go 版本**:需要 Go 1.21+(确保 atomic 操作支持) -- **测试覆盖**:每个改造的工具需要添加测试用例 -- **并发安全**:State 的原子操作需要正确使用锁 -- **回调设计**:AsyncTool 接口可选,不强制所有工具实现 - -### 回调函数签名 -```go -type AsyncCallback func(ctx context.Context, result *ToolResult) - -type AsyncTool interface { - Tool - SetCallback(cb AsyncCallback) -} -``` - -## Success Metrics - -- 删除 `isToolConfirmationMessage` 后无功能回归 -- 心跳可以触发长任务(如邮件检查)而不阻塞 -- 所有工具改造后测试覆盖率 > 80% -- 状态保存异常情况下无数据丢失 - -## Open Questions - -- [ ] 异步任务失败时如何通知用户?(通过回调发送错误消息) -- [ ] 异步任务是否需要超时机制?(暂不实现,由工具自己处理) -- [ ] 心跳日志是否需要 rotation?(暂不实现,使用外部 logrotate) - -## Implementation Order - -1. **基础设施**:ToolResult + Tool 接口 + Registry (US-001, US-002, US-003) -2. **消费者改造**:AgentLoop 工具结果处理 + 删除字符串匹配 (US-004, US-005) -3. **简单工具验证**:MessageTool 改造验证设计 (US-009) -4. **批量工具改造**:剩余所有工具 (US-010 ~ US-016) -5. **心跳和配置**:心跳异步支持 + 配置修改 (US-006, US-017, US-018, US-019) -6. **状态保存**:原子化保存 (US-008, US-020) diff --git a/workspace/AGENT.md b/workspace/AGENT.md new file mode 100644 index 000000000..5f5fa6480 --- /dev/null +++ b/workspace/AGENT.md @@ -0,0 +1,12 @@ +# Agent Instructions + +You are a helpful AI assistant. Be concise, accurate, and friendly. + +## Guidelines + +- Always explain what you're doing before taking actions +- Ask for clarification when request is ambiguous +- Use tools to help accomplish tasks +- Remember important information in your memory files +- Be proactive and helpful +- Learn from user feedback \ No newline at end of file diff --git a/workspace/IDENTITY.md b/workspace/IDENTITY.md new file mode 100644 index 000000000..dabb0e14b --- /dev/null +++ b/workspace/IDENTITY.md @@ -0,0 +1,56 @@ +# Identity + +## Name +PicoClaw 🦞 + +## Description +Ultra-lightweight personal AI assistant written in Go, inspired by nanobot. + +## Version +0.1.0 + +## Purpose +- Provide intelligent AI assistance with minimal resource usage +- Support multiple LLM providers (OpenAI, Anthropic, Zhipu, etc.) +- Enable easy customization through skills system +- Run on minimal hardware ($10 boards, <10MB RAM) + +## Capabilities + +- Web search and content fetching +- File system operations (read, write, edit) +- Shell command execution +- Multi-channel messaging (Telegram, WhatsApp, Feishu) +- Skill-based extensibility +- Memory and context management + +## Philosophy + +- Simplicity over complexity +- Performance over features +- User control and privacy +- Transparent operation +- Community-driven development + +## Goals + +- Provide a fast, lightweight AI assistant +- Support offline-first operation where possible +- Enable easy customization and extension +- Maintain high quality responses +- Run efficiently on constrained hardware + +## License +MIT License - Free and open source + +## Repository +https://github.com/sipeed/picoclaw + +## Contact +Issues: https://github.com/sipeed/picoclaw/issues +Discussions: https://github.com/sipeed/picoclaw/discussions + +--- + +"Every bit helps, every bit matters." +- Picoclaw \ No newline at end of file diff --git a/workspace/SOUL.md b/workspace/SOUL.md new file mode 100644 index 000000000..0be8834f5 --- /dev/null +++ b/workspace/SOUL.md @@ -0,0 +1,17 @@ +# Soul + +I am picoclaw, a lightweight AI assistant powered by AI. + +## Personality + +- Helpful and friendly +- Concise and to the point +- Curious and eager to learn +- Honest and transparent + +## Values + +- Accuracy over speed +- User privacy and safety +- Transparency in actions +- Continuous improvement \ No newline at end of file diff --git a/workspace/USER.md b/workspace/USER.md new file mode 100644 index 000000000..91398a019 --- /dev/null +++ b/workspace/USER.md @@ -0,0 +1,21 @@ +# User + +Information about user goes here. + +## Preferences + +- Communication style: (casual/formal) +- Timezone: (your timezone) +- Language: (your preferred language) + +## Personal Information + +- Name: (optional) +- Location: (optional) +- Occupation: (optional) + +## Learning Goals + +- What the user wants to learn from AI +- Preferred interaction style +- Areas of interest \ No newline at end of file diff --git a/workspace/memory/MEMORY.md b/workspace/memory/MEMORY.md new file mode 100644 index 000000000..265271db9 --- /dev/null +++ b/workspace/memory/MEMORY.md @@ -0,0 +1,21 @@ +# Long-term Memory + +This file stores important information that should persist across sessions. + +## User Information + +(Important facts about user) + +## Preferences + +(User preferences learned over time) + +## Important Notes + +(Things to remember) + +## Configuration + +- Model preferences +- Channel settings +- Skills enabled \ No newline at end of file diff --git a/skills/github/SKILL.md b/workspace/skills/github/SKILL.md similarity index 100% rename from skills/github/SKILL.md rename to workspace/skills/github/SKILL.md diff --git a/workspace/skills/hardware/SKILL.md b/workspace/skills/hardware/SKILL.md new file mode 100644 index 000000000..e89d1b6e7 --- /dev/null +++ b/workspace/skills/hardware/SKILL.md @@ -0,0 +1,64 @@ +--- +name: hardware +description: Read and control I2C and SPI peripherals on Sipeed boards (LicheeRV Nano, MaixCAM, NanoKVM). +homepage: https://wiki.sipeed.com/hardware/en/lichee/RV_Nano/1_intro.html +metadata: {"nanobot":{"emoji":"🔧","requires":{"tools":["i2c","spi"]}}} +--- + +# Hardware (I2C / SPI) + +Use the `i2c` and `spi` tools to interact with sensors, displays, and other peripherals connected to the board. + +## Quick Start + +``` +# 1. Find available buses +i2c detect + +# 2. Scan for connected devices +i2c scan (bus: "1") + +# 3. Read from a sensor (e.g. AHT20 temperature/humidity) +i2c read (bus: "1", address: 0x38, register: 0xAC, length: 6) + +# 4. SPI devices +spi list +spi read (device: "2.0", length: 4) +``` + +## Before You Start — Pinmux Setup + +Most I2C/SPI pins are shared with WiFi on Sipeed boards. You must configure pinmux before use. + +See `references/board-pinout.md` for board-specific commands. + +**Common steps:** +1. Stop WiFi if using shared pins: `/etc/init.d/S30wifi stop` +2. Load i2c-dev module: `modprobe i2c-dev` +3. Configure pinmux with `devmem` (board-specific) +4. Verify with `i2c detect` and `i2c scan` + +## Safety + +- **Write operations** require `confirm: true` — always confirm with the user first +- I2C addresses are validated to 7-bit range (0x03-0x77) +- SPI modes are validated (0-3 only) +- Maximum per-transaction: 256 bytes (I2C), 4096 bytes (SPI) + +## Common Devices + +See `references/common-devices.md` for register maps and usage of popular sensors: +AHT20, BME280, SSD1306 OLED, MPU6050 IMU, DS3231 RTC, INA219 power monitor, PCA9685 PWM, and more. + +## Troubleshooting + +| Problem | Solution | +|---------|----------| +| No I2C buses found | `modprobe i2c-dev` and check device tree | +| Permission denied | Run as root or add user to `i2c` group | +| No devices on scan | Check wiring, pull-up resistors (4.7k typical), and pinmux | +| Bus number changed | I2C adapter numbers can shift between boots; use `i2c detect` to find current assignment | +| WiFi stopped working | I2C-1/SPI-2 share pins with WiFi SDIO; can't use both simultaneously | +| `devmem` not found | Download separately or use `busybox devmem` | +| SPI transfer returns all zeros | Check MISO wiring and device power | +| SPI transfer returns all 0xFF | Device not responding; check CS pin and clock polarity (mode) | diff --git a/workspace/skills/hardware/references/board-pinout.md b/workspace/skills/hardware/references/board-pinout.md new file mode 100644 index 000000000..827dd0613 --- /dev/null +++ b/workspace/skills/hardware/references/board-pinout.md @@ -0,0 +1,131 @@ +# Board Pinout & Pinmux Reference + +## LicheeRV Nano (SG2002) + +### I2C Buses + +| Bus | Pins | Notes | +|-----|------|-------| +| I2C-1 | P18 (SCL), P21 (SDA) | **Shared with WiFi SDIO** — must stop WiFi first | +| I2C-3 | Available on header | Check device tree for pin assignment | +| I2C-5 | Software (BitBang) | Slower but no pin conflicts | + +### SPI Buses + +| Bus | Pins | Notes | +|-----|------|-------| +| SPI-2 | P18 (CS), P21 (MISO), P22 (MOSI), P23 (SCK) | **Shared with WiFi** — must stop WiFi first | +| SPI-4 | Software (BitBang) | Slower but no pin conflicts | + +### Setup Steps for I2C-1 + +```bash +# 1. Stop WiFi (shares pins with I2C-1) +/etc/init.d/S30wifi stop + +# 2. Configure pinmux for I2C-1 +devmem 0x030010D0 b 0x2 # P18 → I2C1_SCL +devmem 0x030010DC b 0x2 # P21 → I2C1_SDA + +# 3. Load i2c-dev module +modprobe i2c-dev + +# 4. Verify +ls /dev/i2c-* +``` + +### Setup Steps for SPI-2 + +```bash +# 1. Stop WiFi (shares pins with SPI-2) +/etc/init.d/S30wifi stop + +# 2. Configure pinmux for SPI-2 +devmem 0x030010D0 b 0x1 # P18 → SPI2_CS +devmem 0x030010DC b 0x1 # P21 → SPI2_MISO +devmem 0x030010E0 b 0x1 # P22 → SPI2_MOSI +devmem 0x030010E4 b 0x1 # P23 → SPI2_SCK + +# 3. Verify +ls /dev/spidev* +``` + +### Max Tested SPI Speed +- SPI-2 hardware: tested up to **93 MHz** +- `spidev_test` is pre-installed on the official image for loopback testing + +--- + +## MaixCAM + +### I2C Buses + +| Bus | Pins | Notes | +|-----|------|-------| +| I2C-1 | Overlaps with WiFi | Not recommended | +| I2C-3 | Overlaps with WiFi | Not recommended | +| I2C-5 | A15 (SCL), A27 (SDA) | **Recommended** — software I2C, no conflicts | + +### Setup Steps for I2C-5 + +```bash +# Configure pins using pinmap utility +# (MaixCAM uses a pinmap tool instead of devmem) +# Refer to: https://wiki.sipeed.com/hardware/en/maixcam/gpio.html + +# Load i2c-dev +modprobe i2c-dev + +# Verify +ls /dev/i2c-* +``` + +--- + +## MaixCAM2 + +### I2C Buses + +| Bus | Pins | Notes | +|-----|------|-------| +| I2C-6 | A1 (SCL), A0 (SDA) | Available on header | +| I2C-7 | Available | Check device tree | + +### Setup Steps + +```bash +# Configure pinmap for I2C-6 +# A1 → I2C6_SCL, A0 → I2C6_SDA +# Refer to MaixCAM2 documentation for pinmap commands + +modprobe i2c-dev +ls /dev/i2c-* +``` + +--- + +## NanoKVM + +Uses the same SG2002 SoC as LicheeRV Nano. GPIO and I2C access follows the same pinmux procedure. Refer to the LicheeRV Nano section above. + +Check NanoKVM-specific pin headers for available I2C/SPI lines: +- https://wiki.sipeed.com/hardware/en/kvm/NanoKVM/introduction.html + +--- + +## Common Issues + +### devmem not found +The `devmem` utility may not be in the default image. Options: +- Use `busybox devmem` if busybox is installed +- Download devmem from the Sipeed package repository +- Cross-compile from source (single C file) + +### Dynamic bus numbering +I2C adapter numbers can change between boots depending on driver load order. Always use `i2c detect` to find current bus assignments rather than hardcoding bus numbers. + +### Permissions +`/dev/i2c-*` and `/dev/spidev*` typically require root access. Options: +- Run picoclaw as root +- Add user to `i2c` and `spi` groups +- Create udev rules: `SUBSYSTEM=="i2c-dev", MODE="0666"` diff --git a/workspace/skills/hardware/references/common-devices.md b/workspace/skills/hardware/references/common-devices.md new file mode 100644 index 000000000..715e8ab7f --- /dev/null +++ b/workspace/skills/hardware/references/common-devices.md @@ -0,0 +1,78 @@ +# Common I2C/SPI Device Reference + +## I2C Devices + +### AHT20 — Temperature & Humidity +- **Address:** 0x38 +- **Init:** Write `[0xBE, 0x08, 0x00]` then wait 10ms +- **Measure:** Write `[0xAC, 0x33, 0x00]`, wait 80ms, read 6 bytes +- **Parse:** Status=byte[0], Humidity=(byte[1]<<12|byte[2]<<4|byte[3]>>4)/2^20*100, Temp=(byte[3]&0x0F<<16|byte[4]<<8|byte[5])/2^20*200-50 +- **Notes:** No register addressing — write command bytes directly (omit `register` param) + +### BME280 / BMP280 — Temperature, Humidity, Pressure +- **Address:** 0x76 or 0x77 (SDO pin selects) +- **Chip ID register:** 0xD0 → BMP280=0x58, BME280=0x60 +- **Data registers:** 0xF7-0xFE (pressure, temperature, humidity) +- **Config:** Write 0xF2 (humidity oversampling), 0xF4 (temp/press oversampling + mode), 0xF5 (standby, filter) +- **Forced measurement:** Write `[0x25]` to register 0xF4, wait 40ms, read 8 bytes from 0xF7 +- **Calibration:** Read 26 bytes from 0x88 and 7 bytes from 0xE1 for compensation formulas +- **Also available via SPI** (mode 0 or 3) + +### SSD1306 — 128x64 OLED Display +- **Address:** 0x3C (or 0x3D if SA0 high) +- **Command prefix:** 0x00 (write to register 0x00) +- **Data prefix:** 0x40 (write to register 0x40) +- **Init sequence:** `[0xAE, 0xD5, 0x80, 0xA8, 0x3F, 0xD3, 0x00, 0x40, 0x8D, 0x14, 0x20, 0x00, 0xA1, 0xC8, 0xDA, 0x12, 0x81, 0xCF, 0xD9, 0xF1, 0xDB, 0x40, 0xA4, 0xA6, 0xAF]` +- **Display on:** 0xAF, **Display off:** 0xAE +- **Also available via SPI** (faster, recommended for animations) + +### MPU6050 — 6-axis Accelerometer + Gyroscope +- **Address:** 0x68 (or 0x69 if AD0 high) +- **WHO_AM_I:** Register 0x75 → should return 0x68 +- **Wake up:** Write `[0x00]` to register 0x6B (clear sleep bit) +- **Read accel:** 6 bytes from register 0x3B (XH,XL,YH,YL,ZH,ZL) — signed 16-bit, default ±2g +- **Read gyro:** 6 bytes from register 0x43 — signed 16-bit, default ±250°/s +- **Read temp:** 2 bytes from register 0x41 — Temp°C = value/340 + 36.53 + +### DS3231 — Real-Time Clock +- **Address:** 0x68 +- **Read time:** 7 bytes from register 0x00 (seconds, minutes, hours, day, date, month, year) — BCD encoded +- **Set time:** Write 7 BCD bytes to register 0x00 +- **Temperature:** 2 bytes from register 0x11 (signed, 0.25°C resolution) +- **Status:** Register 0x0F — bit 2 = busy, bit 0 = alarm 1 flag + +### INA219 — Current & Power Monitor +- **Address:** 0x40-0x4F (A0,A1 pin selectable) +- **Config:** Register 0x00 — set voltage range, gain, ADC resolution +- **Shunt voltage:** Register 0x01 (signed 16-bit, LSB=10µV) +- **Bus voltage:** Register 0x02 (bits 15:3, LSB=4mV) +- **Power:** Register 0x03 (after calibration) +- **Current:** Register 0x04 (after calibration) +- **Calibration:** Register 0x05 — set based on shunt resistor value + +### PCA9685 — 16-Channel PWM / Servo Controller +- **Address:** 0x40-0x7F (A0-A5 selectable, default 0x40) +- **Mode 1:** Register 0x00 — bit 4=sleep, bit 5=auto-increment +- **Set PWM freq:** Sleep → write prescale to 0xFE → wake. Prescale = round(25MHz / (4096 × freq)) - 1 +- **Channel N on/off:** Registers 0x06+4*N to 0x09+4*N (ON_L, ON_H, OFF_L, OFF_H) +- **Servo 0°-180°:** ON=0, OFF=150-600 (at 50Hz). Typical: 0°=150, 90°=375, 180°=600 + +### AT24C256 — 256Kbit EEPROM +- **Address:** 0x50-0x57 (A0-A2 selectable) +- **Read:** Write 2-byte address (high, low), then read N bytes +- **Write:** Write 2-byte address + up to 64 bytes (page write), wait 5ms for write cycle +- **Page size:** 64 bytes. Writes that cross page boundary wrap around. + +## SPI Devices + +### MCP3008 — 8-Channel 10-bit ADC +- **Interface:** SPI mode 0, max 3.6 MHz @ 5V +- **Read channel N:** Send `[0x01, (0x80 | N<<4), 0x00]`, result in last 10 bits of bytes 1-2 +- **Formula:** value = ((byte[1] & 0x03) << 8) | byte[2] +- **Voltage:** value × Vref / 1024 + +### W25Q128 — 128Mbit SPI Flash +- **Interface:** SPI mode 0 or 3, up to 104 MHz +- **Read ID:** Send `[0x9F, 0, 0, 0]` → manufacturer + device ID +- **Read data:** Send `[0x03, addr_high, addr_mid, addr_low]` + N zero bytes +- **Status:** Send `[0x05, 0]` → bit 0 = BUSY diff --git a/skills/skill-creator/SKILL.md b/workspace/skills/skill-creator/SKILL.md similarity index 100% rename from skills/skill-creator/SKILL.md rename to workspace/skills/skill-creator/SKILL.md diff --git a/skills/summarize/SKILL.md b/workspace/skills/summarize/SKILL.md similarity index 100% rename from skills/summarize/SKILL.md rename to workspace/skills/summarize/SKILL.md diff --git a/skills/tmux/SKILL.md b/workspace/skills/tmux/SKILL.md similarity index 100% rename from skills/tmux/SKILL.md rename to workspace/skills/tmux/SKILL.md diff --git a/skills/tmux/scripts/find-sessions.sh b/workspace/skills/tmux/scripts/find-sessions.sh similarity index 100% rename from skills/tmux/scripts/find-sessions.sh rename to workspace/skills/tmux/scripts/find-sessions.sh diff --git a/skills/tmux/scripts/wait-for-text.sh b/workspace/skills/tmux/scripts/wait-for-text.sh similarity index 100% rename from skills/tmux/scripts/wait-for-text.sh rename to workspace/skills/tmux/scripts/wait-for-text.sh diff --git a/skills/weather/SKILL.md b/workspace/skills/weather/SKILL.md similarity index 100% rename from skills/weather/SKILL.md rename to workspace/skills/weather/SKILL.md