diff --git a/.gitignore b/.gitignore index 3ff195fbf..c60b6ba8a 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,7 @@ tasks/ # Added by goreleaser init: dist/ + +# PM2 and process manager configs (dev-specific) +ecosystem.config.js +.pm2/ diff --git a/docs/swarm-architecture.md b/docs/swarm-architecture.md new file mode 100644 index 000000000..5b19448ec --- /dev/null +++ b/docs/swarm-architecture.md @@ -0,0 +1,450 @@ +# PicoClaw Swarm Mode Architecture + +## Overview + +PicoClaw Swarm Mode enables multiple PicoClaw instances to work together as a distributed system, providing: +- **Node Discovery**: Automatic peer discovery via UDP gossip protocol +- **Health Monitoring**: Periodic heartbeat and failure detection +- **Load Balancing**: Intelligent task distribution based on node load +- **Handoff Mechanism**: Dynamic task delegation between nodes + +## Architecture + +The swarm architecture is divided into two distinct planes: + +``` +┌──────────────────────────────────────────────────────────────���──┐ +│ PicoClaw Swarm │ +├─────────────────────────────────────────────────────────────────┤ +│ Control Plane │ Data Plane │ +│ ├─ Node Discovery │ ├─ Task Execution │ +│ ├─ Membership Management │ ├─ Session Transfer │ +│ ├─ Health Monitoring │ └─ Message Routing │ +│ └─ Load Monitoring │ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Control Plane + +The control plane manages cluster state, node membership, and coordination. + +### 1. Node Discovery + +Nodes discover each other using a lightweight UDP gossip protocol: + +```mermaid +sequenceDiagram + participant Node1 + participant Node2 + participant Node3 + + Note over Node1: New node starts + Node1->>Node1: Bind UDP port (7946) + Node1->>Node2: Ping + NodeInfo + Node2->>Node1: Pong + NodeInfo + Node1->>Node2: Gossip: Known Nodes + Node2->>Node3: Forward Node1 info + Node3->>Node2: Ack + Note over Node1,Node3: Cluster formed +``` + +**Gossip Protocol Flow:** + +```mermaid +graph LR + A[Node A] -->|Ping| B[Node B] + B -->|Pong| A + A -->|Sync| B + B -->|Forward| C[Node C] + C -->|Ack| B + B -->|Update| A + A -.->|Eventually| C +``` + +**Key Parameters:** + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `gossip_interval` | 1s | Frequency of gossip messages | +| `push_pull_interval` | 30s | Full state sync interval | +| `node_timeout` | 5s | Time before marking node suspect | +| `dead_node_timeout` | 30s | Time before removing dead node | + +### 2. Membership Management + +Each node maintains a view of the cluster: + +```go +type ClusterView struct { + sync.RWMutex + localNode *NodeInfo + members map[string]*NodeInfo // node_id -> NodeInfo + stateMap map[string]NodeState // node_id -> State +} +``` + +**Node State Machine:** + +```mermaid +stateDiagram-v2 + [*] --> Alive: Node joins + Alive --> Suspect: Missed heartbeat + Suspect --> Alive: Heartbeat recovered + Suspect --> Dead: Timeout exceeded + Dead --> [*] +``` + +**Node Information:** + +```go +type NodeInfo struct { + ID string // Unique node identifier + Addr string // IP address + Port int // Discovery port + AgentCaps map[string]string // Capabilities (models, tools) + LoadScore float64 // Current load (0.0-1.0) + Labels map[string]string // Custom labels + Timestamp int64 // Last update time + Version string // Protocol version +} +``` + +### 3. Health Monitoring + +**Heartbeat Flow:** + +```mermaid +sequenceDiagram + participant N1 as Node 1 + participant N2 as Node 2 + participant HM as Health Monitor + + loop Every gossip_interval + N1->>N2: Heartbeat (timestamp, load_score) + N2->>HM: Update state + HM->>HM: Check timeout + alt Timeout exceeded + HM->>HM: Mark as Suspect + HM->>N1: Probe (are you alive?) + alt No response + HM->>HM: Mark as Dead + HM->>All: Broadcast NodeLeft event + end + end + end +``` + +### 4. Load Monitoring + +Each node continuously monitors its resource usage: + +```mermaid +graph TB + subgraph Load Monitor + A[CPU Sample] --> D[Score Calculator] + B[Memory Sample] --> D + C[Session Count] --> D + D --> E[Load Score] + end + + subgraph Weights + A -.->|0.3| D + B -.->|0.3| D + C -.->|0.4| D + end + + E --> F{Threshold Check} + F -->|< 0.8| G[Normal Mode] + F -->|>= 0.8| H[Overloaded - Trigger Handoff] +``` + +**Load Score Formula:** + +``` +LoadScore = (CPUUsage × cpu_weight) + + (MemoryUsage × memory_weight) + + (SessionRatio × session_weight) + +Where: +- CPUUsage = current CPU usage (0.0-1.0) +- MemoryUsage = current memory usage (0.0-1.0) +- SessionRatio = current_sessions / max_sessions +- Default weights: cpu=0.3, memory=0.3, session=0.4 +``` + +## Data Plane + +The data plane handles actual task execution and session state transfer. + +### 1. Request Flow + +```mermaid +sequenceDiagram + participant User + participant LB as Entry Point + participant N1 as Node 1 + participant N2 as Node 2 + + User->>LB: Message + LB->>LB: Check node availability + + alt Node 1 available + LB->>N1: Forward message + N1->>N1: Process with LLM + N1->>User: Response + else Node 1 overloaded + LB->>N2: Handoff request + N2->>N1: Session transfer + N1->>N2: Session state + N2->>User: Response (from N2) + end +``` + +### 2. Handoff Mechanism + +**Handoff Decision Flow:** + +```mermaid +flowchart TD + A[Receive Request] --> B{Should Handoff?} + B -->|Local load >= threshold| C[Select Target Node] + B -->|Local load < threshold| D[Process Locally] + + C --> E{Target Available?} + E -->|Yes| F[Initiate Handoff] + E -->|No| G[Retry or Fail] + + F --> H[Serialize Session] + H --> I[Send to Target] + I --> J{Success?} + J -->|Yes| K[Update Routing] + J -->|No| L[Rollback] + + K --> M[Target Processes] + M --> N[Return Response] +``` + +**Handoff Protocol:** + +```mermaid +sequenceDiagram + participant Source as Overloaded Node + participant Target as Selected Node + participant Client + + Source->>Source: Check load threshold + Source->>Target: HandoffRequest{session_id, context} + + Target->>Target: Validate request + alt Accepted + Target->>Source: HandoffAccept + Source->>Target: SessionTransfer{messages, tools, state} + Target->>Target: Restore session + Target->>Source: TransferComplete + Source->>Client: Redirect to Target + Client->>Target: Continue conversation + else Rejected + Target->>Source: HandoffReject{reason} + Source->>Source: Try next node or process locally + end +``` + +### 3. Session Transfer + +**Session State Structure:** + +```go +type SessionState struct { + SessionID string + Messages []Message // Conversation history + Context map[string]any // Shared context + Tools []ToolCall // Pending tool calls + Metadata SessionMeta // Timestamp, user info, etc. +} +``` + +**Transfer Flow:** + +```mermaid +stateDiagram-v2 + [*] --> Active: Session created + Active --> Transferring: Handoff initiated + Transferring --> Active: Transfer failed + Transferring --> Migrated: Transfer complete + Migrated --> [*]: Session closed + Active --> [*]: Session closed +``` + +## System Architecture + +### Component Overview + +```mermaid +graph TB + subgraph "Node 1" + D1[Discovery Service] --> M1[Membership Manager] + H1[Handoff Coordinator] --> M1 + L1[Load Monitor] --> H1 + A1[Agent Loop] --> L1 + end + + subgraph "Node 2" + D2[Discovery Service] --> M2[Membership Manager] + H2[Handoff Coordinator] --> M2 + L2[Load Monitor] --> H2 + A2[Agent Loop] --> L2 + end + + D1 <-- UDP Gossip --> D2 + D2 <-- UDP Gossip --> D1 + H1 <-- RPC Handoff --> H2 + H2 <-- RPC Handoff --> H1 + + TG[Telegram Gateway] --> A1 + TG --> A2 +``` + +### Communication Channels + +| Channel | Protocol | Purpose | +|---------|----------|---------| +| Discovery | UDP | Node gossip, heartbeat | +| Handoff RPC | UDP | Session transfer coordination | +| Session Data | UDP | Serialized session state | + +## Configuration + +### Example Configuration + +```json +{ + "swarm": { + "enabled": true, + "node_id": "picoclaw-node-1", + "bind_addr": "127.0.0.1", + "bind_port": 7946, + + "discovery": { + "join_addrs": ["127.0.0.1:7946"], + "gossip_interval": 1, + "push_pull_interval": 30, + "node_timeout": 5, + "dead_node_timeout": 30 + }, + + "handoff": { + "enabled": true, + "load_threshold": 0.8, + "timeout": 30, + "max_retries": 3, + "retry_delay": 5 + }, + + "rpc": { + "port": 7947, + "timeout": 10 + }, + + "load_monitor": { + "enabled": true, + "interval": 5, + "sample_size": 60, + "cpu_weight": 0.3, + "memory_weight": 0.3, + "session_weight": 0.4 + } + } +} +``` + +### Deployment Modes + +**Single Entry Point:** + +```mermaid +graph LR + TG[Telegram Gateway] --> N1[Node 1: Coordinator] + N1 <-- Swarm --> N2[Node 2: Worker] + N2 <-- Swarm --> N1 + N1 <-- Swarm --> N3[Node 3: Worker] + N3 <-- Swarm --> N1 +``` + +**Multi-Entry Point (with load balancer):** + +```mermaid +graph LR + LB[Load Balancer] --> N1[Node 1] + LB --> N2[Node 2] + N1 <-- Swarm Mesh --> N2 + N2 <-- Swarm Mesh --> N1 + N1 <-- Swarm Mesh --> N3[Node 3] + N3 <-- Swarm Mesh --> N1 + N2 <-- Swarm Mesh --> N3 + N3 <-- Swarm Mesh --> N2 +``` + +## Event System + +The swarm publishes events for monitoring and integration: + +```mermaid +graph LR + A[Node Joined] --> ED[Event Dispatcher] + B[Node Left] --> ED + C[Node Suspect] --> ED + D[Handoff Started] --> ED + E[Handoff Completed] --> ED + + ED --> H[Handlers] + H --> L[Logging] + H --> M[Metrics] + H --> C[Custom Actions] +``` + +**Event Types:** + +| Event | Description | Payload | +|-------|-------------|---------| +| `NodeJoined` | New node discovered | NodeInfo | +| `NodeLeft` | Node removed | NodeID | +| `NodeSuspect` | Node marked suspect | NodeID | +| `NodeAlive` | Node recovered | NodeInfo | +| `HandoffStarted` | Handoff initiated | HandoffOperation | +| `HandoffCompleted` | Handoff finished | HandoffResult | +| `HandoffFailed` | Handoff error | Error | + +## Error Handling + +```mermaid +flowchart TD + A[Operation Failed] --> B{Retryable?} + B -->|Yes| C[Increment retry count] + B -->|No| D[Return error] + + C --> E{Max retries reached?} + E -->|No| F[Wait retry_delay] + F --> G[Retry operation] + + E -->|Yes| H[Mark node suspect] + H --> I[Select alternative node] + + G --> J{Success?} + J -->|Yes| K[Continue] + J -->|No| C +``` + +## Security Considerations + +1. **Discovery**: UDP gossip is unencrypted - use in trusted networks only +2. **Handoff**: Session data transferred without encryption +3. **Authentication**: No node authentication implemented +4. **Recommendation**: Use VPN or private network for production + +## Future Enhancements + +1. **Secure Discovery**: Add mTLS for node communication +2. **Consistent Hashing**: Replace random selection with consistent hashing +3. **Session Affinity**: Sticky sessions for better performance +4. **Leader Election**: Automatic coordinator election +5. **Multi-Region**: Geo-distributed cluster support diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 29827d0b2..4671d9c18 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -23,10 +23,12 @@ import ( "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" + picolib "github.com/sipeed/picoclaw/pkg/pico" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/skills" "github.com/sipeed/picoclaw/pkg/state" + "github.com/sipeed/picoclaw/pkg/swarm" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -41,6 +43,18 @@ type AgentLoop struct { fallback *providers.FallbackChain channelManager *channels.Manager mediaStore media.MediaStore + + // Swarm mode support + swarmEnabled bool + swarmInitError error // Set if swarm initialization fails + swarmDiscovery *swarm.DiscoveryService + swarmHandoff *swarm.HandoffCoordinator + swarmLoad *swarm.LoadMonitor + swarmPicoClient *picolib.PicoNodeClient + swarmLeaderElection *swarm.LeaderElection + + // Dynamic command registry (wired from channel manager) + commandRegistry *channels.CommandRegistry } // processOptions configures how a message is processed @@ -74,7 +88,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers stateManager = state.NewManager(defaultAgent.Workspace) } - return &AgentLoop{ + al := &AgentLoop{ bus: msgBus, cfg: cfg, registry: registry, @@ -82,6 +96,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers summarizing: sync.Map{}, fallback: fallbackChain, } + + // Initialize swarm mode if enabled + if cfg.Swarm.Enabled { + al.initSwarm() + } + + return al } // registerSharedTools registers tools that are shared across all agents (web, message, spawn). @@ -234,6 +255,12 @@ func (al *AgentLoop) Run(ctx context.Context) error { func (al *AgentLoop) Stop() { al.running.Store(false) + // Gracefully shutdown swarm mode components if enabled + if al.swarmEnabled { + if err := al.ShutdownSwarm(); err != nil { + logger.WarnCF("swarm", "Error during swarm shutdown", map[string]any{"error": err}) + } + } } func (al *AgentLoop) RegisterTool(tool tools.Tool) { @@ -246,6 +273,13 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) { func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { al.channelManager = cm + al.commandRegistry = cm.CommandRegistry() + + // Wire up swarm node request handler on the Pico channel + if al.swarmEnabled { + al.registerPicoNodeHandler() + al.registerSwarmCommands() + } } // SetMediaStore injects a MediaStore for media lifecycle management. @@ -360,6 +394,25 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return al.processSystemMessage(ctx, msg) } + // Check for @node-id routing syntax + if al.swarmEnabled { + if targetNodeID, content := tools.ParseNodeMention(msg.Content); targetNodeID != "" { + logger.InfoCF("swarm", "Node mention detected", map[string]any{ + "target": targetNodeID, + "content": utils.Truncate(content, 50), + }) + return al.handleNodeRouting(ctx, msg, targetNodeID, content) + } + + // For non-mentioned messages in swarm mode, only leader processes + if al.swarmLeaderElection != nil && !al.swarmLeaderElection.IsLeader() { + logger.DebugCF("swarm", "Ignoring message (not leader)", map[string]any{ + "content": utils.Truncate(msg.Content, 50), + }) + return "", nil + } + } + // Check for commands if response, handled := al.handleCommand(ctx, msg); handled { return response, nil @@ -403,6 +456,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) "matched_by": route.MatchedBy, }) + // Check if we should handoff this request to another node + if al.swarmEnabled && al.shouldHandoff(agent, processOptions{ + SessionKey: sessionKey, + Channel: msg.Channel, + ChatID: msg.ChatID, + UserMessage: msg.Content, + }) { + handoffResp, err := al.initiateSwarmHandoff(ctx, agent, sessionKey, msg) + if err == nil && handoffResp != nil && handoffResp.Accepted { + // Handoff was successful, return the response + return fmt.Sprintf("Your request has been handed off to node %s for processing.", handoffResp.NodeID), nil + } + // If handoff failed, continue processing locally + logger.WarnCF("swarm", "Handoff failed, processing locally", map[string]any{"error": err}) + } + return al.runAgentLoop(ctx, agent, processOptions{ SessionKey: sessionKey, Channel: msg.Channel, @@ -475,6 +544,10 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe // runAgentLoop is the core message processing logic. func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) { + // Track active session for swarm load monitoring + al.IncrementSwarmSessions() + defer al.DecrementSwarmSessions() + // 0. Record last channel for heartbeat notifications (skip internal channels) if opts.Channel != "" && opts.ChatID != "" { // Don't record internal channels (cli, system, subagent) @@ -907,6 +980,15 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c go func() { defer al.summarizing.Delete(summarizeKey) logger.Debug("Memory threshold reached. Optimizing conversation history...") + if !constants.IsInternalChannel(channel) { + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: "Memory threshold reached. Optimizing conversation history...", + }) + } al.summarizeSession(agent, sessionKey) }() } @@ -1263,6 +1345,19 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) } } + // Check dynamic command registry + if al.commandRegistry != nil { + cmdName := strings.TrimPrefix(parts[0], "/") + if entry, ok := al.commandRegistry.Get(cmdName); ok { + argsStr := strings.Join(args, " ") + response, err := entry.Handler(ctx, argsStr, msg) + if err != nil { + return fmt.Sprintf("Command error: %v", err), true + } + return response, true + } + } + return "", false } @@ -1291,3 +1386,11 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { } return &routing.RoutePeer{Kind: parentKind, ID: parentID} } + +// min returns the minimum of two integers. +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/pkg/agent/loop_swarm.go b/pkg/agent/loop_swarm.go new file mode 100644 index 000000000..1c0fd71f6 --- /dev/null +++ b/pkg/agent/loop_swarm.go @@ -0,0 +1,704 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package agent + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels/pico" + "github.com/sipeed/picoclaw/pkg/config" + picolib "github.com/sipeed/picoclaw/pkg/pico" + "github.com/sipeed/picoclaw/pkg/swarm" + "github.com/sipeed/picoclaw/pkg/tools" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// registerPicoNodeHandler registers the incoming node message handler on the Pico channel. +func (al *AgentLoop) registerPicoNodeHandler() { + if al.channelManager == nil { + return + } + ch, ok := al.channelManager.GetChannel("pico") + if !ok { + logger.WarnC("swarm", "Pico channel not available, inter-node request handler not registered") + return + } + picoCh, ok := ch.(*pico.PicoChannel) + if !ok { + logger.WarnC("swarm", "Pico channel is not a PicoChannel type, inter-node request handler not registered") + return + } + picoCh.SetNodeRequestHandler(func(payload map[string]any) (map[string]any, error) { + p := picolib.NodePayload(payload) + + switch p.Action() { + case picolib.NodeActionHandoffRequest: + return al.handleIncomingHandoffRequest(p) + case picolib.NodeActionMessage, "": + // Default action: direct message (backward compatible) + response, err := al.handleIncomingNodeMessage(&picolib.DirectMessage{ + MessageID: p.RequestID(), + SourceNodeID: p.SourceNodeID(), + Content: p.Content(), + Channel: p.Channel(), + ChatID: p.ChatID(), + SenderID: p.SenderID(), + Metadata: p.Metadata(), + }) + if err != nil { + return picolib.ErrorReply(err.Error()), nil + } + return picolib.ResponseReply(response), nil + default: + return picolib.ErrorReply(fmt.Sprintf("unknown action: %s", p.Action())), nil + } + }) + logger.InfoC("swarm", "Node request handler registered on Pico channel") +} + +// handleNodeRouting handles routing a message to a specific node in the swarm. +func (al *AgentLoop) handleNodeRouting( + ctx context.Context, + msg bus.InboundMessage, + targetNodeID, content string, +) (string, error) { + // Check if target is this node + if targetNodeID == al.swarmDiscovery.LocalNode().ID { + logger.InfoCF("swarm", "Target is this node, processing locally", map[string]any{"node_id": targetNodeID}) + // Update content and continue processing locally. + // Note: content has already been stripped of @node-id: prefix by ParseNodeMention, + // so calling processMessage is safe and will not cause infinite recursion. + msg.Content = content + return al.processMessage(ctx, msg) + } + + // Find target node + members := al.swarmDiscovery.Members() + var target *swarm.NodeWithState + for _, m := range members { + if m.Node.ID == targetNodeID { + target = m + break + } + } + + if target == nil { + return fmt.Sprintf("❌ Node '%s' not found in cluster. Use /nodes to see available nodes.", targetNodeID), nil + } + + // Check if target is available + if target.State.Status != swarm.NodeStatusAlive { + return fmt.Sprintf("❌ Node '%s' is not alive (status: %s)", targetNodeID, target.State.Status), nil + } + + if target.Node.LoadScore > 0.9 { + return fmt.Sprintf("❌ Node '%s' is overloaded (load: %.0f%%)", targetNodeID, target.Node.LoadScore*100), nil + } + + // Use PicoNodeClient to send message to target node + if al.swarmPicoClient == nil { + return fmt.Sprintf("❌ Node-to-node communication not initialized"), nil + } + + // Construct target Pico address using the node's HTTP port + if target.Node.HTTPPort == 0 { + return fmt.Sprintf("❌ Node '%s' does not have an HTTP port configured", targetNodeID), nil + } + picoAddr := picolib.BuildNodeAddr(target.Node.Addr, target.Node.HTTPPort) + + logger.InfoCF("swarm", "Sending message to remote node via Pico channel", map[string]any{ + "target": targetNodeID, + "pico_addr": picoAddr, + "load": target.Node.LoadScore, + }) + + // Send the message using PicoNodeClient + response, err := al.swarmPicoClient.SendMessage( + ctx, + picoAddr, + targetNodeID, + content, + msg.Channel, + msg.ChatID, + msg.SenderID, + ) + if err != nil { + return fmt.Sprintf("❌ Failed to send message to node '%s': %v", targetNodeID, err), nil + } + + return response, nil +} + +// initSwarm initializes the swarm mode components. +func (al *AgentLoop) initSwarm() { + logger.InfoC("swarm", "Initializing swarm mode") + + // Create and start discovery service + swarmConfig := al.convertToSwarmConfig(al.cfg.Swarm) + discovery, err := swarm.NewDiscoveryService(swarmConfig) + if err != nil { + logger.ErrorCF("swarm", "Failed to create discovery service", map[string]any{"error": err.Error()}) + al.swarmInitError = fmt.Errorf("discovery service creation failed: %w", err) + return + } + if err := discovery.Start(); err != nil { + logger.ErrorCF("swarm", "Failed to start discovery service", map[string]any{"error": err.Error()}) + al.swarmInitError = fmt.Errorf("discovery service start failed: %w", err) + return + } + al.swarmDiscovery = discovery + + // Create handoff coordinator + al.swarmHandoff = swarm.NewHandoffCoordinator(discovery, al.convertHandoffConfig()) + + // Create and start load monitor + al.swarmLoad = swarm.NewLoadMonitor(al.convertLoadMonitorConfig()) + if al.cfg.Swarm.LoadMonitor.Enabled { + al.swarmLoad.Start() + al.swarmLoad.OnThreshold(func(score float64) { + discovery.UpdateLoad(score) + }) + } + + al.swarmEnabled = true + + // Initialize leader election if enabled + if al.cfg.Swarm.LeaderElection.Enabled { + al.initSwarmLeaderElection(discovery) + } + + // Initialize inter-node communication via Pico channel + al.initSwarmPicoClient(discovery) + + // Register swarm tools + al.registerSwarmTools(discovery) + + // Subscribe to node events for logging + al.subscribeSwarmEvents(discovery) + + logger.InfoCF("swarm", "Swarm mode initialized", map[string]any{ + "node_id": discovery.LocalNode().ID, + "bind_addr": al.cfg.Swarm.BindAddr, + "bind_port": al.cfg.Swarm.BindPort, + "handoff": al.cfg.Swarm.Handoff.Enabled, + }) +} + +// initSwarmLeaderElection initializes the leader election module. +func (al *AgentLoop) initSwarmLeaderElection(discovery *swarm.DiscoveryService) { + // Get membership manager from discovery service + membership := discovery.GetMembershipManager() + if membership == nil { + logger.WarnCF("swarm", "Membership manager not available, leader election disabled", nil) + return + } + + // Convert config.SwarmLeaderElectionConfig to swarm.LeaderElectionConfig + leaderElectionConfig := swarm.LeaderElectionConfig{ + Enabled: al.cfg.Swarm.LeaderElection.Enabled, + ElectionInterval: swarm.Duration{Duration: time.Duration(al.cfg.Swarm.LeaderElection.ElectionInterval) * time.Second}, + LeaderHeartbeatTimeout: swarm.Duration{Duration: time.Duration(al.cfg.Swarm.LeaderElection.LeaderHeartbeatTimeout) * time.Second}, + } + + // Create leader election instance + leaderElection := swarm.NewLeaderElection( + discovery.LocalNode().ID, + membership, + leaderElectionConfig, + ) + + // Start leader election + leaderElection.Start() + al.swarmLeaderElection = leaderElection + + logger.InfoCF("swarm", "Leader election initialized", map[string]any{ + "node_id": discovery.LocalNode().ID, + "enabled": al.cfg.Swarm.LeaderElection.Enabled, + }) +} + +// initSwarmPicoClient sets up the PicoNodeClient for inter-node communication +// and wires the handoff coordinator to use it. +func (al *AgentLoop) initSwarmPicoClient(discovery *swarm.DiscoveryService) { + picoToken := al.cfg.Channels.Pico.Token + if picoToken == "" { + logger.WarnCF("swarm", "Pico token not configured, inter-node communication disabled", nil) + return + } + + al.swarmPicoClient = picolib.NewPicoNodeClient(discovery.LocalNode().ID, picoToken) + logger.InfoCF("swarm", "PicoNodeClient initialized for inter-node communication", nil) + + al.swarmHandoff.SetSendFunc( + func(ctx context.Context, targetAddr string, req *swarm.HandoffRequest) (*swarm.HandoffResponse, error) { + return al.sendHandoffViaPico(ctx, targetAddr, req) + }, + ) +} + +// sendHandoffViaPico sends a handoff request to a remote node via the Pico channel +// and parses the response. Uses a mapstructure decoder to handle type mismatches gracefully. +func (al *AgentLoop) sendHandoffViaPico( + ctx context.Context, + targetAddr string, + req *swarm.HandoffRequest, +) (*swarm.HandoffResponse, error) { + payload := swarm.NewHandoffRequestPayload(req) + replyPayload, err := al.swarmPicoClient.SendNodeAction(ctx, targetAddr, payload) + if err != nil { + return nil, fmt.Errorf("handoff request via Pico failed: %w", err) + } + + respData, ok := replyPayload.RawValue(picolib.PayloadKeyHandoffResp) + if !ok { + return nil, fmt.Errorf("missing handoff_response in reply") + } + + // Type-assert respData to map[string]any first + respMap, ok := respData.(map[string]any) + if !ok { + // Fall back to JSON marshal/unmarshal if not a map + respJSON, err := json.Marshal(respData) + if err != nil { + return nil, fmt.Errorf("failed to convert handoff response to JSON: %w", err) + } + var resp swarm.HandoffResponse + if err := json.Unmarshal(respJSON, &resp); err != nil { + return nil, fmt.Errorf("failed to parse handoff response: %w", err) + } + return &resp, nil + } + + // Direct struct assignment from map with type coercion + resp := &swarm.HandoffResponse{ + RequestID: toString(respMap["request_id"]), + Accepted: toBool(respMap["accepted"]), + NodeID: toString(respMap["node_id"]), + Reason: toString(respMap["reason"]), + SessionKey: toString(respMap["session_key"]), + Timestamp: toInt64(respMap["timestamp"]), + State: swarm.HandoffState(toString(respMap["state"])), + } + + return resp, nil +} + +// Helper functions for type coercion to handle json.Number and other types +func toString(v any) string { + switch val := v.(type) { + case string: + return val + case json.Number: + return val.String() + default: + return "" + } +} + +func toBool(v any) bool { + switch val := v.(type) { + case bool: + return val + case string: + return val == "true" || val == "1" + case float64: + return val != 0 + default: + return false + } +} + +func toInt64(v any) int64 { + switch val := v.(type) { + case int64: + return val + case float64: + return int64(val) + case json.Number: + i, _ := val.Int64() + return i + case string: + i, _ := strconv.ParseInt(val, 10, 64) + return i + default: + return 0 + } +} + +// sendMessageToNode resolves a target node by ID from the discovery service +// and sends a message to it via the Pico channel. +func (al *AgentLoop) sendMessageToNode( + ctx context.Context, + targetNodeID, content, channel, chatID, senderID string, +) (string, error) { + members := al.swarmDiscovery.Members() + for _, m := range members { + if m.Node.ID == targetNodeID { + if m.Node.HTTPPort == 0 { + return "", fmt.Errorf("node %s does not have an HTTP port configured", targetNodeID) + } + picoAddr := picolib.BuildNodeAddr(m.Node.Addr, m.Node.HTTPPort) + return al.swarmPicoClient.SendMessage(ctx, picoAddr, targetNodeID, content, channel, chatID, senderID) + } + } + return "", fmt.Errorf("node %s not found", targetNodeID) +} + +// registerSwarmTools registers swarm-related tools (handoff, routing). +// Note: /nodes is registered as a command (not a tool) via registerSwarmCommands. +func (al *AgentLoop) registerSwarmTools(discovery *swarm.DiscoveryService) { + localNodeID := discovery.LocalNode().ID + + if !al.cfg.Swarm.Handoff.Enabled { + return + } + + handoffTool := tools.NewHandoffTool(al.swarmHandoff) + al.RegisterTool(handoffTool) + + routeTool := tools.NewSwarmRouteTool(discovery, al.swarmHandoff, localNodeID) + if al.swarmPicoClient != nil { + routeTool.SetSendMessageFn(al.sendMessageToNode) + } + al.RegisterTool(routeTool) +} + +// registerSwarmCommands registers swarm slash commands on the channel command registry. +// Called from SetChannelManager after both swarm and channel manager are initialized. +func (al *AgentLoop) registerSwarmCommands() { + if al.commandRegistry == nil || al.swarmDiscovery == nil { + return + } + + localNodeID := al.swarmDiscovery.LocalNode().ID + + al.commandRegistry.Register("nodes", "List swarm cluster nodes", func( + ctx context.Context, args string, msg bus.InboundMessage, + ) (string, error) { + verbose := strings.Contains(args, "verbose") || strings.Contains(args, "-v") + return tools.FormatClusterStatus(al.swarmDiscovery, al.swarmLoad, localNodeID, verbose), nil + }) +} + +// subscribeSwarmEvents subscribes to discovery events and logs node state changes. +func (al *AgentLoop) subscribeSwarmEvents(discovery *swarm.DiscoveryService) { + discovery.Subscribe(func(event *swarm.NodeEvent) { + switch event.Event { + case swarm.EventJoin: + logger.InfoCF("swarm", "Node joined", map[string]any{"node_id": event.Node.ID}) + case swarm.EventLeave: + logger.InfoCF("swarm", "Node left", map[string]any{"node_id": event.Node.ID}) + case swarm.EventUpdate: + logger.DebugCF("swarm", "Node updated", map[string]any{ + "node_id": event.Node.ID, + "load_score": event.Node.LoadScore, + }) + } + }) +} + +// convertHandoffConfig converts config.SwarmHandoffConfig to swarm.HandoffConfig. +func (al *AgentLoop) convertHandoffConfig() swarm.HandoffConfig { + cfg := al.cfg.Swarm.Handoff + return swarm.HandoffConfig{ + Enabled: cfg.Enabled, + LoadThreshold: cfg.LoadThreshold, + Timeout: swarm.Duration{Duration: time.Duration(cfg.Timeout) * time.Second}, + MaxRetries: cfg.MaxRetries, + RetryDelay: swarm.Duration{Duration: time.Duration(cfg.RetryDelay) * time.Second}, + } +} + +// convertLoadMonitorConfig converts config.SwarmLoadMonitorConfig to *swarm.LoadMonitorConfig. +func (al *AgentLoop) convertLoadMonitorConfig() *swarm.LoadMonitorConfig { + cfg := al.cfg.Swarm.LoadMonitor + return &swarm.LoadMonitorConfig{ + Enabled: cfg.Enabled, + Interval: swarm.Duration{Duration: time.Duration(cfg.Interval) * time.Second}, + SampleSize: cfg.SampleSize, + CPUWeight: cfg.CPUWeight, + MemoryWeight: cfg.MemoryWeight, + SessionWeight: cfg.SessionWeight, + } +} + +// convertToSwarmConfig converts the config.SwarmConfig to swarm.Config. +func (al *AgentLoop) convertToSwarmConfig(cfg config.SwarmConfig) *swarm.Config { + // Helper function to apply defaults for duration configs + applyDurationDefault := func(val int, defaultVal time.Duration) time.Duration { + if val <= 0 { + return defaultVal + } + return time.Duration(val) * time.Second + } + + return &swarm.Config{ + Enabled: cfg.Enabled, + NodeID: cfg.NodeID, + BindAddr: cfg.BindAddr, + BindPort: cfg.BindPort, + AdvertiseAddr: cfg.AdvertiseAddr, + Discovery: swarm.DiscoveryConfig{ + JoinAddrs: cfg.Discovery.JoinAddrs, + GossipInterval: swarm.Duration{Duration: applyDurationDefault(cfg.Discovery.GossipInterval, swarm.DefaultGossipInterval)}, + PushPullInterval: swarm.Duration{Duration: applyDurationDefault(cfg.Discovery.PushPullInterval, swarm.DefaultPushPullInterval)}, + NodeTimeout: swarm.Duration{Duration: applyDurationDefault(cfg.Discovery.NodeTimeout, swarm.DefaultNodeTimeout)}, + DeadNodeTimeout: swarm.Duration{Duration: applyDurationDefault(cfg.Discovery.DeadNodeTimeout, swarm.DefaultDeadNodeTimeout)}, + }, + Handoff: swarm.HandoffConfig{ + Enabled: cfg.Handoff.Enabled, + LoadThreshold: cfg.Handoff.LoadThreshold, + Timeout: swarm.Duration{Duration: applyDurationDefault(cfg.Handoff.Timeout, swarm.DefaultHandoffTimeout)}, + MaxRetries: cfg.Handoff.MaxRetries, + RetryDelay: swarm.Duration{Duration: applyDurationDefault(cfg.Handoff.RetryDelay, swarm.DefaultHandoffRetryDelay)}, + }, + RPC: swarm.RPCConfig{ + Port: cfg.RPC.Port, + Timeout: swarm.Duration{Duration: applyDurationDefault(cfg.RPC.Timeout, 10*time.Second)}, + }, + LoadMonitor: swarm.LoadMonitorConfig{ + Enabled: cfg.LoadMonitor.Enabled, + Interval: swarm.Duration{Duration: applyDurationDefault(cfg.LoadMonitor.Interval, swarm.DefaultLoadSampleInterval)}, + SampleSize: cfg.LoadMonitor.SampleSize, + CPUWeight: cfg.LoadMonitor.CPUWeight, + MemoryWeight: cfg.LoadMonitor.MemoryWeight, + SessionWeight: cfg.LoadMonitor.SessionWeight, + }, + LeaderElection: swarm.LeaderElectionConfig{ + Enabled: cfg.LeaderElection.Enabled, + ElectionInterval: swarm.Duration{Duration: applyDurationDefault(cfg.LeaderElection.ElectionInterval, 5*time.Second)}, + LeaderHeartbeatTimeout: swarm.Duration{Duration: applyDurationDefault(cfg.LeaderElection.LeaderHeartbeatTimeout, 10*time.Second)}, + }, + HTTPPort: al.cfg.Gateway.Port, + } +} + +// handleIncomingHandoffRequest handles a handoff_request action from a remote node. +func (al *AgentLoop) handleIncomingHandoffRequest(payload picolib.NodePayload) (map[string]any, error) { + if al.swarmHandoff == nil { + return picolib.ErrorReply("handoff coordinator not available"), nil + } + + // Parse request from payload + reqData, ok := payload.RawValue(picolib.PayloadKeyRequest) + if !ok { + return picolib.ErrorReply("missing request in handoff payload"), nil + } + + reqJSON, err := json.Marshal(reqData) + if err != nil { + return picolib.ErrorReply(fmt.Sprintf("failed to marshal request: %v", err)), nil + } + + var req swarm.HandoffRequest + if err := json.Unmarshal(reqJSON, &req); err != nil { + return picolib.ErrorReply(fmt.Sprintf("failed to parse handoff request: %v", err)), nil + } + + resp := al.swarmHandoff.HandleIncomingHandoff(&req) + + return swarm.HandoffResponseReply(resp), nil +} + +// handleIncomingNodeMessage handles a message received from another node via Pico channel. +func (al *AgentLoop) handleIncomingNodeMessage(msg *picolib.DirectMessage) (string, error) { + logger.InfoCF("swarm", "Processing incoming node message", map[string]any{ + "from": msg.SourceNodeID, + "message_id": msg.MessageID, + "content": msg.Content[:min(50, len(msg.Content))], + }) + + // Create an inbound message from the node message + inboundMsg := bus.InboundMessage{ + Content: msg.Content, + Channel: msg.Channel, + ChatID: msg.ChatID, + SenderID: msg.SenderID, + } + + // Process the message through the agent loop + ctx := context.Background() + response, err := al.processMessage(ctx, inboundMsg) + if err != nil { + return "", fmt.Errorf("failed to process message: %w", err) + } + + return response, nil +} + +// shouldHandoff determines if the current request should be handed off to another node. +func (al *AgentLoop) shouldHandoff(agent *AgentInstance, opts processOptions) bool { + if !al.swarmEnabled || al.swarmHandoff == nil { + return false + } + + // Check if load is too high + if al.swarmLoad != nil && al.swarmLoad.ShouldOffload() { + logger.InfoCF("swarm", "Load threshold exceeded, considering handoff", map[string]any{ + "load_score": al.swarmLoad.GetCurrentLoad().Score, + }) + return true + } + + return false +} + +// UpdateSwarmLoad updates the current load score reported to the swarm. +func (al *AgentLoop) UpdateSwarmLoad(sessionCount int) { + if al.swarmLoad != nil { + al.swarmLoad.SetSessionCount(sessionCount) + } +} + +// IncrementSwarmSessions increments the active session count. +func (al *AgentLoop) IncrementSwarmSessions() { + if al.swarmLoad != nil { + al.swarmLoad.IncrementSessions() + } +} + +// DecrementSwarmSessions decrements the active session count. +func (al *AgentLoop) DecrementSwarmSessions() { + if al.swarmLoad != nil { + al.swarmLoad.DecrementSessions() + } +} + +// GetSwarmStatus returns the current swarm status. +func (al *AgentLoop) GetSwarmStatus() map[string]any { + if !al.swarmEnabled { + return map[string]any{"enabled": false} + } + + status := map[string]any{ + "enabled": true, + "node_id": al.swarmDiscovery.LocalNode().ID, + "handoff": al.cfg.Swarm.Handoff.Enabled, + } + + if al.swarmLoad != nil { + metrics := al.swarmLoad.GetCurrentLoad() + status["load"] = map[string]any{ + "score": metrics.Score, + "cpu_usage": metrics.CPUUsage, + "memory_usage": metrics.MemoryUsage, + "active_sessions": metrics.ActiveSessions, + "goroutines": metrics.Goroutines, + "trend": al.swarmLoad.GetTrend(), + } + } + + if al.swarmDiscovery != nil { + members := al.swarmDiscovery.Members() + status["members"] = len(members) + } + + return status +} + +// ShutdownSwarm gracefully shuts down the swarm components. +func (al *AgentLoop) ShutdownSwarm() error { + if !al.swarmEnabled { + return nil + } + + var errs []string + + if al.swarmLoad != nil { + al.swarmLoad.Stop() + } + + if al.swarmLeaderElection != nil { + al.swarmLeaderElection.Stop() + } + + if al.swarmHandoff != nil { + if err := al.swarmHandoff.Close(); err != nil { + errs = append(errs, fmt.Sprintf("handoff: %v", err)) + } + } + + if al.swarmDiscovery != nil { + if err := al.swarmDiscovery.Stop(); err != nil { + errs = append(errs, fmt.Sprintf("discovery: %v", err)) + } + } + + al.swarmEnabled = false + + if len(errs) > 0 { + return fmt.Errorf("swarm shutdown errors: %s", strings.Join(errs, ", ")) + } + return nil +} + +// initiateSwarmHandoff initiates a handoff to another node. +func (al *AgentLoop) initiateSwarmHandoff( + ctx context.Context, + agent *AgentInstance, + sessionKey string, + msg bus.InboundMessage, +) (*swarm.HandoffResponse, error) { + if al.swarmHandoff == nil { + return nil, swarm.ErrDiscoveryDisabled + } + + // Build session history for handoff + sessionMessages := make([]picolib.SessionMessage, 0) + history := agent.Sessions.GetHistory(sessionKey) + + for _, m := range history { + if m.Role == "user" || m.Role == "assistant" { + sessionMessages = append(sessionMessages, picolib.SessionMessage{ + Role: m.Role, + Content: m.Content, + }) + } + } + + // Create handoff request + req := &swarm.HandoffRequest{ + Reason: swarm.ReasonOverloaded, + SessionKey: sessionKey, + SessionMessages: sessionMessages, + Context: map[string]any{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "sender": msg.SenderID, + "agent_id": agent.ID, + }, + Metadata: map[string]string{ + "original_channel": msg.Channel, + "original_chat_id": msg.ChatID, + }, + } + + logger.InfoCF("swarm", "Initiating handoff", map[string]any{ + "session_key": sessionKey, + "reason": req.Reason, + "history_len": len(sessionMessages), + }) + + // Execute handoff + resp, err := al.swarmHandoff.InitiateHandoff(ctx, req) + + if resp != nil { + logger.InfoCF("swarm", "Handoff response received", map[string]any{ + "accepted": resp.Accepted, + "node_id": resp.NodeID, + "state": resp.State, + }) + } + + return resp, err +} diff --git a/pkg/channels/commands.go b/pkg/channels/commands.go new file mode 100644 index 000000000..e0ebbc11a --- /dev/null +++ b/pkg/channels/commands.go @@ -0,0 +1,77 @@ +package channels + +import ( + "context" + "sort" + "sync" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +// CommandHandler processes a slash command and returns a text response. +// args contains everything after the command name (e.g. for "/nodes verbose", args = "verbose"). +// msg provides the full inbound message context (channel, sender, chat, etc.). +type CommandHandler func(ctx context.Context, args string, msg bus.InboundMessage) (string, error) + +// CommandEntry holds a registered command. +type CommandEntry struct { + Name string // command name without leading slash, e.g. "nodes" + Description string // human-readable description, e.g. "List swarm cluster nodes" + Handler CommandHandler // function to execute +} + +// CommandRegistry is a thread-safe registry of slash commands. +// External modules register commands here; the agent loop checks it when +// processing inbound messages that start with "/". +type CommandRegistry struct { + mu sync.RWMutex + commands map[string]*CommandEntry +} + +// NewCommandRegistry creates an empty command registry. +func NewCommandRegistry() *CommandRegistry { + return &CommandRegistry{ + commands: make(map[string]*CommandEntry), + } +} + +// Register adds or replaces a command in the registry. +// name should NOT include the leading slash. +func (r *CommandRegistry) Register(name, description string, handler CommandHandler) { + r.mu.Lock() + defer r.mu.Unlock() + r.commands[name] = &CommandEntry{ + Name: name, + Description: description, + Handler: handler, + } +} + +// Get looks up a command by name. Returns nil, false if not found. +func (r *CommandRegistry) Get(name string) (*CommandEntry, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + entry, ok := r.commands[name] + return entry, ok +} + +// List returns all registered commands sorted by name. +func (r *CommandRegistry) List() []*CommandEntry { + r.mu.RLock() + defer r.mu.RUnlock() + entries := make([]*CommandEntry, 0, len(r.commands)) + for _, e := range r.commands { + entries = append(entries, e) + } + sort.Slice(entries, func(i, j int) bool { + return entries[i].Name < entries[j].Name + }) + return entries +} + +// Remove unregisters a command by name. No-op if not found. +func (r *CommandRegistry) Remove(name string) { + r.mu.Lock() + defer r.mu.Unlock() + delete(r.commands, name) +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 31af9672c..f54b6edd9 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -79,12 +79,13 @@ type Manager struct { bus *bus.MessageBus config *config.Config mediaStore media.MediaStore + commands *CommandRegistry dispatchTask *asyncTask mux *http.ServeMux httpServer *http.Server mu sync.RWMutex - placeholders sync.Map // "channel:chatID" → placeholderID (string) - typingStops sync.Map // "channel:chatID" → func() + placeholders sync.Map // "channel:chatID" → placeholderEntry + typingStops sync.Map // "channel:chatID" → typingEntry reactionUndos sync.Map // "channel:chatID" → reactionEntry } @@ -154,6 +155,7 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.Medi bus: messageBus, config: cfg, mediaStore: store, + commands: NewCommandRegistry(), } if err := m.initChannels(); err != nil { @@ -780,6 +782,12 @@ func (m *Manager) RegisterChannel(name string, channel Channel) { m.channels[name] = channel } +// CommandRegistry returns the shared command registry. +// External modules use this to register slash commands that work across all channels. +func (m *Manager) CommandRegistry() *CommandRegistry { + return m.commands +} + func (m *Manager) UnregisterChannel(name string) { m.mu.Lock() defer m.mu.Unlock() diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index 2ae82d8da..7dd1e9a30 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -50,12 +50,13 @@ func (pc *picoConn) close() { // It serves as the reference implementation for all optional capability interfaces. type PicoChannel struct { *channels.BaseChannel - config config.PicoConfig - upgrader websocket.Upgrader - connections sync.Map // connID → *picoConn - connCount atomic.Int32 - ctx context.Context - cancel context.CancelFunc + config config.PicoConfig + upgrader websocket.Upgrader + connections sync.Map // connID → *picoConn + connCount atomic.Int32 + ctx context.Context + cancel context.CancelFunc + nodeRequestHandler func(payload map[string]any) (map[string]any, error) } // NewPicoChannel creates a new Pico Protocol channel. @@ -403,6 +404,9 @@ func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) { case TypeMessageSend: c.handleMessageSend(pc, msg) + case TypeNodeRequest: + c.handleNodeRequest(pc, msg) + default: errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type)) pc.writeJSON(errMsg) @@ -452,6 +456,59 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata, sender) } +// SetNodeRequestHandler sets the handler for incoming inter-node request messages. +// The handler receives the full payload map and returns a reply payload map. +func (c *PicoChannel) SetNodeRequestHandler(handler func(payload map[string]any) (map[string]any, error)) { + c.nodeRequestHandler = handler +} + +// handleNodeRequest processes an incoming node.request message from a swarm peer. +// Note: This requires a valid Pico token (authenticated WebSocket connection). +// In swarm mode with inter-node communication, both nodes must share the same token. +// Future enhancement: Add node-identity verification via signatures for production deployments. +func (c *PicoChannel) handleNodeRequest(pc *picoConn, msg PicoMessage) { + requestID, _ := msg.Payload["request_id"].(string) + sourceNodeID, _ := msg.Payload["source_node_id"].(string) + action, _ := msg.Payload["action"].(string) + + logger.InfoCF("pico", "Received node request", map[string]any{ + "request_id": requestID, + "source_node_id": sourceNodeID, + "action": action, + }) + + var replyPayload map[string]any + + if c.nodeRequestHandler == nil { + replyPayload = map[string]any{ + "request_id": requestID, + "error": "no node request handler registered", + } + } else { + var err error + replyPayload, err = c.nodeRequestHandler(msg.Payload) + if err != nil { + replyPayload = map[string]any{ + "request_id": requestID, + "error": err.Error(), + } + } + // Ensure request_id is always in the reply + if replyPayload != nil { + replyPayload["request_id"] = requestID + } + } + + reply := newMessage(TypeNodeReply, replyPayload) + reply.ID = msg.ID + if err := pc.writeJSON(reply); err != nil { + logger.ErrorCF("pico", "Failed to send node reply", map[string]any{ + "error": err.Error(), + "request_id": requestID, + }) + } +} + // truncate truncates a string to maxLen runes. func truncate(s string, maxLen int) string { runes := []rune(s) diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go index 0a630e193..a03214164 100644 --- a/pkg/channels/pico/protocol.go +++ b/pkg/channels/pico/protocol.go @@ -3,6 +3,8 @@ package pico import "time" // Protocol message types. +// NOTE: These should match pkg/pico/protocol/protocol.go to keep protocol definitions in sync. +// The canonical definitions are in pkg/pico/protocol which has zero dependencies. const ( // TypeMessageSend is sent from client to server. TypeMessageSend = "message.send" @@ -17,6 +19,10 @@ const ( TypeTypingStop = "typing.stop" TypeError = "error" TypePong = "pong" + + // Node-to-node communication (swarm mode) + TypeNodeRequest = "node.request" + TypeNodeReply = "node.reply" ) // PicoMessage is the wire format for all Pico Protocol messages. diff --git a/pkg/config/config.go b/pkg/config/config.go index 2e0215278..f6a572a7c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -58,6 +58,7 @@ type Config struct { Tools ToolsConfig `json:"tools"` Heartbeat HeartbeatConfig `json:"heartbeat"` Devices DevicesConfig `json:"devices"` + Swarm SwarmConfig `json:"swarm,omitempty"` // Swarm mode configuration } // MarshalJSON implements custom JSON marshaling for Config @@ -774,3 +775,115 @@ func (c *Config) ValidateModelList() error { } return nil } + +// SwarmConfig contains configuration for swarm mode. +type SwarmConfig struct { + // Enabled enables swarm mode. + Enabled bool `json:"enabled" env:"PICOCLAW_SWARM_ENABLED"` + + // NodeID is the unique identifier for this node. + NodeID string `json:"node_id,omitempty" env:"PICOCLAW_SWARM_NODE_ID"` + + // BindAddr is the address to bind for gossip and RPC. + BindAddr string `json:"bind_addr,omitempty" env:"PICOCLAW_SWARM_BIND_ADDR"` + + // BindPort is the port for gossip protocol. + BindPort int `json:"bind_port,omitempty" env:"PICOCLAW_SWARM_BIND_PORT"` + + // AdvertiseAddr is the address to advertise to other nodes. + AdvertiseAddr string `json:"advertise_addr,omitempty" env:"PICOCLAW_SWARM_ADVERTISE_ADDR"` + + // Discovery configuration for node discovery. + Discovery SwarmDiscoveryConfig `json:"discovery"` + + // Handoff configuration for task handoff. + Handoff SwarmHandoffConfig `json:"handoff"` + + // RPC configuration for inter-node communication. + RPC SwarmRPCConfig `json:"rpc"` + + // LoadMonitor configuration for load monitoring. + LoadMonitor SwarmLoadMonitorConfig `json:"load_monitor"` + + // LeaderElection configuration for leader election. + LeaderElection SwarmLeaderElectionConfig `json:"leader_election"` +} + +// SwarmDiscoveryConfig contains configuration for node discovery. +type SwarmDiscoveryConfig struct { + // JoinAddrs is a list of existing nodes to join. + JoinAddrs []string `json:"join_addrs,omitempty"` + + // GossipInterval is the interval between gossip messages (in seconds). + GossipInterval int `json:"gossip_interval,omitempty"` + + // PushPullInterval is the interval for full state sync (in seconds). + PushPullInterval int `json:"push_pull_interval,omitempty"` + + // NodeTimeout is the timeout before marking a node as suspect (in seconds). + NodeTimeout int `json:"node_timeout,omitempty"` + + // DeadNodeTimeout is the timeout before marking a node as dead (in seconds). + DeadNodeTimeout int `json:"dead_node_timeout,omitempty"` +} + +// SwarmHandoffConfig contains configuration for task handoff. +type SwarmHandoffConfig struct { + // Enabled enables task handoff. + Enabled bool `json:"enabled"` + + // LoadThreshold is the load score threshold (0-1) above which + // tasks will be handed off to other nodes. + LoadThreshold float64 `json:"load_threshold,omitempty"` + + // Timeout is the timeout for a handoff operation (in seconds). + Timeout int `json:"timeout,omitempty"` + + // MaxRetries is the maximum number of retries for handoff. + MaxRetries int `json:"max_retries,omitempty"` + + // RetryDelay is the delay between retries (in seconds). + RetryDelay int `json:"retry_delay,omitempty"` +} + +// SwarmRPCConfig contains configuration for RPC communication. +type SwarmRPCConfig struct { + // Port is the port for RPC communication. + Port int `json:"port,omitempty" env:"PICOCLAW_SWARM_RPC_PORT"` + + // Timeout is the default timeout for RPC calls (in seconds). + Timeout int `json:"timeout,omitempty"` +} + +// SwarmLoadMonitorConfig contains configuration for load monitoring. +type SwarmLoadMonitorConfig struct { + // Enabled enables load monitoring. + Enabled bool `json:"enabled"` + + // Interval is the interval between load samples (in seconds). + Interval int `json:"interval,omitempty"` + + // SampleSize is the number of samples to keep for averaging. + SampleSize int `json:"sample_size,omitempty"` + + // CPUWeight is the weight for CPU usage in load score (0-1). + CPUWeight float64 `json:"cpu_weight,omitempty"` + + // MemoryWeight is the weight for memory usage in load score (0-1). + MemoryWeight float64 `json:"memory_weight,omitempty"` + + // SessionWeight is the weight for active sessions in load score (0-1). + SessionWeight float64 `json:"session_weight,omitempty"` +} + +// SwarmLeaderElectionConfig contains configuration for leader election. +type SwarmLeaderElectionConfig struct { + // Enabled enables leader election. + Enabled bool `json:"enabled"` + + // ElectionInterval is how often to check leadership (in seconds). + ElectionInterval int `json:"election_interval,omitempty"` + + // LeaderHeartbeatTimeout is how long before assuming leader is dead (in seconds). + LeaderHeartbeatTimeout int `json:"leader_heartbeat_timeout,omitempty"` +} diff --git a/pkg/pico/addr.go b/pkg/pico/addr.go new file mode 100644 index 000000000..4d9152074 --- /dev/null +++ b/pkg/pico/addr.go @@ -0,0 +1,14 @@ +// Package pico provides a reusable WebSocket client for the Pico Protocol. +// +// In addition to the low-level Client, this package defines shared types and +// helpers used by both the Pico channel (server) and the swarm subsystem +// (client) for inter-node communication. +package pico + +import "fmt" + +// BuildNodeAddr constructs a host:port address string for inter-node +// communication. It replaces the scattered fmt.Sprintf("%s:%d", ...) calls. +func BuildNodeAddr(addr string, port int) string { + return fmt.Sprintf("%s:%d", addr, port) +} diff --git a/pkg/pico/client.go b/pkg/pico/client.go new file mode 100644 index 000000000..b7ba7391d --- /dev/null +++ b/pkg/pico/client.go @@ -0,0 +1,94 @@ +// Package pico provides a reusable WebSocket client for the Pico Protocol. +// +// The Client type encapsulates the connect → send → receive → close lifecycle +// for a single request-reply exchange with a Pico WebSocket endpoint. It is +// intentionally stateless: each call to SendRequest opens a new connection, +// performs the exchange, and closes the connection. +// +// This package depends only on pkg/pico/protocol and gorilla/websocket; +// it has no knowledge of swarm, channels, or any other higher-level construct. +package pico + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/pico/protocol" +) + +// DefaultReadTimeout is the maximum time to wait for a reply after sending. +const DefaultReadTimeout = 30 * time.Second + +// Client is a lightweight, stateless Pico WebSocket client. +// Each SendRequest call dials a new connection, performs a single +// request-reply exchange, and closes the connection. +type Client struct { + token string +} + +// NewClient creates a new Pico WebSocket client. +// If token is non-empty it is sent as a Bearer token in the upgrade request. +func NewClient(token string) *Client { + return &Client{token: token} +} + +// BuildWSURL constructs the canonical Pico WebSocket URL for a given +// host address and session ID. +func BuildWSURL(addr, sessionID string) string { + return fmt.Sprintf("ws://%s/pico/ws?session_id=%s", addr, sessionID) +} + +// SendRequest dials the target Pico WebSocket endpoint, sends msg, and blocks +// until a single reply message is received (or the context / read timeout fires). +// +// The caller is responsible for constructing the outbound protocol.Message +// (including Type, ID, Payload, etc.) and for interpreting the reply. +func (c *Client) SendRequest( + ctx context.Context, + addr, sessionID string, + msg protocol.Message, +) (protocol.Message, error) { + wsURL := BuildWSURL(addr, sessionID) + + header := http.Header{} + if c.token != "" { + header.Set("Authorization", "Bearer "+c.token) + } + + dialer := websocket.Dialer{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + conn, resp, err := dialer.DialContext(ctx, wsURL, header) + if err != nil { + if resp != nil { + return protocol.Message{}, fmt.Errorf("pico WebSocket dial failed: %s (status: %d)", err.Error(), resp.StatusCode) + } + return protocol.Message{}, fmt.Errorf("pico WebSocket dial failed: %w", err) + } + defer conn.Close() + + if writeErr := conn.WriteJSON(msg); writeErr != nil { + return protocol.Message{}, fmt.Errorf("failed to send pico request: %w", writeErr) + } + + conn.SetReadDeadline(time.Now().Add(DefaultReadTimeout)) + + _, rawReply, err := conn.ReadMessage() + if err != nil { + return protocol.Message{}, fmt.Errorf("failed to read pico reply: %w", err) + } + + var reply protocol.Message + if err := json.Unmarshal(rawReply, &reply); err != nil { + return protocol.Message{}, fmt.Errorf("failed to parse pico reply: %w", err) + } + + return reply, nil +} diff --git a/pkg/pico/node_client.go b/pkg/pico/node_client.go new file mode 100644 index 000000000..8c7f0a980 --- /dev/null +++ b/pkg/pico/node_client.go @@ -0,0 +1,123 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package pico + +import ( + "context" + "fmt" + "time" + + "github.com/google/uuid" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/pico/protocol" +) + +// PicoNodeClient sends inter-node messages by connecting to a target node's +// Pico WebSocket endpoint. It delegates the low-level WebSocket lifecycle to +// Client and focuses on swarm-specific payload construction. +type PicoNodeClient struct { + sourceNodeID string + client *Client +} + +// NewPicoNodeClient creates a new Pico-based inter-node client. +func NewPicoNodeClient(sourceNodeID, token string) *PicoNodeClient { + return &PicoNodeClient{ + sourceNodeID: sourceNodeID, + client: NewClient(token), + } +} + +// SendMessage connects to the target node's Pico WebSocket, sends a node.request +// with action "message", and blocks until a node.reply is received (or timeout). +func (c *PicoNodeClient) SendMessage( + ctx context.Context, + targetAddr, targetNodeID, content, channel, chatID, senderID string, +) (string, error) { + payload := NewMessagePayload(c.sourceNodeID, content, channel, chatID, senderID) + return c.sendRequest(ctx, targetAddr, targetNodeID, payload) +} + +// SendNodeAction sends an action-based request to a target node via Pico. +// The payload must contain an "action" key. Returns the raw reply payload. +func (c *PicoNodeClient) SendNodeAction( + ctx context.Context, + targetAddr string, + payload NodePayload, +) (NodePayload, error) { + requestID := uuid.New().String() + payload[PayloadKeySourceNodeID] = c.sourceNodeID + payload[PayloadKeyRequestID] = requestID + + logger.InfoCF("pico", "Sending node action via Pico", map[string]any{ + "action": payload.Action(), + "target_addr": targetAddr, + "request_id": requestID, + }) + + return c.doSend(ctx, targetAddr, requestID, payload) +} + +// sendRequest is the internal method for sending a request and returning the "response" string. +func (c *PicoNodeClient) sendRequest( + ctx context.Context, + targetAddr, targetNodeID string, + payload NodePayload, +) (string, error) { + requestID := uuid.New().String() + payload[PayloadKeyRequestID] = requestID + + logger.InfoCF("pico", "Sending node request via Pico", map[string]any{ + "action": payload.Action(), + "target_node_id": targetNodeID, + "target_addr": targetAddr, + "request_id": requestID, + }) + + replyPayload, err := c.doSend(ctx, targetAddr, requestID, payload) + if err != nil { + return "", err + } + + if errStr := replyPayload.ErrorMsg(); errStr != "" { + return "", fmt.Errorf("node error: %s", errStr) + } + + return replyPayload.Response(), nil +} + +// doSend builds a protocol.Message, delegates to Client.SendRequest, +// and validates the reply envelope before returning the payload. +func (c *PicoNodeClient) doSend( + ctx context.Context, + targetAddr, requestID string, + payload NodePayload, +) (NodePayload, error) { + reqMsg := protocol.Message{ + Type: protocol.TypeNodeRequest, + ID: requestID, + SessionID: c.sourceNodeID, + Timestamp: time.Now().UnixMilli(), + Payload: payload, + } + + reply, err := c.client.SendRequest(ctx, targetAddr, c.sourceNodeID, reqMsg) + if err != nil { + return nil, err + } + + if reply.Type != protocol.TypeNodeReply { + return nil, fmt.Errorf("unexpected reply type: %s (expected %s)", reply.Type, protocol.TypeNodeReply) + } + + replyPayload := NodePayload(reply.Payload) + if replyPayload.RequestID() != requestID { + return nil, fmt.Errorf("request ID mismatch: got %s, want %s", replyPayload.RequestID(), requestID) + } + + return replyPayload, nil +} diff --git a/pkg/pico/node_payload.go b/pkg/pico/node_payload.go new file mode 100644 index 000000000..ad1cd8ed4 --- /dev/null +++ b/pkg/pico/node_payload.go @@ -0,0 +1,131 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package pico + +// NodePayload is the typed payload exchanged between nodes via the +// Pico channel node.request / node.reply protocol. It provides named +// constants for field keys and typed accessor methods so that callers +// never need to use raw string literals. +type NodePayload map[string]any + +// Payload field key constants. +const ( + PayloadKeyAction = "action" + PayloadKeyRequestID = "request_id" + PayloadKeySourceNodeID = "source_node_id" + PayloadKeyContent = "content" + PayloadKeyChannel = "channel" + PayloadKeyChatID = "chat_id" + PayloadKeySenderID = "sender_id" + PayloadKeyMetadata = "metadata" + PayloadKeyError = "error" + PayloadKeyResponse = "response" + PayloadKeyRequest = "request" // nested handoff request object + PayloadKeyHandoffResp = "handoff_response" // nested handoff response object +) + +// Node action constants used for action-based routing over Pico. +const ( + NodeActionMessage = "message" + NodeActionHandoffRequest = "handoff_request" +) + +// --------------------------------------------------------------------------- +// Accessors +// --------------------------------------------------------------------------- + +// str is a small helper that extracts a string value from the payload. +func (p NodePayload) str(key string) string { + v, _ := p[key].(string) + return v +} + +// Action returns the action field (e.g. "message", "handoff_request"). +func (p NodePayload) Action() string { return p.str(PayloadKeyAction) } + +// RequestID returns the request_id field. +func (p NodePayload) RequestID() string { return p.str(PayloadKeyRequestID) } + +// SourceNodeID returns the source_node_id field. +func (p NodePayload) SourceNodeID() string { return p.str(PayloadKeySourceNodeID) } + +// Content returns the content field. +func (p NodePayload) Content() string { return p.str(PayloadKeyContent) } + +// Channel returns the channel field. +func (p NodePayload) Channel() string { return p.str(PayloadKeyChannel) } + +// ChatID returns the chat_id field. +func (p NodePayload) ChatID() string { return p.str(PayloadKeyChatID) } + +// SenderID returns the sender_id field. +func (p NodePayload) SenderID() string { return p.str(PayloadKeySenderID) } + +// ErrorMsg returns the error field. +func (p NodePayload) ErrorMsg() string { return p.str(PayloadKeyError) } + +// Response returns the response field. +func (p NodePayload) Response() string { return p.str(PayloadKeyResponse) } + +// Metadata extracts the metadata map, converting map[string]any to map[string]string. +func (p NodePayload) Metadata() map[string]string { + raw, ok := p[PayloadKeyMetadata] + if !ok { + return nil + } + m, ok := raw.(map[string]any) + if !ok { + return nil + } + result := make(map[string]string, len(m)) + for k, v := range m { + if s, ok := v.(string); ok { + result[k] = s + } + } + return result +} + +// RawValue returns the raw value for an arbitrary key. +func (p NodePayload) RawValue(key string) (any, bool) { + v, ok := p[key] + return v, ok +} + +// --------------------------------------------------------------------------- +// Builder helpers +// --------------------------------------------------------------------------- + +// NewNodePayload creates an empty NodePayload. +func NewNodePayload() NodePayload { + return make(NodePayload) +} + +// NewMessagePayload creates a NodePayload pre-filled for a "message" action. +func NewMessagePayload(sourceNodeID, content, channel, chatID, senderID string) NodePayload { + return NodePayload{ + PayloadKeyAction: NodeActionMessage, + PayloadKeySourceNodeID: sourceNodeID, + PayloadKeyContent: content, + PayloadKeyChannel: channel, + PayloadKeyChatID: chatID, + PayloadKeySenderID: senderID, + } +} + +// --------------------------------------------------------------------------- +// Reply constructors +// --------------------------------------------------------------------------- + +// ErrorReply creates a reply payload carrying an error message. +func ErrorReply(msg string) NodePayload { + return NodePayload{PayloadKeyError: msg} +} + +// ResponseReply creates a reply payload carrying a string response. +func ResponseReply(response string) NodePayload { + return NodePayload{PayloadKeyResponse: response} +} diff --git a/pkg/pico/protocol/protocol.go b/pkg/pico/protocol/protocol.go new file mode 100644 index 000000000..3049a5e4a --- /dev/null +++ b/pkg/pico/protocol/protocol.go @@ -0,0 +1,54 @@ +// Package protocol defines the Pico Protocol wire format shared by +// the Pico channel (server) and the swarm PicoNodeClient (client). +// This package has zero internal dependencies to stay at the bottom +// of the dependency graph. +package protocol + +import "time" + +// Message type constants for the Pico Protocol. +const ( + // Client → Server + TypeMessageSend = "message.send" + TypeMediaSend = "media.send" + TypePing = "ping" + + // Server → Client + TypeMessageCreate = "message.create" + TypeMessageUpdate = "message.update" + TypeMediaCreate = "media.create" + TypeTypingStart = "typing.start" + TypeTypingStop = "typing.stop" + TypeError = "error" + TypePong = "pong" + + // Inter-node (swarm) + TypeNodeRequest = "node.request" + TypeNodeReply = "node.reply" +) + +// Message is the wire format for all Pico Protocol messages. +type Message struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + SessionID string `json:"session_id,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + Payload map[string]any `json:"payload,omitempty"` +} + +// NewMessage creates a Message with the given type, payload, and current timestamp. +func NewMessage(msgType string, payload map[string]any) Message { + return Message{ + Type: msgType, + Timestamp: time.Now().UnixMilli(), + Payload: payload, + } +} + +// NewError creates an error Message with code and human-readable message. +func NewError(code, message string) Message { + return NewMessage(TypeError, map[string]any{ + "code": code, + "message": message, + }) +} diff --git a/pkg/pico/types.go b/pkg/pico/types.go new file mode 100644 index 000000000..36640f3d8 --- /dev/null +++ b/pkg/pico/types.go @@ -0,0 +1,44 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package pico + +// SessionMessage represents a message in a session. +// This is shared across handoff, session transfer, and inter-node communication. +type SessionMessage struct { + Role string `json:"role"` + Content string `json:"content"` + Timestamp int64 `json:"timestamp,omitempty"` + ToolCalls []ToolCallData `json:"tool_calls,omitempty"` +} + +// ToolCallData represents tool call information in a message. +type ToolCallData struct { + ID string `json:"id"` + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + Result string `json:"result,omitempty"` + Extra map[string]any `json:"extra,omitempty"` +} + +// DirectMessage represents a direct message sent to another node. +type DirectMessage struct { + MessageID string `json:"message_id"` + SourceNodeID string `json:"source_node_id"` + TargetNodeID string `json:"target_node_id"` + Content string `json:"content"` + Channel string `json:"channel"` + ChatID string `json:"chat_id"` + SenderID string `json:"sender_id"` + Metadata map[string]string `json:"metadata,omitempty"` + Timestamp int64 `json:"timestamp"` +} + +// DirectMessageResponse represents a response to a direct message. +type DirectMessageResponse struct { + MessageID string `json:"message_id"` + Response string `json:"response"` + Error string `json:"error,omitempty"` +} diff --git a/pkg/swarm/config.go b/pkg/swarm/config.go new file mode 100644 index 000000000..b3e7cf518 --- /dev/null +++ b/pkg/swarm/config.go @@ -0,0 +1,280 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "encoding/json" + "time" +) + +// Config contains all configuration for swarm mode. +type Config struct { + // Enabled enables swarm mode. + Enabled bool `json:"enabled" env:"PICOCLAW_SWARM_ENABLED"` + + // NodeID is the unique identifier for this node. + // If empty, a hostname-based ID will be generated. + NodeID string `json:"node_id,omitempty" env:"PICOCLAW_SWARM_NODE_ID"` + + // BindAddr is the address to bind for gossip and RPC. + BindAddr string `json:"bind_addr,omitempty" env:"PICOCLAW_SWARM_BIND_ADDR"` + + // BindPort is the port for gossip protocol. + BindPort int `json:"bind_port,omitempty" env:"PICOCLAW_SWARM_BIND_PORT"` + + // AdvertiseAddr is the address to advertise to other nodes. + // If empty, BindAddr will be used. + AdvertiseAddr string `json:"advertise_addr,omitempty" env:"PICOCLAW_SWARM_ADVERTISE_ADDR"` + + // AdvertisePort is the port to advertise to other nodes. + // If 0, BindPort will be used. + AdvertisePort int `json:"advertise_port,omitempty" env:"PICOCLAW_SWARM_ADVERTISE_PORT"` + + // Discovery configuration for node discovery. + Discovery DiscoveryConfig `json:"discovery"` + + // Handoff configuration for task handoff. + Handoff HandoffConfig `json:"handoff"` + + // RPC configuration for inter-node communication. + RPC RPCConfig `json:"rpc"` + + // LoadMonitor configuration for load monitoring. + LoadMonitor LoadMonitorConfig `json:"load_monitor"` + + // LeaderElection configuration for leader election. + LeaderElection LeaderElectionConfig `json:"leader_election"` + + // Metrics configuration for observability. + Metrics MetricsConfig `json:"metrics"` + + // HTTPPort is the HTTP/gateway port where the Pico channel is served. + // Used for inter-node communication via the Pico WebSocket protocol. + HTTPPort int `json:"http_port,omitempty"` +} + +// DiscoveryConfig contains configuration for node discovery. +type DiscoveryConfig struct { + // JoinAddrs is a list of existing nodes to join. + JoinAddrs []string `json:"join_addrs,omitempty"` + + // GossipInterval is the interval between gossip messages. + GossipInterval Duration `json:"gossip_interval,omitempty"` + + // PushPullInterval is the interval for full state sync. + PushPullInterval Duration `json:"push_pull_interval,omitempty"` + + // NodeTimeout is the timeout before marking a node as suspect. + NodeTimeout Duration `json:"node_timeout,omitempty"` + + // DeadNodeTimeout is the timeout before marking a node as dead. + DeadNodeTimeout Duration `json:"dead_node_timeout,omitempty"` + + // AuthSecret is the shared secret for node authentication. + // If empty, authentication is disabled (not recommended for production). + AuthSecret string `json:"auth_secret,omitempty"` + + // RequireAuth requires all nodes to be authenticated. + RequireAuth bool `json:"require_auth"` + + // EnableMessageSigning enables HMAC signing of all messages. + EnableMessageSigning bool `json:"enable_message_signing"` +} + +// HandoffConfig contains configuration for task handoff. +type HandoffConfig struct { + // Enabled enables task handoff. + Enabled bool `json:"enabled"` + + // LoadThreshold is the load score threshold (0-1) above which + // tasks will be handed off to other nodes. + LoadThreshold float64 `json:"load_threshold,omitempty"` + + // Timeout is the timeout for a handoff operation. + Timeout Duration `json:"timeout,omitempty"` + + // MaxRetries is the maximum number of retries for handoff. + MaxRetries int `json:"max_retries,omitempty"` + + // RetryDelay is the delay between retries. + RetryDelay Duration `json:"retry_delay,omitempty"` +} + +// RPCConfig contains configuration for RPC communication. +type RPCConfig struct { + // Port is the port for RPC communication. + Port int `json:"port,omitempty" env:"PICOCLAW_SWARM_RPC_PORT"` + + // Timeout is the default timeout for RPC calls. + Timeout Duration `json:"timeout,omitempty"` +} + +// LoadMonitorConfig contains configuration for load monitoring. +type LoadMonitorConfig struct { + // Enabled enables load monitoring. + Enabled bool `json:"enabled"` + + // Interval is the interval between load samples. + Interval Duration `json:"interval,omitempty"` + + // SampleSize is the number of samples to keep for averaging. + SampleSize int `json:"sample_size,omitempty"` + + // CPUWeight is the weight for CPU usage in load score (0-1). + CPUWeight float64 `json:"cpu_weight,omitempty"` + + // MemoryWeight is the weight for memory usage in load score (0-1). + MemoryWeight float64 `json:"memory_weight,omitempty"` + + // SessionWeight is the weight for active sessions in load score (0-1). + SessionWeight float64 `json:"session_weight,omitempty"` + + // OffloadThreshold is the load score threshold above which tasks should be offloaded (0-1). + OffloadThreshold float64 `json:"offload_threshold,omitempty"` + + // MaxMemoryBytes is the maximum memory to use for normalization (default: 1GB). + MaxMemoryBytes uint64 `json:"max_memory_bytes,omitempty"` + + // MaxGoroutines is the maximum goroutine count for normalization (default: 1000). + MaxGoroutines int `json:"max_goroutines,omitempty"` + + // MaxSessions is the maximum session count for normalization (default: 100). + MaxSessions int `json:"max_sessions,omitempty"` +} + +// LeaderElectionConfig contains configuration for leader election. +type LeaderElectionConfig struct { + // Enabled enables leader election. + Enabled bool `json:"enabled"` + + // ElectionInterval is how often to check leadership. + ElectionInterval Duration `json:"election_interval,omitempty"` + + // LeaderHeartbeatTimeout is how long before assuming leader is dead. + LeaderHeartbeatTimeout Duration `json:"leader_heartbeat_timeout,omitempty"` +} + +// MetricsConfig contains configuration for metrics collection. +type MetricsConfig struct { + // Enabled enables metrics collection. + Enabled bool `json:"enabled"` + + // ExportInterval is how often to export metrics. + ExportInterval Duration `json:"export_interval,omitempty"` + + // PrometheusEnabled enables Prometheus format export. + PrometheusEnabled bool `json:"prometheus_enabled"` + + // PrometheusEndpoint is the HTTP endpoint for Prometheus metrics. + PrometheusEndpoint string `json:"prometheus_endpoint,omitempty"` +} + +// Duration is a wrapper around time.Duration for JSON parsing. +type Duration struct { + time.Duration +} + +// UnmarshalJSON parses a duration from JSON. +func (d *Duration) UnmarshalJSON(b []byte) error { + // Check if it's a string (quoted) + if len(b) > 0 && b[0] == '"' { + var s string + if err := parseJSONString(b, &s); err != nil { + return err + } + var err error + d.Duration, err = time.ParseDuration(s) + return err + } + + // Otherwise it's a number (milliseconds) + var v float64 + if err := parseJSONNumber(b, &v); err != nil { + return err + } + d.Duration = time.Duration(v) + return nil +} + +// MarshalJSON converts a duration to JSON. +func (d Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.Duration.String()) +} + +// DefaultConfig returns the default swarm configuration. +func DefaultConfig() *Config { + return &Config{ + Enabled: false, + NodeID: "", + BindAddr: "0.0.0.0", + BindPort: DefaultBindPort, + Discovery: DiscoveryConfig{ + JoinAddrs: nil, + GossipInterval: Duration{DefaultGossipInterval}, + PushPullInterval: Duration{DefaultPushPullInterval}, + NodeTimeout: Duration{DefaultNodeTimeout}, + DeadNodeTimeout: Duration{DefaultDeadNodeTimeout}, + AuthSecret: "", + RequireAuth: false, + EnableMessageSigning: false, + }, + Handoff: HandoffConfig{ + Enabled: true, + LoadThreshold: DefaultLoadThreshold, + Timeout: Duration{DefaultHandoffTimeout}, + MaxRetries: DefaultMaxHandoffRetries, + RetryDelay: Duration{DefaultHandoffRetryDelay}, + }, + RPC: RPCConfig{ + Port: DefaultRPCPort, + Timeout: Duration{10 * time.Second}, + }, + LoadMonitor: LoadMonitorConfig{ + Enabled: true, + Interval: Duration{DefaultLoadSampleInterval}, + SampleSize: DefaultLoadSampleSize, + CPUWeight: DefaultCPUWeight, + MemoryWeight: DefaultMemoryWeight, + SessionWeight: DefaultSessionWeight, + OffloadThreshold: DefaultOffloadThreshold, + MaxMemoryBytes: DefaultMaxMemoryBytes, + MaxGoroutines: DefaultMaxGoroutines, + MaxSessions: DefaultMaxSessions, + }, + LeaderElection: LeaderElectionConfig{ + Enabled: false, + ElectionInterval: Duration{5 * time.Second}, + LeaderHeartbeatTimeout: Duration{10 * time.Second}, + }, + Metrics: MetricsConfig{ + Enabled: false, + ExportInterval: Duration{10 * time.Second}, + PrometheusEnabled: false, + PrometheusEndpoint: "/metrics", + }, + } +} + +// parseJSONString parses a JSON string (including quotes). +func parseJSONString(b []byte, s *string) error { + if len(b) < 2 || b[0] != '"' || b[len(b)-1] != '"' { + return &json.UnmarshalTypeError{} + } + *s = string(b[1 : len(b)-1]) + return nil +} + +// parseJSONNumber parses a JSON number. +func parseJSONNumber(b []byte, f *float64) error { + n, err := json.Number(string(b)).Int64() + if err == nil { + *f = float64(n) + return nil + } + *f, err = json.Number(string(b)).Float64() + return err +} diff --git a/pkg/swarm/constants.go b/pkg/swarm/constants.go new file mode 100644 index 000000000..6ad788092 --- /dev/null +++ b/pkg/swarm/constants.go @@ -0,0 +1,113 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import "time" + +// GossipMessageType represents the type of a gossip protocol message. +type GossipMessageType string + +const ( + GossipTypePing GossipMessageType = "ping" + GossipTypePong GossipMessageType = "pong" + GossipTypeJoin GossipMessageType = "join" + GossipTypeUpdate GossipMessageType = "update" + GossipTypeSync GossipMessageType = "sync" +) + +// LoadTrend represents the direction of load change over time. +type LoadTrend string + +const ( + LoadTrendIncreasing LoadTrend = "increasing" + LoadTrendDecreasing LoadTrend = "decreasing" + LoadTrendStable LoadTrend = "stable" +) + +const ( + // Default values for configurable parameters + + // DefaultBindPort is the default port for gossip protocol. + DefaultBindPort = 7946 + + // DefaultRPCPort is the default port for RPC communication. + DefaultRPCPort = 7947 + + // DefaultNodeTimeout is the default timeout before marking a node as suspect. + DefaultNodeTimeout = 5 * time.Second + + // DefaultDeadNodeTimeout is the default timeout before removing a dead node. + DefaultDeadNodeTimeout = 30 * time.Second + + // DefaultGossipInterval is the default interval between gossip messages. + DefaultGossipInterval = 1 * time.Second + + // DefaultPushPullInterval is the default interval for full state sync. + DefaultPushPullInterval = 30 * time.Second + + // DefaultHandoffTimeout is the default timeout for a handoff operation. + DefaultHandoffTimeout = 30 * time.Second + + // DefaultHandoffRetryDelay is the default delay between handoff retries. + DefaultHandoffRetryDelay = 5 * time.Second + + // DefaultMaxHandoffRetries is the default maximum number of handoff retries. + DefaultMaxHandoffRetries = 3 + + // DefaultLoadSampleInterval is the default interval between load samples. + DefaultLoadSampleInterval = 5 * time.Second + + // DefaultLoadSampleSize is the default number of load samples to keep. + DefaultLoadSampleSize = 60 + + // Thresholds and limits + + // DefaultLoadThreshold is the default load score threshold for handoff. + DefaultLoadThreshold = 0.8 + + // DefaultAvailableLoadThreshold is the threshold below which a node is considered available (0-1). + DefaultAvailableLoadThreshold = 0.9 + + // DefaultOffloadThreshold is the default threshold above which tasks should be offloaded (0-1). + DefaultOffloadThreshold = 0.8 + + // DefaultMaxMemoryBytes is the default max memory for normalization (1GB). + DefaultMaxMemoryBytes = 1024 * 1024 * 1024 + + // DefaultMaxGoroutines is the default max goroutine count for normalization. + DefaultMaxGoroutines = 1000 + + // DefaultMaxSessions is the default max session count for normalization. + DefaultMaxSessions = 100 + + // DefaultCPUWeight is the default weight for CPU in load score calculation. + DefaultCPUWeight = 0.3 + + // DefaultMemoryWeight is the default weight for memory in load score calculation. + DefaultMemoryWeight = 0.3 + + // DefaultSessionWeight is the default weight for sessions in load score calculation. + DefaultSessionWeight = 0.4 + + // Buffer sizes + + // MaxGossipMessageSize is the maximum size of a gossip message (64KB). + MaxGossipMessageSize = 64 * 1024 + + // UDP write deadline + + // DefaultUDPWriteDeadline is the default write deadline for UDP operations. + DefaultUDPWriteDeadline = 5 * time.Second + + // Trend analysis + + // TrendIncreasingThreshold is the slope threshold for detecting increasing trend. + TrendIncreasingThreshold = 0.01 + + // TrendDecreasingThreshold is the slope threshold for detecting decreasing trend. + TrendDecreasingThreshold = -0.01 +) diff --git a/pkg/swarm/discovery.go b/pkg/swarm/discovery.go new file mode 100644 index 000000000..5450b4e18 --- /dev/null +++ b/pkg/swarm/discovery.go @@ -0,0 +1,531 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "encoding/json" + "fmt" + "net" + "os" + "sync" + "time" + + "github.com/google/uuid" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// DiscoveryService handles node discovery using a gossip protocol. +// For lightweight implementation, we use a simple UDP-based gossip +// instead of the heavier memberlist library. +type DiscoveryService struct { + config *Config + localNode *NodeInfo + membership *MembershipManager + eventHandler *EventDispatcher + conn *net.UDPConn + rpcConn net.Listener + auth *AuthProvider + + mu sync.RWMutex + running bool + stopChan chan struct{} + once sync.Once + + // Sequence number for updates + seqNum uint64 +} + +// NewDiscoveryService creates a new discovery service. +func NewDiscoveryService(cfg *Config) (*DiscoveryService, error) { + if cfg.NodeID == "" { + // Generate node ID from hostname + hostname, _ := os.Hostname() + if hostname == "" { + hostname = "picoclaw" + } + cfg.NodeID = fmt.Sprintf("%s-%s", hostname, uuid.New().String()[:8]) + } + + // Determine advertise address + advAddr := cfg.AdvertiseAddr + if advAddr == "" || advAddr == "0.0.0.0" { + advAddr = getLocalIP() + if advAddr == "" { + advAddr = "127.0.0.1" + } + } + + advPort := cfg.AdvertisePort + if advPort == 0 { + advPort = cfg.BindPort + } + + localNode := &NodeInfo{ + ID: cfg.NodeID, + Addr: advAddr, + Port: cfg.RPC.Port, + AgentCaps: make(map[string]string), + LoadScore: 0, + Labels: make(map[string]string), + HTTPPort: cfg.HTTPPort, + Timestamp: time.Now().UnixNano(), + Version: "1.0.0", // PicoClaw version + } + + ds := &DiscoveryService{ + config: cfg, + localNode: localNode, + eventHandler: NewEventDispatcher(), + stopChan: make(chan struct{}), + } + + // Initialize auth provider if secret is configured + if cfg.Discovery.AuthSecret != "" { + ds.auth = NewAuthProvider(cfg.NodeID, cfg.Discovery.AuthSecret) + if cfg.Discovery.RequireAuth || cfg.Discovery.EnableMessageSigning { + logger.InfoC("swarm", "Authentication enabled for swarm") + } + } + + // Initialize membership manager + ds.membership = NewMembershipManager(ds, cfg.Discovery) + + return ds, nil +} + +// Start starts the discovery service. +func (ds *DiscoveryService) Start() error { + ds.mu.Lock() + if ds.running { + ds.mu.Unlock() + return nil + } + ds.running = true + ds.mu.Unlock() + + // Bind UDP socket for gossip + addr := fmt.Sprintf("%s:%d", ds.config.BindAddr, ds.config.BindPort) + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return fmt.Errorf("failed to resolve UDP address: %w", err) + } + + ds.conn, err = net.ListenUDP("udp", udpAddr) + if err != nil { + return fmt.Errorf("failed to listen on UDP: %w", err) + } + + // Start gossip listener + go ds.gossipListener() + + // Start periodic gossip + go ds.gossipLoop() + + // Join existing cluster if addresses provided + if len(ds.config.Discovery.JoinAddrs) > 0 { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + ds.Join(ctx, ds.config.Discovery.JoinAddrs) + } + + // Add self to membership + ds.membership.UpdateNode(ds.localNode) + + return nil +} + +// Stop stops the discovery service. +func (ds *DiscoveryService) Stop() error { + ds.once.Do(func() { + ds.mu.Lock() + ds.running = false + ds.mu.Unlock() + + if ds.stopChan != nil { + close(ds.stopChan) + } + + if ds.conn != nil { + ds.conn.Close() + } + + if ds.rpcConn != nil { + ds.rpcConn.Close() + } + }) + return nil +} + +// Join joins a cluster by contacting existing nodes. +func (ds *DiscoveryService) Join(ctx context.Context, addrs []string) (int, error) { + count := 0 + + for _, addr := range addrs { + // Send join message to each address + err := ds.sendJoin(ctx, addr) + if err == nil { + count++ + } + } + + // Warn if no addresses could be reached + if count == 0 && len(addrs) > 0 { + logger.WarnCF("swarm", "Failed to join cluster: no nodes were reachable. Node starting in isolated mode", map[string]any{ + "attempted_addrs": addrs, + "local_node_id": ds.localNode.ID, + }) + } + + return count, nil +} + +// Members returns all known members of the cluster. +func (ds *DiscoveryService) Members() []*NodeWithState { + return ds.membership.GetMembers() +} + +// GetMembershipManager returns the membership manager. +func (ds *DiscoveryService) GetMembershipManager() *MembershipManager { + return ds.membership +} + +// LocalNode returns the local node info. +func (ds *DiscoveryService) LocalNode() *NodeInfo { + ds.mu.RLock() + defer ds.mu.RUnlock() + return ds.localNode +} + +// UpdateLocalInfo updates the local node's information. +func (ds *DiscoveryService) UpdateLocalInfo(info *NodeInfo) { + ds.mu.Lock() + ds.localNode = info + ds.localNode.Timestamp = time.Now().UnixNano() + ds.seqNum++ + ds.mu.Unlock() + + // Update membership + ds.membership.UpdateNode(info) + + // Broadcast update + ds.broadcastUpdate() +} + +// UpdateLoad updates the local node's load score. +func (ds *DiscoveryService) UpdateLoad(score float64) { + ds.mu.Lock() + ds.localNode.LoadScore = score + ds.localNode.Timestamp = time.Now().UnixNano() + ds.seqNum++ + info := ds.localNode + ds.mu.Unlock() + + ds.membership.UpdateNode(info) + ds.broadcastUpdate() +} + +// UpdateCapabilities updates the local node's agent capabilities. +func (ds *DiscoveryService) UpdateCapabilities(caps map[string]string) { + ds.mu.Lock() + ds.localNode.AgentCaps = caps + ds.localNode.Timestamp = time.Now().UnixNano() + ds.seqNum++ + info := ds.localNode + ds.mu.Unlock() + + ds.membership.UpdateNode(info) + ds.broadcastUpdate() +} + +// Subscribe registers a handler for node events and returns its ID. +func (ds *DiscoveryService) Subscribe(handler EventHandler) EventHandlerID { + return ds.eventHandler.Subscribe(handler) +} + +// Unsubscribe removes a node event handler by ID. +func (ds *DiscoveryService) Unsubscribe(id EventHandlerID) { + ds.eventHandler.Unsubscribe(id) +} + +// gossipListener listens for incoming gossip messages. +func (ds *DiscoveryService) gossipListener() { + buf := make([]byte, MaxGossipMessageSize) + + for { + select { + case <-ds.stopChan: + return + default: + } + + ds.conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, addr, err := ds.conn.ReadFromUDP(buf) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + return + } + + if n > 0 { + go ds.handleGossip(buf[:n], addr) + } + } +} + +// gossipLoop periodically gossips node state to random members. +func (ds *DiscoveryService) gossipLoop() { + ticker := time.NewTicker(ds.config.Discovery.GossipInterval.Duration) + defer ticker.Stop() + + for { + select { + case <-ds.stopChan: + return + case <-ticker.C: + ds.broadcastUpdate() + } + } +} + +// GossipMessage represents a gossip message. +type GossipMessage struct { + Type GossipMessageType `json:"type"` // ping, pong, join, update, sync + FromNode string `json:"from_node"` + SeqNum uint64 `json:"seq_num"` + Timestamp int64 `json:"timestamp"` + Payload []byte `json:"payload,omitempty"` + Nodes []*NodeInfo `json:"nodes,omitempty"` // For memberlist exchange + AuthToken *AuthToken `json:"auth_token,omitempty"` +} + +// handleGossip handles an incoming gossip message. +func (ds *DiscoveryService) handleGossip(data []byte, addr *net.UDPAddr) { + var msg GossipMessage + if err := json.Unmarshal(data, &msg); err != nil { + return + } + + // Verify authentication if required + if ds.config.Discovery.RequireAuth { + if msg.AuthToken == nil || ds.auth == nil || !ds.auth.VerifyToken(msg.AuthToken) { + logger.WarnCF("swarm", "Rejected unauthenticated message", map[string]any{"from": addr.String()}) + return + } + } + + // Verify message signature if enabled + if ds.config.Discovery.EnableMessageSigning && msg.AuthToken != nil { + // The signature is in the token, so verification above handles it + } + + switch msg.Type { + case GossipTypePing: + ds.handlePing(msg, addr) + case GossipTypePong: + ds.handlePong(msg) + case GossipTypeJoin: + ds.handleJoin(msg, addr) + case GossipTypeUpdate: + ds.handleUpdate(msg) + case GossipTypeSync: + ds.handleSync(msg, addr) + } +} + +// handlePing handles a ping message. +func (ds *DiscoveryService) handlePing(msg GossipMessage, addr *net.UDPAddr) { + // Respond with pong + pong := GossipMessage{ + Type: GossipTypePong, + FromNode: ds.localNode.ID, + Timestamp: time.Now().UnixNano(), + } + + data, err := json.Marshal(pong) + if err != nil { + logger.ErrorCF("swarm", "failed to marshal pong message", map[string]any{"error": err}) + return + } + if _, err := ds.conn.WriteToUDP(data, addr); err != nil { + logger.DebugCF("swarm", "failed to send pong", map[string]any{"to": addr.String(), "error": err}) + } + + // Update membership if this is a known node + if len(msg.Nodes) > 0 { + for _, node := range msg.Nodes { + if node.ID != ds.localNode.ID { + ds.membership.UpdateNode(node) + } + } + } +} + +// handlePong handles a pong message. +func (ds *DiscoveryService) handlePong(msg GossipMessage) { + // Update last seen for this node + ds.membership.RecordHeartbeat(msg.FromNode) +} + +// handleJoin handles a join request from a new node. +func (ds *DiscoveryService) handleJoin(msg GossipMessage, addr *net.UDPAddr) { + // Send our member list back + members := ds.Members() + nodes := make([]*NodeInfo, 0, len(members)+1) + nodes = append(nodes, ds.localNode) + for _, m := range members { + if m.Node.ID != ds.localNode.ID { + nodes = append(nodes, m.Node) + } + } + + response := GossipMessage{ + Type: GossipTypeSync, + FromNode: ds.localNode.ID, + Timestamp: time.Now().UnixNano(), + Nodes: nodes, + } + + data, err := json.Marshal(response) + if err != nil { + logger.ErrorCF("swarm", "failed to marshal sync message", map[string]any{"error": err}) + return + } + if _, err := ds.conn.WriteToUDP(data, addr); err != nil { + logger.ErrorCF("swarm", "failed to send sync", map[string]any{"to": addr.String(), "error": err}) + return + } + + // Emit join event + event := &NodeEvent{ + Event: EventJoin, + Time: time.Now().UnixNano(), + } + if len(msg.Nodes) > 0 { + event.Node = msg.Nodes[0] + } + ds.eventHandler.Dispatch(event) +} + +// handleUpdate handles a node update message. +func (ds *DiscoveryService) handleUpdate(msg GossipMessage) { + if len(msg.Nodes) == 0 { + return + } + + for _, node := range msg.Nodes { + if node.ID != ds.localNode.ID { + existing, ok := ds.membership.GetNode(node.ID) + if !ok || node.Timestamp > existing.Node.Timestamp { + ds.membership.UpdateNode(node) + } + } + } +} + +// handleSync handles a sync response with member list. +func (ds *DiscoveryService) handleSync(msg GossipMessage, addr *net.UDPAddr) { + for _, node := range msg.Nodes { + if node.ID != ds.localNode.ID { + ds.membership.UpdateNode(node) + } + } +} + +// broadcastUpdate broadcasts local state to random members. +func (ds *DiscoveryService) broadcastUpdate() { + members := ds.membership.GetMembers() + if len(members) == 0 { + return + } + + msg := GossipMessage{ + Type: GossipTypeUpdate, + FromNode: ds.localNode.ID, + SeqNum: ds.seqNum, + Timestamp: time.Now().UnixNano(), + Nodes: []*NodeInfo{ds.localNode}, + } + + // Add auth token if authentication is enabled + if ds.auth != nil { + token, err := ds.auth.GenerateToken() + if err != nil { + logger.ErrorCF("swarm", "failed to generate auth token", map[string]any{"error": err}) + } else { + msg.AuthToken = token + } + } + + data, err := json.Marshal(msg) + if err != nil { + logger.ErrorCF("swarm", "failed to marshal broadcast update", map[string]any{"error": err}) + return + } + + // Send to a few random members + for _, member := range members { + if member.Node.ID != ds.localNode.ID { + addr := fmt.Sprintf("%s:%d", member.Node.Addr, member.Node.Port) + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + logger.DebugCF("swarm", "failed to resolve address", map[string]any{"address": addr, "error": err}) + continue + } + if _, err := ds.conn.WriteToUDP(data, udpAddr); err != nil { + logger.DebugCF("swarm", "failed to send update", map[string]any{"address": addr, "error": err}) + } + } + } +} + +// sendJoin sends a join message to a specific address. +func (ds *DiscoveryService) sendJoin(ctx context.Context, addr string) error { + joinAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return err + } + + msg := GossipMessage{ + Type: GossipTypeJoin, + FromNode: ds.localNode.ID, + Timestamp: time.Now().UnixNano(), + Nodes: []*NodeInfo{ds.localNode}, + } + + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal join message: %w", err) + } + + // Set deadline + ds.conn.SetWriteDeadline(time.Now().Add(DefaultUDPWriteDeadline)) + _, err = ds.conn.WriteToUDP(data, joinAddr) + if err != nil { + return fmt.Errorf("failed to send join to %s: %w", addr, err) + } + return nil +} + +// getLocalIP returns the local IP address. +func getLocalIP() string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "" + } + + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + return ipnet.IP.String() + } + } + } + return "" +} diff --git a/pkg/swarm/errors.go b/pkg/swarm/errors.go new file mode 100644 index 000000000..ed8cad7c7 --- /dev/null +++ b/pkg/swarm/errors.go @@ -0,0 +1,50 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import "errors" + +var ( + // ErrNodeNotFound is returned when a node is not found in the cluster. + ErrNodeNotFound = errors.New("node not found") + + // ErrNodeNotAvailable is returned when a node is not available for handoff. + ErrNodeNotAvailable = errors.New("node not available") + + // ErrNoHealthyNodes is returned when no healthy nodes are available. + ErrNoHealthyNodes = errors.New("no healthy nodes available") + + // ErrHandoffTimeout is returned when a handoff operation times out. + ErrHandoffTimeout = errors.New("handoff timeout") + + // ErrHandoffRejected is returned when a handoff is rejected by the target node. + ErrHandoffRejected = errors.New("handoff rejected") + + // ErrHandoffInProgress is returned when a handoff is already in progress. + ErrHandoffInProgress = errors.New("handoff already in progress") + + // ErrInvalidNodeInfo is returned when node information is invalid. + ErrInvalidNodeInfo = errors.New("invalid node information") + + // ErrDiscoveryDisabled is returned when discovery is disabled. + ErrDiscoveryDisabled = errors.New("discovery disabled") + + // ErrTransportClosed is returned when the transport is closed. + ErrTransportClosed = errors.New("transport closed") + + // ErrSessionNotFound is returned when a session is not found. + ErrSessionNotFound = errors.New("session not found") + + // ErrCapabilityNotSupported is returned when a required capability is not supported. + ErrCapabilityNotSupported = errors.New("capability not supported") + + // ErrAuthenticationFailed is returned when authentication fails. + ErrAuthenticationFailed = errors.New("authentication failed") + + // ErrInvalidSignature is returned when a signature verification fails. + ErrInvalidSignature = errors.New("invalid signature") +) diff --git a/pkg/swarm/handoff.go b/pkg/swarm/handoff.go new file mode 100644 index 000000000..718436edb --- /dev/null +++ b/pkg/swarm/handoff.go @@ -0,0 +1,326 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/google/uuid" + + picolib "github.com/sipeed/picoclaw/pkg/pico" +) + +// HandoffReason represents the reason for a handoff. +type HandoffReason string + +const ( + ReasonOverloaded HandoffReason = "overloaded" // Load is too high + ReasonNoCapability HandoffReason = "no_capability" // Missing capability + ReasonUserRequest HandoffReason = "user_request" // User explicitly requested + ReasonNodeLeave HandoffReason = "node_leave" // Node is leaving + ReasonShutdown HandoffReason = "shutdown" // Graceful shutdown +) + +// HandoffState represents the state of a handoff operation. +type HandoffState string + +const ( + HandoffStatePending HandoffState = "pending" + HandoffStateAccepted HandoffState = "accepted" + HandoffStateRejected HandoffState = "rejected" + HandoffStateCompleted HandoffState = "completed" + HandoffStateFailed HandoffState = "failed" + HandoffStateTimeout HandoffState = "timeout" +) + +// HandoffRequest represents a request to hand off a session. +type HandoffRequest struct { + RequestID string `json:"request_id"` + Reason HandoffReason `json:"reason"` + SessionKey string `json:"session_key"` + SessionMessages []picolib.SessionMessage `json:"session_messages,omitempty"` + Context map[string]any `json:"context,omitempty"` + RequiredCap string `json:"required_cap,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + FromNodeID string `json:"from_node_id"` + FromNodeAddr string `json:"from_node_addr"` + Timestamp int64 `json:"timestamp"` +} + +// HandoffResponse represents the response to a handoff request. +type HandoffResponse struct { + RequestID string `json:"request_id"` + Accepted bool `json:"accepted"` + NodeID string `json:"node_id"` + Reason string `json:"reason,omitempty"` + SessionKey string `json:"session_key,omitempty"` // New session key on target + Timestamp int64 `json:"timestamp"` + State HandoffState `json:"state"` +} + +// HandoffSendFunc is the function signature for sending a handoff request to a +// target node via Pico channel. The caller provides the target node address and +// the request; the function returns the response synchronously. +type HandoffSendFunc func(ctx context.Context, targetAddr string, req *HandoffRequest) (*HandoffResponse, error) + +// HandoffCoordinator coordinates handoff operations between nodes. +// Communication is handled via an injected send function (typically backed by PicoNodeClient). +type HandoffCoordinator struct { + discovery *DiscoveryService + membership *MembershipManager + config HandoffConfig + + pending map[string]*HandoffOperation // request_id -> operation + mu sync.RWMutex + + // Injected communication function (set via SetSendFunc) + sendRequestFn HandoffSendFunc + + // Accept/reject callbacks + onHandoffRequest func(*HandoffRequest) *HandoffResponse + onHandoffComplete func(*HandoffRequest, *HandoffResponse) +} + +// HandoffOperation represents an ongoing handoff operation. +type HandoffOperation struct { + Request *HandoffRequest + Response *HandoffResponse + State HandoffState + StartTime time.Time + LastUpdate time.Time + RetryCount int + TargetNode *NodeWithState +} + +// NewHandoffCoordinator creates a new handoff coordinator. +// Unlike the previous version, this does NOT open a UDP socket. +// Communication is injected via SetSendFunc. +func NewHandoffCoordinator(ds *DiscoveryService, config HandoffConfig) *HandoffCoordinator { + return &HandoffCoordinator{ + discovery: ds, + membership: ds.membership, + config: config, + pending: make(map[string]*HandoffOperation), + } +} + +// Close cleans up the handoff coordinator. +func (hc *HandoffCoordinator) Close() error { + return nil +} + +// SetSendFunc injects the function used to send handoff requests to remote nodes. +func (hc *HandoffCoordinator) SetSendFunc(fn HandoffSendFunc) { + hc.sendRequestFn = fn +} + +// CanHandle checks if the local node can handle a request. +func (hc *HandoffCoordinator) CanHandle(requiredCap string) bool { + if !hc.config.Enabled { + return false + } + + // Check load + loadScore := hc.discovery.localNode.LoadScore + if loadScore > hc.config.LoadThreshold { + return false + } + + // Check capability + if requiredCap != "" { + hasCap := false + for _, cap := range hc.discovery.localNode.AgentCaps { + if cap == requiredCap { + hasCap = true + break + } + } + if !hasCap { + return false + } + } + + return true +} + +// InitiateHandoff initiates a handoff to another node. +func (hc *HandoffCoordinator) InitiateHandoff(ctx context.Context, req *HandoffRequest) (*HandoffResponse, error) { + if hc.sendRequestFn == nil { + return nil, fmt.Errorf("handoff send function not configured") + } + + if req.RequestID == "" { + req.RequestID = uuid.New().String() + } + + req.FromNodeID = hc.discovery.localNode.ID + req.FromNodeAddr = hc.discovery.localNode.Addr + req.Timestamp = time.Now().UnixNano() + + // Find target node + targetNode, findErr := hc.findTargetNode(req) + if findErr != nil { + // Intentionally return nil error: convert internal error to a rejected response. + reason := findErr.Error() + return &HandoffResponse{ //nolint:nilerr // intentional: wraps error as rejected response + RequestID: req.RequestID, + Accepted: false, + Reason: reason, + State: HandoffStateFailed, + }, nil + } + + // Create operation + op := &HandoffOperation{ + Request: req, + State: HandoffStatePending, + StartTime: time.Now(), + LastUpdate: time.Now(), + TargetNode: targetNode, + } + + hc.mu.Lock() + hc.pending[req.RequestID] = op + hc.mu.Unlock() + + // Build target address (use HTTPPort for Pico channel) + targetAddr := picolib.BuildNodeAddr(targetNode.Node.Addr, targetNode.Node.HTTPPort) + + // Send request synchronously via Pico + resp, err := hc.sendRequestFn(ctx, targetAddr, req) + if err != nil { + // First attempt failed, try retries + resp = &HandoffResponse{ + RequestID: req.RequestID, + Accepted: false, + Reason: err.Error(), + State: HandoffStateFailed, + } + } + + // Retry if needed + for !resp.Accepted && op.RetryCount < hc.config.MaxRetries { + op.RetryCount++ + + // Find new target + newTarget, findErr := hc.findTargetNode(req) + if findErr != nil { + continue + } + op.TargetNode = newTarget + + // Delay before retry + time.Sleep(hc.config.RetryDelay.Duration) + + newAddr := picolib.BuildNodeAddr(newTarget.Node.Addr, newTarget.Node.HTTPPort) + retryResp, retryErr := hc.sendRequestFn(ctx, newAddr, req) + if retryErr != nil { + resp = &HandoffResponse{ + RequestID: req.RequestID, + Accepted: false, + Reason: retryErr.Error(), + State: HandoffStateFailed, + } + continue + } + + resp = retryResp + if resp.Accepted { + break + } + } + + // Clean up + hc.mu.Lock() + delete(hc.pending, req.RequestID) + hc.mu.Unlock() + + // Notify callback + if hc.onHandoffComplete != nil { + go hc.onHandoffComplete(req, resp) + } + + return resp, nil +} + +// HandleIncomingHandoff handles a handoff request received from another node +// (via the Pico channel node request handler). Returns a HandoffResponse. +func (hc *HandoffCoordinator) HandleIncomingHandoff(req *HandoffRequest) *HandoffResponse { + // Check if we can handle it + accepted := hc.CanHandle(req.RequiredCap) + response := &HandoffResponse{ + RequestID: req.RequestID, + Accepted: accepted, + NodeID: hc.discovery.localNode.ID, + State: HandoffStateAccepted, + Timestamp: time.Now().UnixNano(), + } + + if !accepted { + response.Reason = "cannot handle (overloaded or missing capability)" + response.State = HandoffStateRejected + } + + // Call custom handler if set + if hc.onHandoffRequest != nil { + response = hc.onHandoffRequest(req) + } + + return response +} + +// findTargetNode finds a suitable target node for handoff. +func (hc *HandoffCoordinator) findTargetNode(req *HandoffRequest) (*NodeWithState, error) { + var candidates []*NodeWithState + + if req.RequiredCap != "" { + // Find nodes with required capability + candidates = hc.membership.SelectByCapability([]string{req.RequiredCap}) + } else { + // Find all available nodes + candidates = hc.membership.GetAvailableMembers() + } + + if len(candidates) == 0 { + return nil, ErrNoHealthyNodes + } + + // Select least loaded node + target := candidates[0] + for _, c := range candidates[1:] { + if c.Node.LoadScore < target.Node.LoadScore { + target = c + } + } + + return target, nil +} + +// SetRequestHandler sets a custom handler for handoff requests. +func (hc *HandoffCoordinator) SetRequestHandler(handler func(*HandoffRequest) *HandoffResponse) { + hc.onHandoffRequest = handler +} + +// SetCompleteHandler sets a callback for handoff completion. +func (hc *HandoffCoordinator) SetCompleteHandler(handler func(*HandoffRequest, *HandoffResponse)) { + hc.onHandoffComplete = handler +} + +// GetPending returns all pending handoff operations. +func (hc *HandoffCoordinator) GetPending() []*HandoffOperation { + hc.mu.RLock() + defer hc.mu.RUnlock() + + result := make([]*HandoffOperation, 0, len(hc.pending)) + for _, op := range hc.pending { + result = append(result, op) + } + return result +} diff --git a/pkg/swarm/leader_election.go b/pkg/swarm/leader_election.go new file mode 100644 index 000000000..58dab483d --- /dev/null +++ b/pkg/swarm/leader_election.go @@ -0,0 +1,268 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// LeaderElection handles leader election using a simple bully algorithm variant. +type LeaderElection struct { + localNodeID string + membership *MembershipManager + config LeaderElectionConfig + + mu sync.RWMutex + currentLeader string + isLeader bool + electionInProgress bool + leaderChangeCh chan string + stopCh chan struct{} +} + +// NewLeaderElection creates a new leader election instance. +func NewLeaderElection(nodeID string, membership *MembershipManager, config LeaderElectionConfig) *LeaderElection { + return &LeaderElection{ + localNodeID: nodeID, + membership: membership, + config: config, + leaderChangeCh: make(chan string, 10), + stopCh: make(chan struct{}), + } +} + +// Start starts the leader election process. +func (le *LeaderElection) Start() { + // Start election checker + go le.electionChecker() + + // Start leader heartbeat monitor + go le.leaderMonitor() +} + +// Stop stops the leader election process. +func (le *LeaderElection) Stop() { + close(le.stopCh) +} + +// IsLeader returns true if this node is the current leader. +func (le *LeaderElection) IsLeader() bool { + le.mu.RLock() + defer le.mu.RUnlock() + return le.isLeader +} + +// GetLeader returns the current leader ID. +func (le *LeaderElection) GetLeader() string { + le.mu.RLock() + defer le.mu.RUnlock() + return le.currentLeader +} + +// LeaderChanges returns a channel that receives leader ID changes. +func (le *LeaderElection) LeaderChanges() <-chan string { + return le.leaderChangeCh +} + +// electionChecker periodically checks if we should become leader. +func (le *LeaderElection) electionChecker() { + interval := le.config.ElectionInterval.Duration + if interval <= 0 { + interval = 5 * time.Second + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-le.stopCh: + return + case <-ticker.C: + le.checkElection() + } + } +} + +// checkElection runs the leader election algorithm. +func (le *LeaderElection) checkElection() { + le.mu.Lock() + defer le.mu.Unlock() + + members := le.membership.GetMembers() + + // Filter to only alive nodes for leader candidacy + aliveMembers := make([]*NodeWithState, 0, len(members)) + for _, m := range members { + if m.State != nil && m.State.Status == NodeStatusAlive { + aliveMembers = append(aliveMembers, m) + } + } + + if len(aliveMembers) == 0 { + // No alive members in view (including self not yet registered), become leader as fallback + le.becomeLeader() + return + } + + // Find the node with the lowest ID (simple deterministic leader selection) + candidateID := le.localNodeID + for _, m := range aliveMembers { + if m.Node.ID < candidateID { + candidateID = m.Node.ID + } + } + + // Update current leader + if le.currentLeader != candidateID { + le.currentLeader = candidateID + + if candidateID == le.localNodeID { + le.becomeLeader() + } else { + le.becomeFollower() + } + } +} + +// becomeLeader marks this node as the leader. +func (le *LeaderElection) becomeLeader() { + if !le.isLeader { + le.isLeader = true + logger.InfoCF("swarm", "This node is now the leader", map[string]any{"node_id": le.localNodeID}) + + // Notify listeners (non-blocking) + select { + case le.leaderChangeCh <- le.localNodeID: + default: + logger.WarnC("swarm", "Leader change notification dropped, channel full") + } + } +} + +// becomeFollower marks this node as a follower. +func (le *LeaderElection) becomeFollower() { + wasLeader := le.isLeader + le.isLeader = false + + if wasLeader { + logger.InfoCF("swarm", "This node is now a follower", map[string]any{ + "node_id": le.localNodeID, + "new_leader": le.currentLeader, + }) + + // Notify listeners of leader change (non-blocking) + select { + case le.leaderChangeCh <- le.currentLeader: + default: + logger.WarnC("swarm", "Leader change notification dropped, channel full") + } + } +} + +// leaderMonitor monitors if the current leader is still alive. +func (le *LeaderElection) leaderMonitor() { + interval := le.config.LeaderHeartbeatTimeout.Duration + if interval <= 0 { + interval = 10 * time.Second + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-le.stopCh: + return + case <-ticker.C: + le.monitorLeader() + } + } +} + +// monitorLeader checks if the current leader is still alive. +func (le *LeaderElection) monitorLeader() { + le.mu.RLock() + leaderID := le.currentLeader + amLeader := le.isLeader + le.mu.RUnlock() + + if amLeader || leaderID == "" { + return + } + + // Check if leader is still in the membership and healthy + needReelection := false + node, exists := le.membership.GetNode(leaderID) + if !exists { + logger.WarnCF("swarm", "Leader no longer in membership, triggering reelection", + map[string]any{"leader_id": leaderID}) + needReelection = true + } else if node.State != nil && node.State.Status != NodeStatusAlive { + logger.WarnCF("swarm", "Leader is no longer alive, triggering reelection", + map[string]any{"leader_id": leaderID, "status": node.State.Status}) + needReelection = true + } + + if needReelection { + le.mu.Lock() + le.currentLeader = "" + le.mu.Unlock() + le.checkElection() + } +} + +// ElectLeader triggers a new leader election. +func (le *LeaderElection) ElectLeader(ctx context.Context) (string, error) { + le.mu.Lock() + le.currentLeader = "" // Clear current leader to trigger reelection + le.mu.Unlock() + + // Run election immediately + le.checkElection() + + // Wait for new leader + ticker := time.NewTicker(time.Millisecond * 100) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-ticker.C: + le.mu.RLock() + leader := le.currentLeader + le.mu.RUnlock() + + if leader != "" { + return leader, nil + } + } + } +} + +// LeadershipState represents the current leadership state. +type LeadershipState struct { + LeaderID string `json:"leader_id"` + IsLeader bool `json:"is_leader"` + LastChange time.Time `json:"last_change"` + MemberCount int `json:"member_count"` +} + +// GetState returns the current leadership state. +func (le *LeaderElection) GetState() LeadershipState { + le.mu.RLock() + defer le.mu.RUnlock() + + return LeadershipState{ + LeaderID: le.currentLeader, + IsLeader: le.isLeader, + MemberCount: len(le.membership.GetMembers()), + } +} diff --git a/pkg/swarm/load_monitor.go b/pkg/swarm/load_monitor.go new file mode 100644 index 000000000..250060d03 --- /dev/null +++ b/pkg/swarm/load_monitor.go @@ -0,0 +1,295 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "runtime" + "sync" + "time" +) + +// LoadMonitor monitors system load and calculates a load score. +type LoadMonitor struct { + config *LoadMonitorConfig + samples []float64 + mu sync.RWMutex + sessionCount int + ticker *time.Ticker + stopChan chan struct{} + onThreshold []func(float64) +} + +// NewLoadMonitor creates a new load monitor. +func NewLoadMonitor(config *LoadMonitorConfig) *LoadMonitor { + if config.SampleSize <= 0 { + config.SampleSize = 60 + } + if config.Interval.Duration <= 0 { + config.Interval = Duration{5 * time.Second} + } + + lm := &LoadMonitor{ + config: config, + samples: make([]float64, 0, config.SampleSize), + stopChan: make(chan struct{}), + onThreshold: make([]func(float64), 0), + } + return lm +} + +// Start begins monitoring load. +func (lm *LoadMonitor) Start() { + if lm.ticker != nil { + return + } + + lm.ticker = time.NewTicker(lm.config.Interval.Duration) + go lm.run() +} + +// Stop stops monitoring load. +func (lm *LoadMonitor) Stop() { + if lm.ticker == nil { + return + } + + lm.ticker.Stop() + close(lm.stopChan) + lm.ticker = nil +} + +// run is the main monitoring loop. +func (lm *LoadMonitor) run() { + for { + select { + case <-lm.ticker.C: + score := lm.calculateScore() + lm.addSample(score) + + // Check threshold callbacks + if lm.shouldOffload() { + lm.mu.RLock() + callbacks := make([]func(float64), len(lm.onThreshold)) + copy(callbacks, lm.onThreshold) + lm.mu.RUnlock() + + for _, cb := range callbacks { + go cb(score) + } + } + case <-lm.stopChan: + return + } + } +} + +// LoadMetrics represents current load metrics. +type LoadMetrics struct { + CPUUsage float64 `json:"cpu_usage"` + MemoryUsage float64 `json:"memory_usage"` + ActiveSessions int `json:"active_sessions"` + Goroutines int `json:"goroutines"` + Score float64 `json:"score"` + Timestamp int64 `json:"timestamp"` +} + +// GetCurrentLoad returns the current load metrics. +func (lm *LoadMonitor) GetCurrentLoad() *LoadMetrics { + metrics := &LoadMetrics{ + ActiveSessions: lm.GetSessionCount(), + Goroutines: runtime.NumGoroutine(), + Timestamp: time.Now().UnixNano(), + } + + // Get memory usage + var m runtime.MemStats + runtime.ReadMemStats(&m) + + // Normalize using configured thresholds + maxMem := lm.config.MaxMemoryBytes + if maxMem == 0 { + maxMem = 1024 * 1024 * 1024 // Default 1GB + } + metrics.MemoryUsage = normalizeMemory(m.Alloc, maxMem) + + maxGoroutines := lm.config.MaxGoroutines + if maxGoroutines == 0 { + maxGoroutines = 1000 + } + metrics.CPUUsage = normalizeCPU(metrics.Goroutines, maxGoroutines) + + maxSessions := lm.config.MaxSessions + if maxSessions == 0 { + maxSessions = 100 + } + sessionUsage := normalizeSessions(metrics.ActiveSessions, maxSessions) + + // Calculate weighted score + config := lm.config + metrics.Score = (metrics.CPUUsage * config.CPUWeight) + + (metrics.MemoryUsage * config.MemoryWeight) + + (sessionUsage * config.SessionWeight) + + // Clamp score to [0, 1] + if metrics.Score < 0 { + metrics.Score = 0 + } else if metrics.Score > 1 { + metrics.Score = 1 + } + + return metrics +} + +// calculateScore calculates the current load score. +func (lm *LoadMonitor) calculateScore() float64 { + return lm.GetCurrentLoad().Score +} + +// addSample adds a load sample to the history. +func (lm *LoadMonitor) addSample(score float64) { + lm.mu.Lock() + defer lm.mu.Unlock() + + lm.samples = append(lm.samples, score) + if len(lm.samples) > lm.config.SampleSize { + lm.samples = lm.samples[1:] + } +} + +// GetAverageScore returns the average load score over the sample window. +func (lm *LoadMonitor) GetAverageScore() float64 { + lm.mu.RLock() + defer lm.mu.RUnlock() + + if len(lm.samples) == 0 { + return lm.calculateScore() + } + + sum := 0.0 + for _, s := range lm.samples { + sum += s + } + return sum / float64(len(lm.samples)) +} + +// GetSessionCount returns the current number of active sessions. +func (lm *LoadMonitor) GetSessionCount() int { + lm.mu.RLock() + defer lm.mu.RUnlock() + return lm.sessionCount +} + +// SetSessionCount sets the current number of active sessions. +func (lm *LoadMonitor) SetSessionCount(count int) { + lm.mu.Lock() + defer lm.mu.Unlock() + lm.sessionCount = count +} + +// IncrementSessions increments the session count. +func (lm *LoadMonitor) IncrementSessions() { + lm.mu.Lock() + defer lm.mu.Unlock() + lm.sessionCount++ +} + +// DecrementSessions decrements the session count. +func (lm *LoadMonitor) DecrementSessions() { + lm.mu.Lock() + defer lm.mu.Unlock() + if lm.sessionCount > 0 { + lm.sessionCount-- + } +} + +// ShouldOffload returns true if the load is high enough to offload tasks. +func (lm *LoadMonitor) ShouldOffload() bool { + return lm.shouldOffload() +} + +// shouldOffload internal check for offloading. +func (lm *LoadMonitor) shouldOffload() bool { + avgScore := lm.GetAverageScore() + currentScore := lm.calculateScore() + + // Use configured offload threshold, or default to 0.8 + threshold := lm.config.OffloadThreshold + if threshold <= 0 { + threshold = 0.8 + } + + // Use a combination of current and average for smoother behavior + combinedScore := (currentScore*0.7 + avgScore*0.3) + return combinedScore > threshold +} + +// OnThreshold registers a callback when the load threshold is exceeded. +func (lm *LoadMonitor) OnThreshold(callback func(float64)) { + lm.mu.Lock() + defer lm.mu.Unlock() + lm.onThreshold = append(lm.onThreshold, callback) +} + +// GetTrend returns the load trend. +func (lm *LoadMonitor) GetTrend() LoadTrend { + lm.mu.RLock() + defer lm.mu.RUnlock() + + if len(lm.samples) < 3 { + return LoadTrendStable + } + + // Simple linear regression to detect trend + n := float64(len(lm.samples)) + sumX := n * (n - 1) / 2 + sumY := 0.0 + sumXY := 0.0 + + for i, s := range lm.samples { + x := float64(i) + sumY += s + sumXY += x * s + } + + slope := (n*sumXY - sumX*sumY) / (n * (n - 1) * (2*n - 1) / 6) + + if slope > TrendIncreasingThreshold { + return LoadTrendIncreasing + } else if slope < TrendDecreasingThreshold { + return LoadTrendDecreasing + } + return LoadTrendStable +} + +// Helper functions for normalization + +func normalizeMemory(alloc uint64, maxMem uint64) float64 { + // Use configured max memory threshold + usage := float64(alloc) / float64(maxMem) + if usage > 1 { + return 1 + } + return usage +} + +func normalizeCPU(goroutines int, maxGoroutines int) float64 { + // Use configured max goroutine threshold + usage := float64(goroutines) / float64(maxGoroutines) + if usage > 1 { + return 1 + } + return usage +} + +func normalizeSessions(sessions int, maxSessions int) float64 { + // Use configured max sessions threshold + usage := float64(sessions) / float64(maxSessions) + if usage > 1 { + return 1 + } + return usage +} diff --git a/pkg/swarm/membership.go b/pkg/swarm/membership.go new file mode 100644 index 000000000..ffa04c23c --- /dev/null +++ b/pkg/swarm/membership.go @@ -0,0 +1,365 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "math/rand" + "sync" + "time" +) + +// MembershipManager manages cluster membership. +type MembershipManager struct { + discovery *DiscoveryService + view *ClusterView + config DiscoveryConfig + mu sync.RWMutex + + // Event callbacks + onJoin []func(*NodeInfo) + onLeave []func(*NodeInfo) + onUpdate []func(*NodeInfo) +} + +// NewMembershipManager creates a new membership manager. +func NewMembershipManager(ds *DiscoveryService, config DiscoveryConfig) *MembershipManager { + localNodeID := ds.LocalNode().ID + return &MembershipManager{ + discovery: ds, + view: NewClusterView(localNodeID), + config: config, + onJoin: make([]func(*NodeInfo), 0), + onLeave: make([]func(*NodeInfo), 0), + onUpdate: make([]func(*NodeInfo), 0), + } +} + +// GetNode retrieves a node by ID. +func (mm *MembershipManager) GetNode(nodeID string) (*NodeWithState, bool) { + return mm.view.Get(nodeID) +} + +// GetMembers returns all members. +func (mm *MembershipManager) GetMembers() []*NodeWithState { + mm.mu.RLock() + defer mm.mu.RUnlock() + return mm.view.List() +} + +// GetAliveMembers returns all alive members. +func (mm *MembershipManager) GetAliveMembers() []*NodeWithState { + mm.mu.RLock() + defer mm.mu.RUnlock() + + members := mm.view.GetAliveNodes() + result := make([]*NodeWithState, 0, len(members)) + for _, m := range members { + if m.Node.ID != mm.discovery.LocalNode().ID { + result = append(result, m) + } + } + return result +} + +// GetAvailableMembers returns all available members (alive and not overloaded). +func (mm *MembershipManager) GetAvailableMembers() []*NodeWithState { + mm.mu.RLock() + defer mm.mu.RUnlock() + + members := mm.view.GetAvailableNodes() + result := make([]*NodeWithState, 0, len(members)) + for _, m := range members { + if m.Node.ID != mm.discovery.LocalNode().ID { + result = append(result, m) + } + } + return result +} + +// UpdateNode updates or adds a node to the membership. +func (mm *MembershipManager) UpdateNode(node *NodeInfo) *NodeWithState { + mm.mu.Lock() + + existing, existed := mm.view.Get(node.ID) + nws := mm.view.AddOrUpdate(node) + + if !existed { + // New node joined + nws.State.Status = NodeStatusAlive + nws.State.StatusSince = time.Now().UnixNano() + nws.State.LastSeen = time.Now().UnixNano() + + // Notify callbacks + for _, cb := range mm.onJoin { + go cb(node) + } + + // Dispatch event + mm.discovery.eventHandler.Dispatch(&NodeEvent{ + Node: node, + Event: EventJoin, + Time: time.Now().UnixNano(), + }) + } else { + // Existing node updated + if existing.Node.Timestamp < node.Timestamp { + nws.State.LastSeen = time.Now().UnixNano() + + // Mark as alive if was suspect/dead + if nws.State.Status != NodeStatusAlive { + nws.State.UpdateStatus(NodeStatusAlive) + nws.State.PingFailure = 0 + nws.State.PingSuccess++ + } + + // Notify callbacks + for _, cb := range mm.onUpdate { + go cb(node) + } + + // Dispatch event + mm.discovery.eventHandler.Dispatch(&NodeEvent{ + Node: node, + Event: EventUpdate, + Time: time.Now().UnixNano(), + }) + } + } + + mm.mu.Unlock() + return nws +} + +// RemoveNode removes a node from the membership. +func (mm *MembershipManager) RemoveNode(nodeID string) { + mm.mu.Lock() + + nws, exists := mm.view.Get(nodeID) + if !exists { + mm.mu.Unlock() + return + } + + mm.view.Remove(nodeID) + + // Notify callbacks + for _, cb := range mm.onLeave { + go cb(nws.Node) + } + + // Dispatch event + mm.discovery.eventHandler.Dispatch(&NodeEvent{ + Node: nws.Node, + Event: EventLeave, + Time: time.Now().UnixNano(), + }) + + mm.mu.Unlock() +} + +// RecordHeartbeat records a heartbeat for a node. +func (mm *MembershipManager) RecordHeartbeat(nodeID string) { + mm.mu.Lock() + defer mm.mu.Unlock() + + nws, exists := mm.view.Get(nodeID) + if !exists { + return + } + + nws.State.LastPing = time.Now().UnixNano() + nws.State.LastSeen = time.Now().UnixNano() + + // Reset failure count and increment success + nws.State.PingFailure = 0 + nws.State.PingSuccess++ + + // Mark as alive if was suspect + if nws.State.Status != NodeStatusAlive { + nws.State.UpdateStatus(NodeStatusAlive) + } +} + +// MarkSuspect marks a node as suspect (possibly dead). +func (mm *MembershipManager) MarkSuspect(nodeID string) { + mm.mu.Lock() + defer mm.mu.Unlock() + + nws, exists := mm.view.Get(nodeID) + if !exists { + return + } + + if nws.State.Status == NodeStatusAlive { + nws.State.UpdateStatus(NodeStatusSuspect) + nws.State.PingFailure++ + } +} + +// MarkDead marks a node as dead. +func (mm *MembershipManager) MarkDead(nodeID string) { + mm.mu.Lock() + defer mm.mu.Unlock() + + nws, exists := mm.view.Get(nodeID) + if !exists { + return + } + + if nws.State.Status != NodeStatusDead { + nws.State.UpdateStatus(NodeStatusDead) + + // Remove from view after a delay + go func() { + time.Sleep(mm.config.DeadNodeTimeout.Duration) + mm.RemoveNode(nodeID) + }() + } +} + +// CheckHealth checks the health of all members and marks dead nodes. +func (mm *MembershipManager) CheckHealth() { + mm.mu.RLock() + members := mm.view.List() + nodeTimeout := mm.config.NodeTimeout.Duration + deadTimeout := mm.config.DeadNodeTimeout.Duration + localNodeID := mm.discovery.LocalNode().ID + mm.mu.RUnlock() + + now := time.Now() + + for _, m := range members { + // Skip local node + if m.Node.ID == localNodeID { + continue + } + + lastSeen := time.Unix(0, m.State.LastSeen) + age := now.Sub(lastSeen) + + switch m.State.Status { + case NodeStatusAlive: + if age > nodeTimeout { + mm.MarkSuspect(m.Node.ID) + } + case NodeStatusSuspect: + if age > deadTimeout { + mm.MarkDead(m.Node.ID) + } + } + } +} + +// SelectByCapability selects members that have the required capabilities. +func (mm *MembershipManager) SelectByCapability(requiredCaps []string) []*NodeWithState { + mm.mu.RLock() + defer mm.mu.RUnlock() + + if len(requiredCaps) == 0 { + return mm.GetAvailableMembers() + } + + members := mm.view.GetAvailableNodes() + result := make([]*NodeWithState, 0) + + for _, m := range members { + if m.Node.ID == mm.discovery.LocalNode().ID { + continue + } + + // Check if node has all required capabilities + hasAll := true + for _, cap := range requiredCaps { + found := false + for _, nodeCap := range m.Node.AgentCaps { + if nodeCap == cap { + found = true + break + } + } + if !found { + hasAll = false + break + } + } + + if hasAll { + result = append(result, m) + } + } + + return result +} + +// SelectLeastLoaded selects the member with the lowest load score. +func (mm *MembershipManager) SelectLeastLoaded() *NodeWithState { + members := mm.GetAvailableMembers() + if len(members) == 0 { + return nil + } + + least := members[0] + for _, m := range members[1:] { + if m.Node.LoadScore < least.Node.LoadScore { + least = m + } + } + + return least +} + +// SelectRandom selects a random available member. +func (mm *MembershipManager) SelectRandom() *NodeWithState { + members := mm.GetAvailableMembers() + if len(members) == 0 { + return nil + } + + // Use crypto/rand for better random distribution + idx := rand.Intn(len(members)) + return members[idx] +} + +// GetClusterSize returns the current cluster size. +func (mm *MembershipManager) GetClusterSize() int { + mm.mu.RLock() + defer mm.mu.RUnlock() + return mm.view.Size +} + +// OnJoin registers a callback for node join events. +func (mm *MembershipManager) OnJoin(callback func(*NodeInfo)) { + mm.mu.Lock() + defer mm.mu.Unlock() + mm.onJoin = append(mm.onJoin, callback) +} + +// OnLeave registers a callback for node leave events. +func (mm *MembershipManager) OnLeave(callback func(*NodeInfo)) { + mm.mu.Lock() + defer mm.mu.Unlock() + mm.onLeave = append(mm.onLeave, callback) +} + +// OnUpdate registers a callback for node update events. +func (mm *MembershipManager) OnUpdate(callback func(*NodeInfo)) { + mm.mu.Lock() + defer mm.mu.Unlock() + mm.onUpdate = append(mm.onUpdate, callback) +} + +// StartHealthCheck starts the health check routine. +func (mm *MembershipManager) StartHealthCheck(interval time.Duration) { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for range ticker.C { + mm.CheckHealth() + } + }() +} diff --git a/pkg/swarm/metrics.go b/pkg/swarm/metrics.go new file mode 100644 index 000000000..f795a553f --- /dev/null +++ b/pkg/swarm/metrics.go @@ -0,0 +1,286 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "encoding/json" + "fmt" + "sync" + "sync/atomic" + "time" +) + +// MetricsCollector collects and exports metrics for the swarm cluster. +type MetricsCollector struct { + mu sync.RWMutex + + // Counters (atomic for performance) + messagesSent atomic.Int64 + messagesReceived atomic.Int64 + handoffsInitiated atomic.Int64 + handoffsAccepted atomic.Int64 + handoffsRejected atomic.Int64 + handoffsFailed atomic.Int64 + electionsWon atomic.Int64 + + // Gauges (use atomic.Value for float64) + currentLoadScore atomic.Value // float64 + activeSessions atomic.Int64 + memberCount atomic.Int32 + + // Histogram data (simplified) + latencyBuckets map[string]*LatencyBucket + + startTime time.Time +} + +// LatencyBucket tracks latency distribution. +type LatencyBucket struct { + mu sync.RWMutex + count int64 + sum int64 + buckets [12]int64 // 0-1ms, 1-2ms, 2-5ms, 5-10ms, 10-20ms, 20-50ms, 50-100ms, 100-200ms, 200-500ms, 500ms-1s, 1-2s, 2s+ +} + +// NewMetricsCollector creates a new metrics collector. +func NewMetricsCollector() *MetricsCollector { + mc := &MetricsCollector{ + latencyBuckets: make(map[string]*LatencyBucket), + startTime: time.Now(), + } + return mc +} + +// Counter methods + +// MessagesSent increments the sent message counter. +func (m *MetricsCollector) MessagesSent(n int64) { + m.messagesSent.Add(n) +} + +// MessagesReceived increments the received message counter. +func (m *MetricsCollector) MessagesReceived(n int64) { + m.messagesReceived.Add(n) +} + +// HandoffInitiated increments the handoff initiated counter. +func (m *MetricsCollector) HandoffInitiated() { + m.handoffsInitiated.Add(1) +} + +// HandoffAccepted increments the handoff accepted counter. +func (m *MetricsCollector) HandoffAccepted() { + m.handoffsAccepted.Add(1) +} + +// HandoffRejected increments the handoff rejected counter. +func (m *MetricsCollector) HandoffRejected() { + m.handoffsRejected.Add(1) +} + +// HandoffFailed increments the handoff failed counter. +func (m *MetricsCollector) HandoffFailed() { + m.handoffsFailed.Add(1) +} + +// ElectionWon increments the elections won counter. +func (m *MetricsCollector) ElectionWon() { + m.electionsWon.Add(1) +} + +// Gauge methods + +// SetLoadScore sets the current load score. +func (m *MetricsCollector) SetLoadScore(score float64) { + m.currentLoadScore.Store(score) +} + +// SetActiveSessions sets the current active session count. +func (m *MetricsCollector) SetActiveSessions(count int64) { + m.activeSessions.Store(count) +} + +// SetMemberCount sets the current cluster member count. +func (m *MetricsCollector) SetMemberCount(count int32) { + m.memberCount.Store(count) +} + +// RecordLatency records a latency observation for the given operation. +func (m *MetricsCollector) RecordLatency(operation string, latency time.Duration) { + m.mu.Lock() + if m.latencyBuckets[operation] == nil { + m.latencyBuckets[operation] = &LatencyBucket{} + } + bucket := m.latencyBuckets[operation] + m.mu.Unlock() + + ms := latency.Milliseconds() + + bucket.mu.Lock() + bucket.count++ + bucket.sum += ms + + // Bucket the latency + switch { + case ms < 1: + bucket.buckets[0]++ + case ms < 2: + bucket.buckets[1]++ + case ms < 5: + bucket.buckets[2]++ + case ms < 10: + bucket.buckets[3]++ + case ms < 20: + bucket.buckets[4]++ + case ms < 50: + bucket.buckets[5]++ + case ms < 100: + bucket.buckets[6]++ + case ms < 200: + bucket.buckets[7]++ + case ms < 500: + bucket.buckets[8]++ + case ms < 1000: + bucket.buckets[9]++ + case ms < 2000: + bucket.buckets[10]++ + default: + bucket.buckets[11]++ + } + bucket.mu.Unlock() +} + +// GetMetrics returns the current metrics as a map. +func (m *MetricsCollector) GetMetrics() map[string]any { + m.mu.RLock() + defer m.mu.RUnlock() + + latency := make(map[string]any) + for name, bucket := range m.latencyBuckets { + bucket.mu.RLock() + latency[name] = map[string]any{ + "count": bucket.count, + "avg_ms": float64(bucket.sum) / float64(bucket.count), + "p50_ms": m.percentile(bucket, 0.50), + "p95_ms": m.percentile(bucket, 0.95), + "p99_ms": m.percentile(bucket, 0.99), + } + bucket.mu.RUnlock() + } + + return map[string]any{ + // Counters + "messages_sent": m.messagesSent.Load(), + "messages_received": m.messagesReceived.Load(), + "handoffs_initiated": m.handoffsInitiated.Load(), + "handoffs_accepted": m.handoffsAccepted.Load(), + "handoffs_rejected": m.handoffsRejected.Load(), + "handoffs_failed": m.handoffsFailed.Load(), + "elections_won": m.electionsWon.Load(), + + // Gauges + "load_score": m.currentLoadScore.Load(), + "active_sessions": m.activeSessions.Load(), + "member_count": m.memberCount.Load(), + + // System info + "uptime_seconds": time.Since(m.startTime).Seconds(), + + // Latency histograms + "latency_ms": latency, + } +} + +// percentile calculates an approximate percentile from the bucket data. +func (m *MetricsCollector) percentile(bucket *LatencyBucket, p float64) float64 { + if bucket.count == 0 { + return 0 + } + + target := int64(float64(bucket.count) * p) + cumulative := int64(0) + + // Upper bounds for each bucket in ms + upperBounds := []int64{1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 1 << 62} + + for i, count := range bucket.buckets { + cumulative += count + if cumulative >= target { + // Return approximate percentile + return float64(upperBounds[i]) + } + } + + return 2000.0 // default max +} + +// ExportJSON exports metrics as JSON. +func (m *MetricsCollector) ExportJSON() ([]byte, error) { + return json.MarshalIndent(m.GetMetrics(), "", " ") +} + +// ExportPrometheus exports metrics in Prometheus text format. +func (m *MetricsCollector) ExportPrometheus() string { + metrics := m.GetMetrics() + var out string + + // Counters as Prometheus counters + out += "# TYPE picoclaw_messages_sent counter\n" + out += fmt.Sprintf("picoclaw_messages_sent %d\n", metrics["messages_sent"]) + + out += "\n# TYPE picoclaw_messages_received counter\n" + out += fmt.Sprintf("picoclaw_messages_received %d\n", metrics["messages_received"]) + + out += "\n# TYPE picoclaw_handoffs_initiated counter\n" + out += fmt.Sprintf("picoclaw_handoffs_initiated %d\n", metrics["handoffs_initiated"]) + + out += "\n# TYPE picoclaw_handoffs_accepted counter\n" + out += fmt.Sprintf("picoclaw_handoffs_accepted %d\n", metrics["handoffs_accepted"]) + + out += "\n# TYPE picoclaw_handoffs_rejected counter\n" + out += fmt.Sprintf("picoclaw_handoffs_rejected %d\n", metrics["handoffs_rejected"]) + + out += "\n# TYPE picoclaw_handoffs_failed counter\n" + out += fmt.Sprintf("picoclaw_handoffs_failed %d\n", metrics["handoffs_failed"]) + + out += "\n# TYPE picoclaw_elections_won counter\n" + out += fmt.Sprintf("picoclaw_elections_won %d\n", metrics["elections_won"]) + + // Gauges as Prometheus gauges + out += "\n# TYPE picoclaw_load_score gauge\n" + out += fmt.Sprintf("picoclaw_load_score %.2f\n", metrics["load_score"]) + + out += "\n# TYPE picoclaw_active_sessions gauge\n" + out += fmt.Sprintf("picoclaw_active_sessions %d\n", metrics["active_sessions"]) + + out += "\n# TYPE picoclaw_member_count gauge\n" + out += fmt.Sprintf("picoclaw_member_count %d\n", metrics["member_count"]) + + out += "\n# TYPE picoclaw_uptime_seconds gauge\n" + out += fmt.Sprintf("picoclaw_uptime_seconds %.0f\n", metrics["uptime_seconds"]) + + return out +} + +// Reset resets all metrics (useful for testing). +func (m *MetricsCollector) Reset() { + m.messagesSent.Store(0) + m.messagesReceived.Store(0) + m.handoffsInitiated.Store(0) + m.handoffsAccepted.Store(0) + m.handoffsRejected.Store(0) + m.handoffsFailed.Store(0) + m.electionsWon.Store(0) + m.currentLoadScore.Store(0) + m.activeSessions.Store(0) + m.memberCount.Store(0) + + m.mu.Lock() + m.latencyBuckets = make(map[string]*LatencyBucket) + m.mu.Unlock() + m.startTime = time.Now() +} diff --git a/pkg/swarm/node.go b/pkg/swarm/node.go new file mode 100644 index 000000000..b92bf008d --- /dev/null +++ b/pkg/swarm/node.go @@ -0,0 +1,330 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// NodeInfo represents a node in the swarm cluster. +type NodeInfo struct { + ID string `json:"id"` // Unique node identifier + Addr string `json:"addr"` // Listening address + Port int `json:"port"` // RPC port + AgentCaps map[string]string `json:"agent_caps"` // Agent capabilities {agent_id: capability} + LoadScore float64 `json:"load_score"` // Load score 0-1 + Labels map[string]string `json:"labels"` // Custom labels + HTTPPort int `json:"http_port"` // HTTP/gateway port (for Pico channel) + Timestamp int64 `json:"timestamp"` // Last update time (Unix nano) + Version string `json:"version"` // PicoClaw version +} + +// IsAlive checks if the node is considered alive based on timestamp. +func (n *NodeInfo) IsAlive(timeout time.Duration) bool { + if n.Timestamp == 0 { + return false + } + age := time.Since(time.Unix(0, n.Timestamp)) + return age < timeout +} + +// String returns a JSON representation of the node. +func (n *NodeInfo) String() string { + data, _ := json.Marshal(n) + return string(data) +} + +// GetAddress returns the full address (host:port) for RPC communication. +func (n *NodeInfo) GetAddress() string { + if n.Port > 0 { + return fmt.Sprintf("%s:%d", n.Addr, n.Port) + } + return n.Addr +} + +// NodeStatus represents the current status of a node. +type NodeStatus string + +const ( + NodeStatusAlive NodeStatus = "alive" + NodeStatusSuspect NodeStatus = "suspect" + NodeStatusDead NodeStatus = "dead" + NodeStatusLeft NodeStatus = "left" +) + +// NodeState represents the state of a node in the membership view. +type NodeState struct { + Node *NodeInfo `json:"node"` + Status NodeStatus `json:"status"` + StatusSince int64 `json:"status_since"` // Unix nano when status was set + LastSeen int64 `json:"last_seen"` // Unix nano of last sighting + LastPing int64 `json:"last_ping"` // Unix nano of last successful ping + PingSuccess int `json:"ping_success"` // Consecutive successful pings + PingFailure int `json:"ping_failure"` // Consecutive failed pings +} + +// IsAvailable returns true if the node is available for handoff. +func (ns *NodeState) IsAvailable() bool { + return ns.Status == NodeStatusAlive && ns.Node.LoadScore < DefaultAvailableLoadThreshold +} + +// UpdateStatus updates the node status with timestamp. +func (ns *NodeState) UpdateStatus(status NodeStatus) { + ns.Status = status + ns.StatusSince = time.Now().UnixNano() +} + +// NodeEvent represents a node state change event. +type NodeEvent struct { + Node *NodeInfo `json:"node"` + Event EventType `json:"event"` + Time int64 `json:"time"` +} + +// EventType represents the type of node event. +type EventType string + +const ( + EventJoin EventType = "join" + EventLeave EventType = "leave" + EventUpdate EventType = "update" +) + +// EventHandler is a callback function for node events. +type EventHandler func(*NodeEvent) + +// EventHandlerID is a unique identifier for a subscribed handler. +type EventHandlerID int + +// EventDispatcher manages event handlers. +type EventDispatcher struct { + handlers []EventHandler + mu sync.RWMutex + nextID EventHandlerID + ids map[EventHandlerID]int // handler ID -> index in handlers slice +} + +// NewEventDispatcher creates a new event dispatcher. +func NewEventDispatcher() *EventDispatcher { + return &EventDispatcher{ + handlers: make([]EventHandler, 0), + ids: make(map[EventHandlerID]int), + nextID: 1, + } +} + +// Subscribe adds a new event handler and returns its ID. +func (ed *EventDispatcher) Subscribe(handler EventHandler) EventHandlerID { + ed.mu.Lock() + defer ed.mu.Unlock() + + id := ed.nextID + ed.nextID++ + + ed.handlers = append(ed.handlers, handler) + ed.ids[id] = len(ed.handlers) - 1 + return id +} + +// Unsubscribe removes an event handler by ID. +func (ed *EventDispatcher) Unsubscribe(id EventHandlerID) { + ed.mu.Lock() + defer ed.mu.Unlock() + + idx, ok := ed.ids[id] + if !ok { + return + } + + // Remove handler + ed.handlers = append(ed.handlers[:idx], ed.handlers[idx+1:]...) + + // Update indices + delete(ed.ids, id) + for handlerID, handlerIdx := range ed.ids { + if handlerIdx > idx { + ed.ids[handlerID] = handlerIdx - 1 + } + } +} + +// Dispatch sends an event to all registered handlers. +func (ed *EventDispatcher) Dispatch(event *NodeEvent) { + ed.DispatchContext(event, nil) +} + +// DispatchContext sends an event to all registered handlers with context cancellation support. +func (ed *EventDispatcher) DispatchContext(event *NodeEvent, ctx context.Context) { + ed.mu.RLock() + handlers := make([]EventHandler, len(ed.handlers)) + copy(handlers, ed.handlers) + ed.mu.RUnlock() + + for _, handler := range handlers { + // Run handlers in goroutines to avoid blocking + go func(h EventHandler) { + defer func() { + if r := recover(); r != nil { + logger.ErrorCF("swarm", "handler panic recovered", map[string]any{"panic": r}) + } + }() + + // Check if context is cancelled + if ctx != nil { + select { + case <-ctx.Done(): + logger.DebugC("swarm", "handler skipped due to context cancellation") + return + default: + } + } + + h(event) + }(handler) + } +} + +// NodeStats tracks statistics about a node. +type NodeStats struct { + MessagesSent int64 `json:"messages_sent"` + MessagesReceived int64 `json:"messages_received"` + HandoffsAccepted int `json:"handoffs_accepted"` + HandoffsInitiated int `json:"handoffs_initiated"` + LastError string `json:"last_error,omitempty"` + LastErrorTime time.Time `json:"last_error_time,omitempty"` + UptimeStart time.Time `json:"uptime_start"` +} + +// NodeWithState combines a node with its state and stats. +type NodeWithState struct { + Node *NodeInfo `json:"node"` + State *NodeState `json:"state"` + Stats *NodeStats `json:"stats,omitempty"` +} + +// IsAvailable returns true if the node is available for handoff. +func (nws *NodeWithState) IsAvailable() bool { + if nws.State == nil || nws.Node == nil { + return false + } + return nws.State.Status == NodeStatusAlive && nws.Node.LoadScore < DefaultAvailableLoadThreshold +} + +// ClusterView represents the current view of the cluster. +type ClusterView struct { + Nodes map[string]*NodeWithState `json:"nodes"` + LocalNodeID string `json:"local_node_id"` + Size int `json:"size"` + Version int64 `json:"version"` // View version for conflict detection + mu sync.RWMutex +} + +// NewClusterView creates a new cluster view. +func NewClusterView(localNodeID string) *ClusterView { + return &ClusterView{ + Nodes: make(map[string]*NodeWithState), + LocalNodeID: localNodeID, + Version: time.Now().UnixNano(), + } +} + +// AddOrUpdate adds or updates a node in the view. +func (cv *ClusterView) AddOrUpdate(node *NodeInfo) *NodeWithState { + cv.mu.Lock() + defer cv.mu.Unlock() + + cv.Version++ + + existing, ok := cv.Nodes[node.ID] + if ok { + // Update existing node + existing.Node = node + return existing + } + + // Add new node + nws := &NodeWithState{ + Node: node, + State: &NodeState{ + Node: node, + Status: NodeStatusAlive, + StatusSince: time.Now().UnixNano(), + LastSeen: time.Now().UnixNano(), + }, + Stats: &NodeStats{ + UptimeStart: time.Now(), + }, + } + cv.Nodes[node.ID] = nws + cv.Size = len(cv.Nodes) + return nws +} + +// Remove removes a node from the view. +func (cv *ClusterView) Remove(nodeID string) { + cv.mu.Lock() + defer cv.mu.Unlock() + + cv.Version++ + delete(cv.Nodes, nodeID) + cv.Size = len(cv.Nodes) +} + +// Get retrieves a node from the view. +func (cv *ClusterView) Get(nodeID string) (*NodeWithState, bool) { + cv.mu.RLock() + defer cv.mu.RUnlock() + + nws, ok := cv.Nodes[nodeID] + return nws, ok +} + +// List returns all nodes in the view. +func (cv *ClusterView) List() []*NodeWithState { + cv.mu.RLock() + defer cv.mu.RUnlock() + + result := make([]*NodeWithState, 0, len(cv.Nodes)) + for _, nws := range cv.Nodes { + result = append(result, nws) + } + return result +} + +// GetAliveNodes returns all alive nodes. +func (cv *ClusterView) GetAliveNodes() []*NodeWithState { + cv.mu.RLock() + defer cv.mu.RUnlock() + + result := make([]*NodeWithState, 0) + for _, nws := range cv.Nodes { + if nws.State.Status == NodeStatusAlive { + result = append(result, nws) + } + } + return result +} + +// GetAvailableNodes returns all available nodes (alive and not overloaded). +func (cv *ClusterView) GetAvailableNodes() []*NodeWithState { + cv.mu.RLock() + defer cv.mu.RUnlock() + + result := make([]*NodeWithState, 0) + for _, nws := range cv.Nodes { + if nws.IsAvailable() { + result = append(result, nws) + } + } + return result +} diff --git a/pkg/swarm/node_payload.go b/pkg/swarm/node_payload.go new file mode 100644 index 000000000..eaee256c7 --- /dev/null +++ b/pkg/swarm/node_payload.go @@ -0,0 +1,22 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import picolib "github.com/sipeed/picoclaw/pkg/pico" + +// NewHandoffRequestPayload creates a NodePayload for a "handoff_request" action. +func NewHandoffRequestPayload(req *HandoffRequest) picolib.NodePayload { + return picolib.NodePayload{ + picolib.PayloadKeyAction: picolib.NodeActionHandoffRequest, + picolib.PayloadKeyRequest: req, + } +} + +// HandoffResponseReply creates a reply payload carrying a HandoffResponse. +func HandoffResponseReply(resp *HandoffResponse) picolib.NodePayload { + return picolib.NodePayload{picolib.PayloadKeyHandoffResp: resp} +} diff --git a/pkg/swarm/security.go b/pkg/swarm/security.go new file mode 100644 index 000000000..ab493b807 --- /dev/null +++ b/pkg/swarm/security.go @@ -0,0 +1,133 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "time" +) + +// AuthProvider handles authentication for swarm nodes. +// Uses HMAC-based shared secret authentication. +type AuthProvider struct { + sharedSecret []byte + nodeID string +} + +// NewAuthProvider creates a new authentication provider. +func NewAuthProvider(nodeID, sharedSecret string) *AuthProvider { + return &AuthProvider{ + sharedSecret: []byte(sharedSecret), + nodeID: nodeID, + } +} + +// SignMessage signs a message with HMAC-SHA256. +// The signature is base64 encoded for JSON transmission. +func (a *AuthProvider) SignMessage(msg any) (string, error) { + if a.sharedSecret == nil { + return "", ErrAuthenticationFailed + } + + // Serialize message to JSON + data, err := json.Marshal(msg) + if err != nil { + return "", fmt.Errorf("failed to marshal message: %w", err) + } + + // Calculate HMAC + h := hmac.New(sha256.New, a.sharedSecret) + h.Write(data) + signature := h.Sum(nil) + + // Return base64 encoded signature + return base64.StdEncoding.EncodeToString(signature), nil +} + +// VerifySignature verifies a message signature. +func (a *AuthProvider) VerifySignature(msg any, signature string) bool { + if a.sharedSecret == nil { + return false + } + + // Calculate expected signature + expected, err := a.SignMessage(msg) + if err != nil { + return false + } + + // Compare signatures + return hmac.Equal([]byte(expected), []byte(signature)) +} + +// GetNodeID returns the node ID for this auth provider. +func (a *AuthProvider) GetNodeID() string { + return a.nodeID +} + +// AuthToken represents an authentication token. +type AuthToken struct { + NodeID string `json:"node_id"` + Signature string `json:"signature"` + Timestamp int64 `json:"timestamp"` +} + +// GenerateToken creates an auth token for the given node. +func (a *AuthProvider) GenerateToken() (*AuthToken, error) { + token := &AuthToken{ + NodeID: a.nodeID, + Timestamp: time.Now().UnixNano(), + } + + signature, err := a.SignMessage(token) + if err != nil { + return nil, err + } + + token.Signature = signature + return token, nil +} + +// VerifyToken verifies an auth token. +func (a *AuthProvider) VerifyToken(token *AuthToken) bool { + if token == nil { + return false + } + + // Check token age (reject tokens older than 1 minute) + age := time.Since(time.Unix(0, token.Timestamp)) + if age > time.Minute { + return false + } + + // Clear signature before verification so the HMAC input matches + // what GenerateToken computed (with Signature as empty string). + sig := token.Signature + token.Signature = "" + result := a.VerifySignature(token, sig) + token.Signature = sig + return result +} + +// AuthenticateNode verifies that a node is allowed to join. +func (a *AuthProvider) AuthenticateNode(nodeID, signature string, challengeData any) bool { + if a.sharedSecret == nil { + // No authentication configured - fail closed + return false + } + + challenge := map[string]any{ + "node_id": nodeID, + "data": challengeData, + } + + return a.VerifySignature(challenge, signature) +} diff --git a/pkg/swarm/swarm_test.go b/pkg/swarm/swarm_test.go new file mode 100644 index 000000000..401958838 --- /dev/null +++ b/pkg/swarm/swarm_test.go @@ -0,0 +1,584 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package swarm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNodeInfo(t *testing.T) { + node := &NodeInfo{ + ID: "test-node-1", + Addr: "192.168.1.100", + Port: 7947, + LoadScore: 0.5, + AgentCaps: map[string]string{ + "agent-1": "general", + }, + Labels: map[string]string{ + "region": "us-west", + }, + Timestamp: time.Now().UnixNano(), + } + + t.Run("IsAlive", func(t *testing.T) { + assert.True(t, node.IsAlive(time.Minute)) + assert.False(t, node.IsAlive(time.Nanosecond)) + }) + + t.Run("GetAddress", func(t *testing.T) { + addr := node.GetAddress() + assert.Equal(t, "192.168.1.100:7947", addr) + }) +} + +func TestClusterView(t *testing.T) { + view := NewClusterView("local-node") + + t.Run("AddOrUpdate", func(t *testing.T) { + node := &NodeInfo{ + ID: "node-1", + Addr: "192.168.1.1", + Port: 7947, + LoadScore: 0.3, + Timestamp: time.Now().UnixNano(), + } + + nws := view.AddOrUpdate(node) + require.NotNil(t, nws) + assert.Equal(t, node.ID, nws.Node.ID) + assert.Equal(t, 1, view.Size) + }) + + t.Run("Get", func(t *testing.T) { + node, ok := view.Get("node-1") + assert.True(t, ok) + assert.Equal(t, "node-1", node.Node.ID) + + _, ok = view.Get("non-existent") + assert.False(t, ok) + }) + + t.Run("GetAliveNodes", func(t *testing.T) { + nodes := view.GetAliveNodes() + assert.Equal(t, 1, len(nodes)) + }) + + t.Run("GetAvailableNodes", func(t *testing.T) { + nodes := view.GetAvailableNodes() + assert.Equal(t, 1, len(nodes)) // 0.3 < 0.9 + }) + + t.Run("Remove", func(t *testing.T) { + view.Remove("node-1") + assert.Equal(t, 0, view.Size) + }) +} + +func TestLoadMonitor(t *testing.T) { + config := &LoadMonitorConfig{ + Enabled: true, + Interval: Duration{time.Second}, + SampleSize: 10, + CPUWeight: 0.3, + MemoryWeight: 0.3, + SessionWeight: 0.4, + } + + monitor := NewLoadMonitor(config) + + t.Run("GetCurrentLoad", func(t *testing.T) { + metrics := monitor.GetCurrentLoad() + assert.NotNil(t, metrics) + assert.GreaterOrEqual(t, metrics.Score, 0.0) + assert.LessOrEqual(t, metrics.Score, 1.0) + assert.GreaterOrEqual(t, metrics.ActiveSessions, 0) + }) + + t.Run("SessionCount", func(t *testing.T) { + monitor.SetSessionCount(5) + assert.Equal(t, 5, monitor.GetSessionCount()) + + monitor.IncrementSessions() + assert.Equal(t, 6, monitor.GetSessionCount()) + + monitor.DecrementSessions() + assert.Equal(t, 5, monitor.GetSessionCount()) + }) + + t.Run("GetAverageScore", func(t *testing.T) { + avg := monitor.GetAverageScore() + assert.GreaterOrEqual(t, avg, 0.0) + assert.LessOrEqual(t, avg, 1.0) + }) +} + +func TestEventDispatcher(t *testing.T) { + ed := NewEventDispatcher() + + t.Run("SubscribeDispatch", func(t *testing.T) { + received := make(chan *NodeEvent, 1) + + id := ed.Subscribe(func(event *NodeEvent) { + received <- event + }) + + event := &NodeEvent{ + Node: &NodeInfo{ID: "test-node"}, + Event: EventJoin, + Time: time.Now().UnixNano(), + } + + ed.Dispatch(event) + + select { + case <-received: + // Event received + case <-time.After(time.Second): + t.Fatal("Event not received") + } + + ed.Unsubscribe(id) + }) + + t.Run("Unsubscribe", func(t *testing.T) { + received := make(chan *NodeEvent, 1) + + id := ed.Subscribe(func(event *NodeEvent) { + received <- event + }) + + ed.Unsubscribe(id) + + event := &NodeEvent{ + Node: &NodeInfo{ID: "test-node"}, + Event: EventJoin, + Time: time.Now().UnixNano(), + } + + ed.Dispatch(event) + + select { + case <-received: + t.Fatal("Should not receive event after unsubscribe") + case <-time.After(100 * time.Millisecond): + // Expected - no event received + } + }) +} + +func TestNodeWithState(t *testing.T) { + node := &NodeInfo{ + ID: "test-node", + Addr: "192.168.1.1", + Port: 7947, + LoadScore: 0.5, + Timestamp: time.Now().UnixNano(), + } + + nws := &NodeWithState{ + Node: node, + State: &NodeState{ + Status: NodeStatusAlive, + StatusSince: time.Now().UnixNano(), + LastSeen: time.Now().UnixNano(), + }, + } + + t.Run("IsAvailable", func(t *testing.T) { + assert.True(t, nws.IsAvailable()) + + // High load + nws.Node.LoadScore = 0.95 + assert.False(t, nws.IsAvailable()) + + // Not alive + nws.Node.LoadScore = 0.5 + nws.State.Status = NodeStatusDead + assert.False(t, nws.IsAvailable()) + }) +} + +func TestDuration(t *testing.T) { + t.Run("UnmarshalJSON from string", func(t *testing.T) { + d := Duration{} + err := d.UnmarshalJSON([]byte(`"5s"`)) + require.NoError(t, err) + assert.Equal(t, 5*time.Second, d.Duration) + }) + + t.Run("MarshalJSON", func(t *testing.T) { + d := Duration{5 * time.Second} + data, err := d.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, []byte(`"5s"`), data) + }) +} + +// Integration tests for discovery and gossip protocol + +func TestDiscoveryServiceNodeDiscovery(t *testing.T) { + t.Run("TwoNodesDiscoverEachOther", func(t *testing.T) { + // Create first node + cfg1 := &Config{ + NodeID: "node-1", + BindAddr: "127.0.0.1", + BindPort: 17946, + RPC: RPCConfig{ + Port: 17947, + }, + Discovery: DiscoveryConfig{ + GossipInterval: Duration{100 * time.Millisecond}, + NodeTimeout: Duration{500 * time.Millisecond}, + DeadNodeTimeout: Duration{2 * time.Second}, + }, + } + + ds1, err := NewDiscoveryService(cfg1) + require.NoError(t, err) + defer ds1.Stop() + + err = ds1.Start() + require.NoError(t, err) + + // Create second node + cfg2 := &Config{ + NodeID: "node-2", + BindAddr: "127.0.0.1", + BindPort: 17948, + RPC: RPCConfig{ + Port: 17949, + }, + Discovery: DiscoveryConfig{ + JoinAddrs: []string{"127.0.0.1:17946"}, + GossipInterval: Duration{100 * time.Millisecond}, + NodeTimeout: Duration{500 * time.Millisecond}, + DeadNodeTimeout: Duration{2 * time.Second}, + }, + } + + ds2, err := NewDiscoveryService(cfg2) + require.NoError(t, err) + defer ds2.Stop() + + err = ds2.Start() + require.NoError(t, err) + + // Wait for discovery + time.Sleep(500 * time.Millisecond) + + // Check that node-2 knows about node-1 + members2 := ds2.Members() + assert.GreaterOrEqual(t, len(members2), 1, "node-2 should discover node-1") + + // Check that node-1 knows about node-2 + members1 := ds1.Members() + assert.GreaterOrEqual(t, len(members1), 1, "node-1 should discover node-2") + }) + + t.Run("NodeHealthCheck", func(t *testing.T) { + cfg := &Config{ + NodeID: "health-node", + BindAddr: "127.0.0.1", + BindPort: 17950, + RPC: RPCConfig{ + Port: 17951, + }, + Discovery: DiscoveryConfig{ + GossipInterval: Duration{100 * time.Millisecond}, + NodeTimeout: Duration{300 * time.Millisecond}, + DeadNodeTimeout: Duration{1 * time.Second}, + }, + } + + ds, err := NewDiscoveryService(cfg) + require.NoError(t, err) + defer ds.Stop() + + err = ds.Start() + require.NoError(t, err) + + // Add a remote node manually + remoteNode := &NodeInfo{ + ID: "remote-node", + Addr: "192.168.1.100", + Port: 7947, + LoadScore: 0.5, + Timestamp: time.Now().UnixNano(), + } + + ds.membership.UpdateNode(remoteNode) + + // Check health check - should have at least the remote node + members := ds.Members() + assert.GreaterOrEqual(t, len(members), 1) + + // Find the remote node + var found *NodeWithState + for _, m := range members { + if m.Node.ID == "remote-node" { + found = m + break + } + } + require.NotNil(t, found) + assert.Equal(t, NodeStatusAlive, found.State.Status) + }) +} + +func TestHandoffCoordinator(t *testing.T) { + t.Run("CanHandleWithLoadThreshold", func(t *testing.T) { + cfg := &Config{ + NodeID: "handoff-node", + Handoff: HandoffConfig{ + Enabled: true, + LoadThreshold: 0.8, + }, + } + + ds, err := NewDiscoveryService(cfg) + require.NoError(t, err) + + hc := NewHandoffCoordinator(ds, cfg.Handoff) + defer hc.Close() + + // With low load, should be able to handle + ds.localNode.LoadScore = 0.5 + assert.True(t, hc.CanHandle("")) + + // With high load, should not be able to handle + ds.localNode.LoadScore = 0.9 + assert.False(t, hc.CanHandle("")) + }) + + t.Run("FindTargetNode", func(t *testing.T) { + cfg := &Config{ + NodeID: "coordinator-node", + } + + ds, err := NewDiscoveryService(cfg) + require.NoError(t, err) + + hc := NewHandoffCoordinator(ds, cfg.Handoff) + defer hc.Close() + + // Add some candidate nodes + node1 := &NodeInfo{ + ID: "target-1", + Addr: "192.168.1.1", + Port: 7947, + LoadScore: 0.3, + AgentCaps: map[string]string{"model": "gpt-4"}, + Timestamp: time.Now().UnixNano(), + } + ds.membership.UpdateNode(node1) + + node2 := &NodeInfo{ + ID: "target-2", + Addr: "192.168.1.2", + Port: 7947, + LoadScore: 0.7, + AgentCaps: map[string]string{"model": "gpt-4"}, + Timestamp: time.Now().UnixNano(), + } + ds.membership.UpdateNode(node2) + + // Should select the least loaded node + target, err := hc.findTargetNode(&HandoffRequest{}) + require.NoError(t, err) + assert.Equal(t, "target-1", target.Node.ID) + }) +} + +func TestLeaderElection(t *testing.T) { + // Helper to create a discovery service with membership for testing + newTestDiscovery := func(t *testing.T, nodeID string, port int) *DiscoveryService { + t.Helper() + cfg := &Config{ + NodeID: nodeID, + BindAddr: "127.0.0.1", + BindPort: port, + RPC: RPCConfig{Port: port + 1}, + Discovery: DiscoveryConfig{ + GossipInterval: Duration{100 * time.Millisecond}, + NodeTimeout: Duration{500 * time.Millisecond}, + DeadNodeTimeout: Duration{2 * time.Second}, + }, + } + ds, err := NewDiscoveryService(cfg) + require.NoError(t, err) + // Register the local node into membership so checkElection sees it + ds.membership.UpdateNode(ds.localNode) + return ds + } + + defaultConfig := LeaderElectionConfig{ + Enabled: true, + ElectionInterval: Duration{100 * time.Millisecond}, + LeaderHeartbeatTimeout: Duration{200 * time.Millisecond}, + } + + t.Run("SingleNodeBecomesLeader", func(t *testing.T) { + ds := newTestDiscovery(t, "node-a", 18100) + + le := NewLeaderElection(ds.localNode.ID, ds.membership, defaultConfig) + + le.checkElection() + + assert.True(t, le.IsLeader()) + assert.Equal(t, "node-a", le.GetLeader()) + }) + + t.Run("LowestIDBecomesLeader", func(t *testing.T) { + ds := newTestDiscovery(t, "node-c", 18110) + + // Add two remote alive nodes + ds.membership.UpdateNode(&NodeInfo{ + ID: "node-a", + Addr: "192.168.1.1", + Port: 7947, + LoadScore: 0.5, + Timestamp: time.Now().UnixNano(), + }) + ds.membership.UpdateNode(&NodeInfo{ + ID: "node-b", + Addr: "192.168.1.2", + Port: 7947, + LoadScore: 0.3, + Timestamp: time.Now().UnixNano(), + }) + + le := NewLeaderElection("node-c", ds.membership, defaultConfig) + + le.checkElection() + + // node-a has the lowest ID + assert.False(t, le.IsLeader()) + assert.Equal(t, "node-a", le.GetLeader()) + }) + + t.Run("DeadNodeNotElected", func(t *testing.T) { + ds := newTestDiscovery(t, "node-c", 18120) + + // Add node-a (lowest ID) but mark it dead + ds.membership.UpdateNode(&NodeInfo{ + ID: "node-a", + Addr: "192.168.1.1", + Port: 7947, + LoadScore: 0.2, + Timestamp: time.Now().UnixNano(), + }) + ds.membership.MarkDead("node-a") + + // Add node-b as alive + ds.membership.UpdateNode(&NodeInfo{ + ID: "node-b", + Addr: "192.168.1.2", + Port: 7947, + LoadScore: 0.3, + Timestamp: time.Now().UnixNano(), + }) + + le := NewLeaderElection("node-c", ds.membership, defaultConfig) + + le.checkElection() + + // node-a is dead, so node-b (next lowest alive ID) should be leader + assert.Equal(t, "node-b", le.GetLeader()) + assert.False(t, le.IsLeader()) + }) + + t.Run("SuspectNodeNotElected", func(t *testing.T) { + ds := newTestDiscovery(t, "node-c", 18130) + + // Add node-a (lowest ID) but mark it suspect + ds.membership.UpdateNode(&NodeInfo{ + ID: "node-a", + Addr: "192.168.1.1", + Port: 7947, + LoadScore: 0.2, + Timestamp: time.Now().UnixNano(), + }) + ds.membership.MarkSuspect("node-a") + + le := NewLeaderElection("node-c", ds.membership, defaultConfig) + + le.checkElection() + + // node-a is suspect, local node-c should be leader + assert.Equal(t, "node-c", le.GetLeader()) + assert.True(t, le.IsLeader()) + }) + + t.Run("LeaderReelectionOnLeaderDeath", func(t *testing.T) { + ds := newTestDiscovery(t, "node-b", 18140) + + // Add node-a as alive leader + ds.membership.UpdateNode(&NodeInfo{ + ID: "node-a", + Addr: "192.168.1.1", + Port: 7947, + LoadScore: 0.2, + Timestamp: time.Now().UnixNano(), + }) + + le := NewLeaderElection("node-b", ds.membership, defaultConfig) + + le.checkElection() + assert.Equal(t, "node-a", le.GetLeader()) + assert.False(t, le.IsLeader()) + + // Now mark node-a as dead + ds.membership.MarkDead("node-a") + + // monitorLeader should detect and trigger reelection + le.monitorLeader() + + // node-b should now be leader + assert.Equal(t, "node-b", le.GetLeader()) + assert.True(t, le.IsLeader()) + }) + + t.Run("NoDoubleNotification", func(t *testing.T) { + ds := newTestDiscovery(t, "node-a", 18150) + + le := NewLeaderElection("node-a", ds.membership, defaultConfig) + + // First election — should produce exactly one notification + le.checkElection() + assert.True(t, le.IsLeader()) + + // Drain the channel + count := 0 + for { + select { + case <-le.LeaderChanges(): + count++ + default: + goto done + } + } + done: + assert.Equal(t, 1, count, "should receive exactly one leader change notification") + }) + + t.Run("GetState", func(t *testing.T) { + ds := newTestDiscovery(t, "node-a", 18160) + + le := NewLeaderElection("node-a", ds.membership, defaultConfig) + le.checkElection() + + state := le.GetState() + assert.Equal(t, "node-a", state.LeaderID) + assert.True(t, state.IsLeader) + assert.GreaterOrEqual(t, state.MemberCount, 1) + }) +} diff --git a/pkg/tools/handoff_tool.go b/pkg/tools/handoff_tool.go new file mode 100644 index 000000000..343ce4eb9 --- /dev/null +++ b/pkg/tools/handoff_tool.go @@ -0,0 +1,140 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Swarm mode support for multi-agent coordination +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package tools + +import ( + "context" + "fmt" + + "github.com/sipeed/picoclaw/pkg/swarm" +) + +// HandoffTool implements the handoff tool for swarm mode. +type HandoffTool struct { + coordinator *swarm.HandoffCoordinator + channel string + chatID string +} + +// NewHandoffTool creates a new handoff tool. +func NewHandoffTool(coordinator *swarm.HandoffCoordinator) *HandoffTool { + return &HandoffTool{ + coordinator: coordinator, + channel: "cli", + chatID: "direct", + } +} + +// Name returns the tool name. +func (t *HandoffTool) Name() string { + return "handoff" +} + +// Description returns the tool description. +func (t *HandoffTool) Description() string { + return "Delegate this task to another agent in the swarm. Use when you cannot handle the task due to capability constraints or system overload." +} + +// Parameters returns the tool parameters schema. +func (t *HandoffTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "reason": map[string]any{ + "type": "string", + "enum": []string{"no_capability", "overloaded", "user_request"}, + "description": "The reason for handing off this task", + }, + "required_capability": map[string]any{ + "type": "string", + "description": "The specific capability required to handle this task", + }, + "context": map[string]any{ + "type": "string", + "description": "Additional context about why this handoff is needed", + }, + }, + } +} + +// SetContext sets the channel and chat ID for the tool. +func (t *HandoffTool) SetContext(channel, chatID string) { + t.channel = channel + t.chatID = chatID +} + +// Execute executes the handoff tool. +func (t *HandoffTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + if t.coordinator == nil { + return ErrorResult("Swarm mode is not enabled or handoff coordinator not configured").WithError( + fmt.Errorf("handoff coordinator is nil")) + } + + // Parse reason + reasonStr, _ := args["reason"].(string) + var reason swarm.HandoffReason + switch reasonStr { + case "no_capability": + reason = swarm.ReasonNoCapability + case "overloaded": + reason = swarm.ReasonOverloaded + case "user_request": + reason = swarm.ReasonUserRequest + default: + reason = swarm.ReasonNoCapability + } + + // Parse required capability + requiredCap, _ := args["required_capability"].(string) + + // Parse context + contextMsg, _ := args["context"].(string) + + // Build handoff request + req := &swarm.HandoffRequest{ + Reason: reason, + RequiredCap: requiredCap, + Metadata: make(map[string]string), + } + + if contextMsg != "" { + req.Metadata["context"] = contextMsg + } + + // Execute handoff + resp, err := t.coordinator.InitiateHandoff(ctx, req) + if err != nil { + return ErrorResult(fmt.Sprintf("Handoff failed: %v", err)).WithError(err) + } + + if !resp.Accepted { + return ErrorResult(fmt.Sprintf("Handoff rejected by all nodes: %s", resp.Reason)).WithError( + fmt.Errorf("handoff rejected: %s", resp.Reason)) + } + + // Build result message + resultMsg := fmt.Sprintf("Task handed off to node %s\n", resp.NodeID) + if resp.Reason != "" { + resultMsg += fmt.Sprintf("Note: %s\n", resp.Reason) + } + + return &ToolResult{ + ForLLM: resultMsg + "The target node will process this task and respond to the user.", + ForUser: "Your task has been delegated to another agent in the swarm. They will respond shortly.", + Silent: false, + IsError: false, + Async: true, // Handoff is async - target node will respond directly + } +} + +// CanHandle reports whether the local node can handle the given capability. +func (t *HandoffTool) CanHandle(requiredCap string) bool { + if t.coordinator == nil { + return true // If swarm is disabled, we can "handle" everything + } + return t.coordinator.CanHandle(requiredCap) +} diff --git a/pkg/tools/swarm_nodes.go b/pkg/tools/swarm_nodes.go new file mode 100644 index 000000000..e7dff13c8 --- /dev/null +++ b/pkg/tools/swarm_nodes.go @@ -0,0 +1,307 @@ +// Package tools provides swarm node management tools. +package tools + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/swarm" +) + +// FormatClusterStatus returns a human-readable cluster status string. +// Used by the /nodes command handler. +func FormatClusterStatus( + discovery *swarm.DiscoveryService, + load *swarm.LoadMonitor, + localID string, + verbose bool, +) string { + if discovery == nil { + return "Swarm mode is not enabled." + } + + members := discovery.Members() + if len(members) == 0 { + return "No nodes found in the swarm cluster." + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Swarm Cluster Status (%d node%s):\n\n", len(members), plural(len(members)))) + + // Sort nodes: local first, then by ID + sortedMembers := make([]*swarm.NodeWithState, len(members)) + copy(sortedMembers, members) + for i := range sortedMembers { + for j := i + 1; j < len(sortedMembers); j++ { + if sortedMembers[j].Node.ID == localID { + sortedMembers[i], sortedMembers[j] = sortedMembers[j], sortedMembers[i] + break + } + if sortedMembers[i].Node.ID > sortedMembers[j].Node.ID { + sortedMembers[i], sortedMembers[j] = sortedMembers[j], sortedMembers[i] + } + } + } + + for _, m := range sortedMembers { + node := m.Node + state := m.State + + localMark := " " + if node.ID == localID { + localMark = "*" + } + + sb.WriteString(fmt.Sprintf("%s **%s**\n", localMark, node.ID)) + + if verbose { + sb.WriteString(fmt.Sprintf(" Address: %s:%d\n", node.Addr, node.Port)) + sb.WriteString(fmt.Sprintf(" Status: %s", state.Status)) + writeStatusEmoji(&sb, state.Status) + sb.WriteString("\n") + + loadPercent := int(node.LoadScore * 100) + sb.WriteString(fmt.Sprintf(" Load: %.2f%% ", node.LoadScore*100)) + if loadPercent < 50 { + sb.WriteString("🟩") + } else if loadPercent < 80 { + sb.WriteString("🟨") + } else { + sb.WriteString("🟥") + } + sb.WriteString("\n") + + lastSeen := time.Unix(0, state.LastSeen) + age := time.Since(lastSeen) + sb.WriteString(fmt.Sprintf(" Last seen: %s ago\n", formatDuration(age))) + + if len(node.AgentCaps) > 0 { + sb.WriteString(" Capabilities:\n") + for k, v := range node.AgentCaps { + sb.WriteString(fmt.Sprintf(" - %s: %s\n", k, v)) + } + } + + if len(node.Labels) > 0 { + sb.WriteString(" Labels:\n") + for k, v := range node.Labels { + sb.WriteString(fmt.Sprintf(" - %s: %s\n", k, v)) + } + } + + if state.PingSuccess > 0 || state.PingFailure > 0 { + total := state.PingSuccess + state.PingFailure + successRate := float64(state.PingSuccess) / float64(total) * 100 + sb.WriteString(fmt.Sprintf(" Ping: %d/%d success (%.1f%%)\n", + state.PingSuccess, total, successRate)) + } + } else { + writeStatusEmoji(&sb, state.Status) + sb.WriteString(fmt.Sprintf(" Load: %.0f%% @ %s:%d\n", + node.LoadScore*100, node.Addr, node.Port)) + } + + sb.WriteString("\n") + } + + sb.WriteString("* = this node\n") + sb.WriteString("\nTip: Use `@node-id: message` to route a request to a specific node.") + + return sb.String() +} + +// writeStatusEmoji appends a status emoji to the builder. +func writeStatusEmoji(sb *strings.Builder, status swarm.NodeStatus) { + switch status { + case swarm.NodeStatusAlive: + sb.WriteString(" 🟢") + case swarm.NodeStatusSuspect: + sb.WriteString(" 🟡") + case swarm.NodeStatusDead: + sb.WriteString(" 🔴") + case swarm.NodeStatusLeft: + sb.WriteString(" ⚪") + } +} + +// SwarmRouteTool routes a request to a specific node in the swarm. +type SwarmRouteTool struct { + discovery *swarm.DiscoveryService + handoff *swarm.HandoffCoordinator + localID string + sendMessageFn func(ctx context.Context, targetNodeID, content, channel, chatID, senderID string) (string, error) +} + +// NewSwarmRouteTool creates a new swarm route tool. +func NewSwarmRouteTool( + discovery *swarm.DiscoveryService, + handoff *swarm.HandoffCoordinator, + localID string, +) *SwarmRouteTool { + return &SwarmRouteTool{ + discovery: discovery, + handoff: handoff, + localID: localID, + } +} + +// SetSendMessageFn sets the function to send messages to other nodes. +func (t *SwarmRouteTool) SetSendMessageFn( + fn func(ctx context.Context, targetNodeID, content, channel, chatID, senderID string) (string, error), +) { + t.sendMessageFn = fn +} + +// Name returns the tool name. +func (t *SwarmRouteTool) Name() string { + return "swarm_route" +} + +// Description returns the tool description. +func (t *SwarmRouteTool) Description() string { + return "Route this request to a specific node in the swarm cluster" +} + +// Parameters returns the tool parameters schema. +func (t *SwarmRouteTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "node_id": map[string]any{ + "type": "string", + "description": "The target node ID to route the request to", + }, + "message": map[string]any{ + "type": "string", + "description": "The message to send to the target node", + }, + }, + "required": []string{"node_id"}, + } +} + +// Execute executes the swarm route tool. +func (t *SwarmRouteTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + if t.discovery == nil { + return ErrorResult("Swarm mode is not enabled") + } + + nodeID, _ := args["node_id"].(string) + message, _ := args["message"].(string) + + if nodeID == "" { + return ErrorResult("node_id is required") + } + + // Check if target is this node + if nodeID == t.localID { + return &ToolResult{ + ForLLM: fmt.Sprintf("Target node %s is this node. Processing locally.", nodeID), + ForUser: fmt.Sprintf("This request is already on node %s.", nodeID), + IsError: false, + } + } + + // Find target node + members := t.discovery.Members() + var target *swarm.NodeWithState + for _, m := range members { + if m.Node.ID == nodeID { + target = m + break + } + } + + if target == nil { + return ErrorResult(fmt.Sprintf("Node %s not found in cluster", nodeID)) + } + + // Check if target is available + if target.State.Status != swarm.NodeStatusAlive { + return ErrorResult(fmt.Sprintf("Node %s is not alive (status: %s)", nodeID, target.State.Status)) + } + + if target.Node.LoadScore > 0.9 { + return ErrorResult(fmt.Sprintf("Node %s is overloaded (load: %.0f%%)", nodeID, target.Node.LoadScore*100)) + } + + // If sendMessageFn is available, use it to send the actual message + if t.sendMessageFn != nil { + logger.InfoCF("swarm", "Sending message to node via Pico channel", map[string]any{ + "target": nodeID, + "message": message, + }) + + response, err := t.sendMessageFn(ctx, nodeID, message, "", "", "") + if err != nil { + return ErrorResult(fmt.Sprintf("Failed to send message to node %s: %v", nodeID, err)) + } + + return &ToolResult{ + ForLLM: response, + ForUser: response, + IsError: false, + } + } + + // Fallback to placeholder message (for backward compatibility) + logger.InfoCF("swarm", "Routing request to node", map[string]any{ + "target": nodeID, + "message": message, + "addr": target.Node.GetAddress(), + }) + + return &ToolResult{ + ForLLM: fmt.Sprintf("Request routed to node %s @ %s (load: %.0f%%). "+ + "That node will process the request and respond independently. "+ + "For direct communication, you can contact that node directly at %s:%d.", + nodeID, target.Node.Addr, target.Node.LoadScore*100, target.Node.Addr, target.Node.Port), + ForUser: fmt.Sprintf("Your request has been routed to node %s. They will respond separately.", + nodeID), + IsError: false, + Async: true, + } +} + +// ParseNodeMention extracts node ID from a message like "@node-id: message" +func ParseNodeMention(message string) (nodeID, content string) { + message = strings.TrimSpace(message) + + // Check for @node-id: pattern + if strings.HasPrefix(message, "@") { + rest := message[1:] + idx := strings.IndexAny(rest, ": \n") + if idx > 0 && rest[idx] == ':' { + nodeID = rest[:idx] + content = strings.TrimSpace(rest[idx+1:]) + return nodeID, content + } + } + + return "", message +} + +// Helper functions + +func plural(n int) string { + if n == 1 { + return "" + } + return "s" +} + +func formatDuration(d time.Duration) string { + if d < time.Second { + return fmt.Sprintf("%dms", d.Milliseconds()) + } + if d < time.Minute { + return fmt.Sprintf("%.1fs", d.Seconds()) + } + if d < time.Hour { + return fmt.Sprintf("%.1fm", d.Minutes()) + } + return fmt.Sprintf("%.1fh", d.Hours()) +}