diff --git a/README.md b/README.md index aa7b0719a..53b3cc094 100644 --- a/README.md +++ b/README.md @@ -996,6 +996,18 @@ PicoClaw routes providers by protocol family: This keeps the runtime lightweight while making new OpenAI-compatible backends mostly a config operation (`api_base` + `api_key`). +### Lifecycle Hooks (Plugin-style Extensions) + +PicoClaw provides typed lifecycle hooks for observability, outbound filtering, and tool guardrails. + +- Register hooks in Go at startup with `hooks.NewHookRegistry()`. +- Attach once via `agentLoop.SetHooks(registry)` before `Run()` and handle setup errors. +- If hooks are not set, default behavior is unchanged. + +See runnable examples: [docs/hooks-plugin-examples.md](docs/hooks-plugin-examples.md) + +Roadmap for plugin system evolution: [docs/plugin-system-roadmap.md](docs/plugin-system-roadmap.md) +
Zhipu diff --git a/docs/hooks-plugin-examples.md b/docs/hooks-plugin-examples.md new file mode 100644 index 000000000..8626f8d67 --- /dev/null +++ b/docs/hooks-plugin-examples.md @@ -0,0 +1,138 @@ +# Lifecycle Hooks: Plugin-Style Examples + +This guide shows how to extend PicoClaw behavior with `pkg/hooks` without modifying core agent logic. + +For future direction (beyond current hooks foundation), see [Plugin System Roadmap](plugin-system-roadmap.md). + +Current model: +- "Plugin-style" means registering Go handlers at startup. +- Hooks are in-process (no dynamic `.so` loading). +- If no hooks are registered, the runtime follows the normal zero-cost path. + +## How Plugin Works + +PicoClaw's plugin model is a startup-time hook registry: + +1. Build a registry (`hooks.NewHookRegistry()`). +2. Register one or more handlers per lifecycle hook with priority. +3. Attach once with `agentLoop.SetHooks(registry)` before `agentLoop.Run(...)` (check error). +4. Agent loop triggers hook handlers at specific lifecycle points. + +Execution semantics: + +- Observe-only hooks (`message_received`, `after_tool_call`, `llm_input`, `llm_output`, `session_start`, `session_end`) + - run concurrently + - cannot block core behavior +- Modifying hooks (`message_sending`, `before_tool_call`) + - run sequentially by priority (lower number first) + - may mutate event data + - may cancel operation via `Cancel=true` + +Safety model: + +- Panic in one handler is recovered and logged. +- Handler errors are logged; pipeline continues unless canceled by event flag. +- With no registered hooks, agent loop behavior is unchanged. + +Lifecycle map: + +```text +Inbound message + -> message_received + -> session_start + -> llm_input + -> llm_output + -> before_tool_call (cancelable) + -> tool execute + -> after_tool_call + -> message_sending (cancelable) + -> outbound publish + -> session_end +``` + +Note: the map above is shown as a single pass for readability. In practice, the +agent loop may iterate up to `MaxToolIterations`, and `llm_input`, `llm_output`, +`before_tool_call`, and `after_tool_call` can fire multiple times. + +## Available Hooks + +| Hook | Type | Typical use | +|---|---|---| +| `message_received` | observe-only | inbound telemetry | +| `message_sending` | modifying + cancel | content filtering, safety policy | +| `before_tool_call` | modifying + cancel | tool allow/deny, arg rewriting | +| `after_tool_call` | observe-only | latency/error metrics | +| `llm_input` | observe-only | prompt size monitoring | +| `llm_output` | observe-only | response/tool-call telemetry | +| `session_start` | observe-only | session audit | +| `session_end` | observe-only | session cleanup metrics | + +## Quick Start + +```go +package main + +import ( + "context" + "strings" + + "github.com/sipeed/picoclaw/pkg/hooks" +) + +func buildHooks() *hooks.HookRegistry { + reg := hooks.NewHookRegistry() + + // 1) Guardrail: block shell tool globally. + reg.OnBeforeToolCall("block-shell", 100, func(_ context.Context, e *hooks.BeforeToolCallEvent) error { + if e.ToolName == "shell" { + e.Cancel = true + e.CancelReason = "shell tool is disabled by local policy" + } + return nil + }) + + // 2) Outbound filter: redact obvious API key patterns. + reg.OnMessageSending("redact-secrets", 50, func(_ context.Context, e *hooks.MessageSendingEvent) error { + e.Content = strings.ReplaceAll(e.Content, "sk-", "[redacted]-") + return nil + }) + + // 3) Telemetry: record tool latency or errors. + reg.OnAfterToolCall("tool-telemetry", 0, func(_ context.Context, e *hooks.AfterToolCallEvent) error { + // Send to metrics backend / logs as needed. + _ = e.ToolName + _ = e.Duration + _ = e.Result + return nil + }) + + return reg +} +``` + +Attach once during startup: + +```go +agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) +if err := agentLoop.SetHooks(buildHooks()); err != nil { + panic(err) // replace with your startup error handling +} +``` + +## Priority and Cancellation + +- Lower `priority` runs first. +- `message_sending` and `before_tool_call` are sequential and can cancel. +- Other hooks are observe-only and run concurrently. + +Recommended ordering: +- `0-49`: telemetry and logging +- `50-89`: transforms (redaction, normalization) +- `90+`: hard guardrails (block/cancel) + +## Safety Notes + +- Hook panics are recovered internally; one bad hook does not crash the loop. +- Hook errors are logged and execution continues unless `Cancel` is set. +- Observe-only hooks (`message_received`, `after_tool_call`, `llm_input`, `llm_output`, `session_start`, `session_end`) must treat events as read-only. +- Keep hook handlers fast and non-blocking to avoid latency impact. diff --git a/docs/plugin-system-roadmap.md b/docs/plugin-system-roadmap.md new file mode 100644 index 000000000..2709012db --- /dev/null +++ b/docs/plugin-system-roadmap.md @@ -0,0 +1,108 @@ +# Plugin System Roadmap + +This document defines how PicoClaw evolves from hook-based extension points to a fuller plugin system in low-risk phases. + +## Current Status (Phase 0: Foundation) + +Implemented in current hooks PR: + +- Typed lifecycle hooks (`pkg/hooks`) +- Priority-based handler ordering +- Cancellation support for modifying hooks +- Panic recovery and error isolation +- Agent-loop integration via `agentLoop.SetHooks(...)` + +Compatibility: + +- If no hooks are registered, runtime behavior is unchanged. +- No config migration is required. + +## Non-Goals in Phase 0 + +- No dynamic runtime plugin loading +- No remote plugin marketplace/distribution +- No plugin sandboxing model +- No stable external plugin ABI yet +- No Go `.so` plugin loading as default direction + +## Phase Plan + +## Phase 1: Static Plugin Contract (Compile-time) — Implemented + +Goal: define a minimal public plugin contract for Go modules. + +Implemented: + +- Add `pkg/plugin` with a small interface: + - `Name() string` + - `APIVersion() string` + - `Register(*hooks.HookRegistry) error` +- Register plugins at startup in code. +- Add compatibility metadata (`plugin.APIVersion`) and registration-time checks. + +Exit criteria (met): + +- Example plugin module builds against the contract. +- Startup validation logs loaded plugins and registration errors clearly. + +## Phase 2: Config-driven Enable/Disable + +Goal: operational control without code changes. + +Proposed: + +- Add plugin list/config in `config.json`: + - enabled/disabled flags + - optional plugin-specific settings +- Deterministic load order and conflict resolution rules. + +Exit criteria: + +- Users can toggle plugins without rebuilding. +- Clear startup diagnostics for invalid plugin config. + +## Phase 3: Developer Experience + +Goal: make third-party plugin development straightforward. + +Proposed: + +- Provide `examples/plugins/*` reference implementations. +- Publish plugin authoring guide (lifecycle map, best practices, safety constraints). +- Add plugin-focused test harness pattern for hook behavior verification. + +Exit criteria: + +- New plugin can be built from template with minimal boilerplate. +- CI examples demonstrate expected behavior and regression checks. + +## Phase 4: Optional Dynamic Loading (Separate RFC) + +Goal: support runtime-loaded plugins only if security and operability are acceptable. + +Preferred direction: + +- Runtime plugins run as subprocesses. +- Host and plugin communicate via RPC/gRPC. +- Host manages lifecycle (spawn/health/timeout/restart), not in-process dynamic loading. + +Why this direction: + +- Go native `.so` plugin loading has strict toolchain/ABI coupling with host binary. +- Subprocess RPC model reduces coupling and improves fault isolation. +- Process boundary provides a cleaner place for permissions and sandbox controls. + +Preconditions: + +- Threat model approved +- Signature/trust model defined +- Sandboxing and permission boundaries defined +- Rollback and safe-disable behavior validated +- Versioned RPC handshake and capability negotiation defined +- Process supervision policy defined (timeouts, retries, crash loop backoff) + +Until then, compile-time registration remains the recommended model. + +## Maintainer Review Notes + +The current hooks PR should be reviewed as Phase 0+1 only. It intentionally establishes extension points while avoiding high-risk runtime plugin mechanics. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 693f2227b..4065ff022 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -20,7 +20,9 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" + "github.com/sipeed/picoclaw/pkg/hooks" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/plugin" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/skills" @@ -38,6 +40,8 @@ type AgentLoop struct { summarizing sync.Map fallback *providers.FallbackChain channelManager *channels.Manager + hooks *hooks.HookRegistry + pluginManager *plugin.Manager } // processOptions configures how a message is processed @@ -118,7 +122,7 @@ func registerSharedTools( // Message tool messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content string) error { + messageTool.SetSendCallback(func(_ context.Context, channel, chatID, content string) error { msgBus.PublishOutbound(bus.OutboundMessage{ Channel: channel, ChatID: chatID, @@ -185,7 +189,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { } if !alreadySent { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.sendOutbound(ctx, bus.OutboundMessage{ Channel: msg.Channel, ChatID: msg.ChatID, Content: response, @@ -214,6 +218,100 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { al.channelManager = cm } +// SetHooks installs a hook registry. Must be called before Run starts. +func (al *AgentLoop) SetHooks(h *hooks.HookRegistry) error { + if al.running.Load() { + return fmt.Errorf("SetHooks must be called before Run starts") + } + al.hooks = h + + // Rewire MessageTool callbacks to route through sendOutbound for hook interception. + for _, agentID := range al.registry.ListAgentIDs() { + if agent, ok := al.registry.GetAgent(agentID); ok { + if tool, ok := agent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + if h == nil { + mt.SetSendCallback(func(_ context.Context, channel, chatID, content string) error { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) + return nil + }) + continue + } + mt.SetSendCallback(func(ctx context.Context, channel, chatID, content string) error { + if sent, reason := al.sendOutbound(ctx, bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }); !sent { + if strings.TrimSpace(reason) == "" { + reason = "unspecified" + } + return fmt.Errorf("message canceled by hook: %s", reason) + } + return nil + }) + } + } + } + } + return nil +} + +// SetPluginManager installs a plugin manager and routes its hook registry into the loop. +// Must be called before Run starts. +func (al *AgentLoop) SetPluginManager(pm *plugin.Manager) error { + if pm == nil { + if err := al.SetHooks(nil); err != nil { + return err + } + al.pluginManager = nil + return nil + } + if err := al.SetHooks(pm.HookRegistry()); err != nil { + return err + } + al.pluginManager = pm + return nil +} + +// EnablePlugins is a convenience helper to build and install a plugin manager. +func (al *AgentLoop) EnablePlugins(plugins ...plugin.Plugin) error { + pm := plugin.NewManager() + if err := pm.RegisterAll(plugins...); err != nil { + return err + } + return al.SetPluginManager(pm) +} + +// sendOutbound wraps bus.PublishOutbound with the message_sending hook. +// Returns whether the message was sent and, if canceled, the cancel reason. +func (al *AgentLoop) sendOutbound(ctx context.Context, msg bus.OutboundMessage) (bool, string) { + if al.hooks != nil { + event := &hooks.MessageSendingEvent{Channel: msg.Channel, ChatID: msg.ChatID, Content: msg.Content} + al.hooks.TriggerMessageSending(ctx, event) + if event.Cancel { + reason := event.CancelReason + if reason == "" { + reason = "unspecified" + } + logger.WarnCF("hooks", "Outbound message canceled by hook", + map[string]any{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "reason": reason, + }) + return false, reason + } + msg.Content = event.Content + } + al.bus.PublishOutbound(msg) + return true, "" +} + // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { @@ -283,6 +381,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) "session_key": msg.SessionKey, }) + // Fire message_received hook + if al.hooks != nil { + al.hooks.TriggerMessageReceived(ctx, &hooks.MessageReceivedEvent{ + Channel: msg.Channel, + SenderID: msg.SenderID, + ChatID: msg.ChatID, + Content: msg.Content, + Media: msg.Media, + Metadata: msg.Metadata, + }) + } + // Route system messages to processSystemMessage if msg.Channel == "system" { return al.processSystemMessage(ctx, msg) @@ -404,6 +514,18 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt // 1. Update tool contexts al.updateToolContexts(agent, opts.Channel, opts.ChatID) + // Fire session hooks + if al.hooks != nil { + sessionEvt := &hooks.SessionEvent{ + AgentID: agent.ID, + SessionKey: opts.SessionKey, + Channel: opts.Channel, + ChatID: opts.ChatID, + } + al.hooks.TriggerSessionStart(ctx, sessionEvt) + defer al.hooks.TriggerSessionEnd(ctx, sessionEvt) + } + // 2. Build messages (skip history for heartbeat) var history []providers.Message var summary string @@ -448,7 +570,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt // 8. Optional: send response via bus if opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.sendOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: finalContent, @@ -545,8 +667,19 @@ func (al *AgentLoop) runLLMIteration( } // Retry loop for context/token errors + llmStart := time.Now() maxRetries := 2 for retry := 0; retry <= maxRetries; retry++ { + // Fire llm_input hook (re-fires after compression so hooks see actual messages) + if al.hooks != nil { + al.hooks.TriggerLLMInput(ctx, &hooks.LLMInputEvent{ + AgentID: agent.ID, + Model: agent.Model, + Messages: messages, + Tools: providerToolDefs, + Iteration: iteration, + }) + } response, err = callLLM() if err == nil { break @@ -565,7 +698,7 @@ func (al *AgentLoop) runLLMIteration( }) if retry == 0 && !constants.IsInternalChannel(opts.Channel) { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.sendOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: "Context window exceeded. Compressing history and retrying...", @@ -584,6 +717,8 @@ func (al *AgentLoop) runLLMIteration( break } + llmDuration := time.Since(llmStart) + if err != nil { logger.ErrorCF("agent", "LLM call failed", map[string]any{ @@ -594,6 +729,18 @@ func (al *AgentLoop) runLLMIteration( return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) } + // Fire llm_output hook + if al.hooks != nil { + al.hooks.TriggerLLMOutput(ctx, &hooks.LLMOutputEvent{ + AgentID: agent.ID, + Model: agent.Model, + Content: response.Content, + ToolCalls: response.ToolCalls, + Iteration: iteration, + Duration: llmDuration, + }) + } + // Check if no tool calls - we're done if len(response.ToolCalls) == 0 { finalContent = response.Content @@ -684,18 +831,64 @@ func (al *AgentLoop) runLLMIteration( } } - toolResult := agent.Tools.ExecuteWithContext( - ctx, - tc.Name, - tc.Arguments, - opts.Channel, - opts.ChatID, - asyncCallback, - ) + // Fire before_tool_call hook + var toolResult *tools.ToolResult + toolCanceled := false + if al.hooks != nil { + args := tc.Arguments + if args == nil { + args = make(map[string]any) + } + btcEvent := &hooks.BeforeToolCallEvent{ + ToolName: tc.Name, + Args: args, + Channel: opts.Channel, + ChatID: opts.ChatID, + } + al.hooks.TriggerBeforeToolCall(ctx, btcEvent) + if btcEvent.Cancel { + toolCanceled = true + reason := btcEvent.CancelReason + if strings.TrimSpace(reason) == "" { + reason = fmt.Sprintf("tool call %q was canceled by before_tool_call hook", tc.Name) + } + toolResult = tools.ErrorResult(reason) + } + tc.Arguments = btcEvent.Args + if tc.Arguments == nil { + tc.Arguments = make(map[string]any) + } + } + + var toolDuration time.Duration + if !toolCanceled { + toolStart := time.Now() + toolResult = agent.Tools.ExecuteWithContext( + ctx, + tc.Name, + tc.Arguments, + opts.Channel, + opts.ChatID, + asyncCallback, + ) + toolDuration = time.Since(toolStart) + } + + // Fire after_tool_call hook (fires for both executed and canceled calls) + if al.hooks != nil { + al.hooks.TriggerAfterToolCall(ctx, &hooks.AfterToolCallEvent{ + ToolName: tc.Name, + Args: tc.Arguments, + Channel: opts.Channel, + ChatID: opts.ChatID, + Duration: toolDuration, + Result: toolResult, + }) + } // Send ForUser content to user immediately if not Silent if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.sendOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: toolResult.ForUser, diff --git a/pkg/agent/plugin_test.go b/pkg/agent/plugin_test.go new file mode 100644 index 000000000..115923476 --- /dev/null +++ b/pkg/agent/plugin_test.go @@ -0,0 +1,342 @@ +package agent + +import ( + "context" + "os" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/hooks" + "github.com/sipeed/picoclaw/pkg/plugin" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +type blockingPlugin struct{} + +func (p blockingPlugin) Name() string { + return "block-outbound" +} + +func (p blockingPlugin) APIVersion() string { + return plugin.APIVersion +} + +func (p blockingPlugin) Register(r *hooks.HookRegistry) error { + r.OnMessageSending("block-all", 0, func(_ context.Context, e *hooks.MessageSendingEvent) error { + e.Cancel = true + e.CancelReason = "blocked by plugin" + return nil + }) + return nil +} + +type nilArgsProvider struct { + calls int +} + +func (p *nilArgsProvider) Chat( + _ context.Context, + _ []providers.Message, + _ []providers.ToolDefinition, + _ string, + _ map[string]any, +) (*providers.LLMResponse, error) { + if p.calls == 0 { + p.calls++ + return &providers.LLMResponse{ + Content: "", + ToolCalls: []providers.ToolCall{ + { + ID: "tc-1", + Type: "function", + Name: "nil_args_tool", + Arguments: map[string]any{"seed": "value"}, + }, + }, + }, nil + } + p.calls++ + return &providers.LLMResponse{ + Content: "done", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (p *nilArgsProvider) GetDefaultModel() string { + return "test-model" +} + +type nilArgsCaptureTool struct { + receivedNil bool +} + +func (t *nilArgsCaptureTool) Name() string { + return "nil_args_tool" +} + +func (t *nilArgsCaptureTool) Description() string { + return "captures whether args are nil" +} + +func (t *nilArgsCaptureTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (t *nilArgsCaptureTool) Execute(_ context.Context, args map[string]any) *tools.ToolResult { + if args == nil { + t.receivedNil = true + } + return tools.SilentResult("ok") +} + +func TestSetPluginManagerInstallsHookRegistry(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + + pm := plugin.NewManager() + if err := pm.Register(blockingPlugin{}); err != nil { + t.Fatalf("register plugin: %v", err) + } + + if err := al.SetPluginManager(pm); err != nil { + t.Fatalf("SetPluginManager: %v", err) + } + + if al.pluginManager == nil { + t.Fatal("expected plugin manager to be set") + } + if al.hooks != pm.HookRegistry() { + t.Fatal("expected agent loop hooks to use plugin manager registry") + } + + sent, reason := al.sendOutbound(context.Background(), bus.OutboundMessage{ + Channel: "cli", + ChatID: "direct", + Content: "hello", + }) + if sent { + t.Fatal("expected outbound message to be blocked by plugin") + } + if reason == "" { + t.Fatal("expected cancel reason to be propagated") + } +} + +func TestSetHooksReturnsErrorWhenRunning(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + al.running.Store(true) + + if err := al.SetHooks(hooks.NewHookRegistry()); err == nil { + t.Fatal("expected error when calling SetHooks while running") + } +} + +func TestSetPluginManagerDoesNotPartiallyUpdateOnError(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + al.running.Store(true) + + pm := plugin.NewManager() + if err := pm.Register(blockingPlugin{}); err != nil { + t.Fatalf("register plugin: %v", err) + } + + if err := al.SetPluginManager(pm); err == nil { + t.Fatal("expected SetPluginManager to fail while running") + } + if al.pluginManager != nil { + t.Fatal("expected plugin manager to remain unchanged on SetPluginManager failure") + } +} + +func TestBeforeToolCallHooksCannotLeaveToolArgsNil(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &nilArgsProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + captureTool := &nilArgsCaptureTool{} + al.RegisterTool(captureTool) + + r := hooks.NewHookRegistry() + r.OnBeforeToolCall("force-nil-args", 0, func(_ context.Context, e *hooks.BeforeToolCallEvent) error { + if e.ToolName == "nil_args_tool" { + e.Args = nil + } + return nil + }) + if setErr := al.SetHooks(r); setErr != nil { + t.Fatalf("SetHooks: %v", setErr) + } + + resp, err := al.ProcessDirectWithChannel(context.Background(), "run nil args test", "s1", "cli", "direct") + if err != nil { + t.Fatalf("ProcessDirectWithChannel: %v", err) + } + if resp != "done" { + t.Fatalf("expected final response 'done', got %q", resp) + } + if captureTool.receivedNil { + t.Fatal("expected tool args to be reinitialized to non-nil map") + } +} + +func TestSetHooksNilRestoresDirectMessageCallback(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-plugin-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &mockProvider{}) + agent := al.registry.GetDefaultAgent() + if agent == nil { + t.Fatal("expected default agent") + } + tool, ok := agent.Tools.Get("message") + if !ok { + t.Fatal("expected message tool") + } + mt, ok := tool.(*tools.MessageTool) + if !ok { + t.Fatal("expected message tool type") + } + + reg := hooks.NewHookRegistry() + reg.OnMessageSending("block-all", 0, func(_ context.Context, e *hooks.MessageSendingEvent) error { + e.Cancel = true + e.CancelReason = "blocked-by-hook" + return nil + }) + if err := al.SetHooks(reg); err != nil { + t.Fatalf("SetHooks(reg): %v", err) + } + + blocked := mt.Execute(context.Background(), map[string]any{ + "content": "first", + "channel": "cli", + "chat_id": "direct", + }) + if !blocked.IsError { + t.Fatal("expected message tool call to fail while hooks are active") + } + if blocked.Err == nil || !strings.Contains(blocked.Err.Error(), "blocked-by-hook") { + t.Fatalf("expected hook cancel reason in error, got %#v", blocked.Err) + } + + ctxNoMsg, cancelNoMsg := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancelNoMsg() + if _, got := msgBus.SubscribeOutbound(ctxNoMsg); got { + t.Fatal("did not expect outbound message while hook cancellation is active") + } + + if err := al.SetHooks(nil); err != nil { + t.Fatalf("SetHooks(nil): %v", err) + } + + delivered := mt.Execute(context.Background(), map[string]any{ + "content": "second", + "channel": "cli", + "chat_id": "direct", + }) + if delivered.IsError { + t.Fatalf("expected message tool to succeed after SetHooks(nil), got %#v", delivered) + } + + ctxMsg, cancelMsg := context.WithTimeout(context.Background(), time.Second) + defer cancelMsg() + msg, got := msgBus.SubscribeOutbound(ctxMsg) + if !got { + t.Fatal("expected outbound message after SetHooks(nil)") + } + if msg.Content != "second" || msg.Channel != "cli" || msg.ChatID != "direct" { + t.Fatalf("unexpected outbound message: %#v", msg) + } +} diff --git a/pkg/hooks/hooks.go b/pkg/hooks/hooks.go new file mode 100644 index 000000000..3a6a0a675 --- /dev/null +++ b/pkg/hooks/hooks.go @@ -0,0 +1,463 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package hooks + +import ( + "context" + "fmt" + "reflect" + "sync" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// HookHandler is the callback signature for all hooks. +type HookHandler[T any] func(ctx context.Context, event *T) error + +// HookRegistration tracks a handler with its priority and name. +type HookRegistration[T any] struct { + Handler HookHandler[T] + Priority int // Lower = runs first + Name string +} + +// HookRegistry manages all lifecycle hooks. +type HookRegistry struct { + messageReceived []HookRegistration[MessageReceivedEvent] + messageSending []HookRegistration[MessageSendingEvent] + beforeToolCall []HookRegistration[BeforeToolCallEvent] + afterToolCall []HookRegistration[AfterToolCallEvent] + llmInput []HookRegistration[LLMInputEvent] + llmOutput []HookRegistration[LLMOutputEvent] + sessionStart []HookRegistration[SessionEvent] + sessionEnd []HookRegistration[SessionEvent] + mu sync.RWMutex +} + +// NewHookRegistry creates an empty hook registry. +func NewHookRegistry() *HookRegistry { + return &HookRegistry{} +} + +// insertSorted inserts a registration into a new slice sorted by priority. +// Always allocates a new backing array so concurrent readers of the old slice are safe. +func insertSorted[T any](slice []HookRegistration[T], reg HookRegistration[T]) []HookRegistration[T] { + i := 0 + for i < len(slice) && slice[i].Priority <= reg.Priority { + i++ + } + result := make([]HookRegistration[T], len(slice)+1) + copy(result, slice[:i]) + result[i] = reg + copy(result[i+1:], slice[i:]) + return result +} + +// Registration methods + +func (r *HookRegistry) OnMessageReceived(name string, priority int, handler HookHandler[MessageReceivedEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.messageReceived = insertSorted(r.messageReceived, HookRegistration[MessageReceivedEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnMessageSending(name string, priority int, handler HookHandler[MessageSendingEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.messageSending = insertSorted(r.messageSending, HookRegistration[MessageSendingEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnBeforeToolCall(name string, priority int, handler HookHandler[BeforeToolCallEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.beforeToolCall = insertSorted(r.beforeToolCall, HookRegistration[BeforeToolCallEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnAfterToolCall(name string, priority int, handler HookHandler[AfterToolCallEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.afterToolCall = insertSorted(r.afterToolCall, HookRegistration[AfterToolCallEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnLLMInput(name string, priority int, handler HookHandler[LLMInputEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.llmInput = insertSorted(r.llmInput, HookRegistration[LLMInputEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnLLMOutput(name string, priority int, handler HookHandler[LLMOutputEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.llmOutput = insertSorted(r.llmOutput, HookRegistration[LLMOutputEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnSessionStart(name string, priority int, handler HookHandler[SessionEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.sessionStart = insertSorted(r.sessionStart, HookRegistration[SessionEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +func (r *HookRegistry) OnSessionEnd(name string, priority int, handler HookHandler[SessionEvent]) { + r.mu.Lock() + defer r.mu.Unlock() + r.sessionEnd = insertSorted(r.sessionEnd, HookRegistration[SessionEvent]{ + Handler: handler, Priority: priority, Name: name, + }) +} + +// Trigger methods — void hooks + +func cloneMapStringString(src map[string]string) map[string]string { + if src == nil { + return nil + } + dst := make(map[string]string, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func cloneMapStringAny(src map[string]any) map[string]any { + if src == nil { + return nil + } + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = cloneAny(v) + } + return dst +} + +func cloneAny(v any) any { + if v == nil { + return nil + } + cloned := cloneReflectValue(reflect.ValueOf(v)) + if !cloned.IsValid() { + return nil + } + return cloned.Interface() +} + +func cloneReflectValue(v reflect.Value) reflect.Value { + if !v.IsValid() { + return v + } + + switch v.Kind() { + case reflect.Pointer: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.New(v.Type().Elem()) + out.Elem().Set(cloneReflectValue(v.Elem())) + return out + case reflect.Interface: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.New(v.Type()).Elem() + out.Set(cloneReflectValue(v.Elem())) + return out + case reflect.Map: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.MakeMapWithSize(v.Type(), v.Len()) + iter := v.MapRange() + for iter.Next() { + out.SetMapIndex(iter.Key(), cloneReflectValue(iter.Value())) + } + return out + case reflect.Slice: + if v.IsNil() { + return reflect.Zero(v.Type()) + } + out := reflect.MakeSlice(v.Type(), v.Len(), v.Len()) + for i := range v.Len() { + out.Index(i).Set(cloneReflectValue(v.Index(i))) + } + return out + case reflect.Array: + out := reflect.New(v.Type()).Elem() + for i := range v.Len() { + out.Index(i).Set(cloneReflectValue(v.Index(i))) + } + return out + default: + return v + } +} + +func cloneToolCall(tc providers.ToolCall) providers.ToolCall { + out := tc + out.Arguments = cloneMapStringAny(tc.Arguments) + if tc.Function != nil { + f := *tc.Function + out.Function = &f + } + if tc.ExtraContent != nil { + ec := *tc.ExtraContent + if tc.ExtraContent.Google != nil { + g := *tc.ExtraContent.Google + ec.Google = &g + } + out.ExtraContent = &ec + } + return out +} + +func cloneMessage(msg providers.Message) providers.Message { + out := msg + if msg.ToolCalls != nil { + out.ToolCalls = make([]providers.ToolCall, len(msg.ToolCalls)) + for i := range msg.ToolCalls { + out.ToolCalls[i] = cloneToolCall(msg.ToolCalls[i]) + } + } + if msg.SystemParts != nil { + out.SystemParts = make([]providers.ContentBlock, len(msg.SystemParts)) + for i := range msg.SystemParts { + part := msg.SystemParts[i] + if part.CacheControl != nil { + cc := *part.CacheControl + part.CacheControl = &cc + } + out.SystemParts[i] = part + } + } + return out +} + +func cloneToolDefinition(td providers.ToolDefinition) providers.ToolDefinition { + out := td + out.Function = td.Function + out.Function.Parameters = cloneMapStringAny(td.Function.Parameters) + return out +} + +func cloneVoidEvent[T any](event *T) *T { + if event == nil { + return nil + } + + switch e := any(event).(type) { + case *MessageReceivedEvent: + c := *e + if e.Media != nil { + c.Media = append([]string(nil), e.Media...) + } + c.Metadata = cloneMapStringString(e.Metadata) + return any(&c).(*T) + case *AfterToolCallEvent: + c := *e + c.Args = cloneMapStringAny(e.Args) + if e.Result != nil { + r := *e.Result + c.Result = &r + } + return any(&c).(*T) + case *LLMInputEvent: + c := *e + if e.Messages != nil { + c.Messages = make([]providers.Message, len(e.Messages)) + for i := range e.Messages { + c.Messages[i] = cloneMessage(e.Messages[i]) + } + } + if e.Tools != nil { + c.Tools = make([]providers.ToolDefinition, len(e.Tools)) + for i := range e.Tools { + c.Tools[i] = cloneToolDefinition(e.Tools[i]) + } + } + return any(&c).(*T) + case *LLMOutputEvent: + c := *e + if e.ToolCalls != nil { + c.ToolCalls = make([]providers.ToolCall, len(e.ToolCalls)) + for i := range e.ToolCalls { + c.ToolCalls[i] = cloneToolCall(e.ToolCalls[i]) + } + } + return any(&c).(*T) + case *SessionEvent: + c := *e + return any(&c).(*T) + default: + c := *event + return &c + } +} + +// triggerVoid runs all handlers concurrently and waits for completion. +// Each handler receives a cloned event to avoid shared-state mutation races. +// Errors are logged but do not propagate to the caller. +func triggerVoid[T any](ctx context.Context, hooks []HookRegistration[T], event *T, hookName string) { + if len(hooks) == 0 { + return + } + var wg sync.WaitGroup + for _, h := range hooks { + wg.Add(1) + go func(reg HookRegistration[T]) { + defer wg.Done() + eventCopy := cloneVoidEvent(event) + defer func() { + if r := recover(); r != nil { + logger.ErrorCF("hooks", "Hook panic", + map[string]any{ + "hook": hookName, + "handler": reg.Name, + "panic": fmt.Sprintf("%v", r), + }) + } + }() + if err := reg.Handler(ctx, eventCopy); err != nil { + logger.WarnCF("hooks", "Hook error", + map[string]any{ + "hook": hookName, + "handler": reg.Name, + "error": err.Error(), + }) + } + }(h) + } + wg.Wait() +} + +// triggerModifying runs handlers sequentially by priority, stopping if Cancel is set. +// The cancelCheck function inspects the event to determine if Cancel was set. +func triggerModifying[T any]( + ctx context.Context, + hooks []HookRegistration[T], + event *T, + hookName string, + cancelCheck func(*T) bool, +) { + if len(hooks) == 0 { + return + } + for _, h := range hooks { + func() { + defer func() { + if r := recover(); r != nil { + logger.ErrorCF("hooks", "Hook panic", + map[string]any{ + "hook": hookName, + "handler": h.Name, + "panic": fmt.Sprintf("%v", r), + }) + } + }() + if err := h.Handler(ctx, event); err != nil { + logger.WarnCF("hooks", "Hook error", + map[string]any{ + "hook": hookName, + "handler": h.Name, + "error": err.Error(), + }) + } + }() + if cancelCheck(event) { + logger.InfoCF("hooks", "Hook canceled operation", + map[string]any{ + "hook": hookName, + "handler": h.Name, + }) + return + } + } +} + +// TriggerMessageReceived fires all message_received handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerMessageReceived(ctx context.Context, event *MessageReceivedEvent) { + r.mu.RLock() + hooks := r.messageReceived + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "message_received") +} + +func (r *HookRegistry) TriggerMessageSending(ctx context.Context, event *MessageSendingEvent) { + r.mu.RLock() + hooks := r.messageSending + r.mu.RUnlock() + triggerModifying(ctx, hooks, event, "message_sending", func(e *MessageSendingEvent) bool { + return e.Cancel + }) +} + +func (r *HookRegistry) TriggerBeforeToolCall(ctx context.Context, event *BeforeToolCallEvent) { + r.mu.RLock() + hooks := r.beforeToolCall + r.mu.RUnlock() + triggerModifying(ctx, hooks, event, "before_tool_call", func(e *BeforeToolCallEvent) bool { + return e.Cancel + }) +} + +// TriggerAfterToolCall fires all after_tool_call handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerAfterToolCall(ctx context.Context, event *AfterToolCallEvent) { + r.mu.RLock() + hooks := r.afterToolCall + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "after_tool_call") +} + +// TriggerLLMInput fires all llm_input handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerLLMInput(ctx context.Context, event *LLMInputEvent) { + r.mu.RLock() + hooks := r.llmInput + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "llm_input") +} + +// TriggerLLMOutput fires all llm_output handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerLLMOutput(ctx context.Context, event *LLMOutputEvent) { + r.mu.RLock() + hooks := r.llmOutput + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "llm_output") +} + +// TriggerSessionStart fires all session_start handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerSessionStart(ctx context.Context, event *SessionEvent) { + r.mu.RLock() + hooks := r.sessionStart + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "session_start") +} + +// TriggerSessionEnd fires all session_end handlers concurrently. +// Handler mutations are isolated per hook invocation and are not propagated. +func (r *HookRegistry) TriggerSessionEnd(ctx context.Context, event *SessionEvent) { + r.mu.RLock() + hooks := r.sessionEnd + r.mu.RUnlock() + triggerVoid(ctx, hooks, event, "session_end") +} diff --git a/pkg/hooks/hooks_test.go b/pkg/hooks/hooks_test.go new file mode 100644 index 000000000..8ca3e4c2c --- /dev/null +++ b/pkg/hooks/hooks_test.go @@ -0,0 +1,567 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package hooks + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +func TestNewHookRegistry(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // Triggering all hooks on an empty registry should not panic. + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "hello"}) + r.TriggerMessageSending(ctx, &MessageSendingEvent{Content: "hello"}) + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{ToolName: "t"}) + r.TriggerAfterToolCall(ctx, &AfterToolCallEvent{ToolName: "t"}) + r.TriggerLLMInput(ctx, &LLMInputEvent{AgentID: "a"}) + r.TriggerLLMOutput(ctx, &LLMOutputEvent{AgentID: "a"}) + r.TriggerSessionStart(ctx, &SessionEvent{AgentID: "a"}) + r.TriggerSessionEnd(ctx, &SessionEvent{AgentID: "a"}) +} + +func TestVoidHookExecution(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var called atomic.Bool + r.OnMessageReceived("test", 0, func(_ context.Context, e *MessageReceivedEvent) error { + called.Store(true) + if e.Content != "ping" { + t.Errorf("Expected content 'ping', got '%s'", e.Content) + } + return nil + }) + + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "ping"}) + + if !called.Load() { + t.Error("Expected handler to be called") + } +} + +func TestVoidHooksConcurrent(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var count atomic.Int32 + started := make(chan struct{}, 5) + release := make(chan struct{}) + done := make(chan struct{}) + + for i := range 5 { + r.OnMessageReceived("hook-"+string(rune('A'+i)), i, func(_ context.Context, _ *MessageReceivedEvent) error { + started <- struct{}{} + <-release + count.Add(1) + return nil + }) + } + + go func() { + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "test"}) + close(done) + }() + + // All 5 handlers must reach the barrier concurrently. + for i := range 5 { + select { + case <-started: + case <-time.After(1 * time.Second): + t.Fatalf("timeout waiting for handler %d to start", i+1) + } + } + + // Release all handlers. + close(release) + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for handlers to complete") + } + + if count.Load() != 5 { + t.Errorf("Expected 5 handlers called, got %d", count.Load()) + } +} + +func TestVoidHooksReceiveIsolatedMessageReceivedEvents(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnMessageReceived("mutator-a", 0, func(_ context.Context, e *MessageReceivedEvent) error { + e.Content = "changed-a" + e.Media[0] = "changed-media-a" + e.Metadata["k"] = "changed-a" + e.Metadata["new-a"] = "x" + return nil + }) + r.OnMessageReceived("mutator-b", 1, func(_ context.Context, e *MessageReceivedEvent) error { + e.Content = "changed-b" + e.Media = append(e.Media, "extra") + e.Metadata["k"] = "changed-b" + e.Metadata["new-b"] = "y" + return nil + }) + + event := &MessageReceivedEvent{ + Content: "original", + Media: []string{"m1"}, + Metadata: map[string]string{"k": "v"}, + } + r.TriggerMessageReceived(ctx, event) + + if event.Content != "original" { + t.Fatalf("expected original content to remain unchanged, got %q", event.Content) + } + if len(event.Media) != 1 || event.Media[0] != "m1" { + t.Fatalf("expected original media to remain unchanged, got %#v", event.Media) + } + if got := event.Metadata["k"]; got != "v" { + t.Fatalf("expected metadata[k] to remain v, got %q", got) + } + if _, ok := event.Metadata["new-a"]; ok { + t.Fatal("unexpected mutation leaked from hook mutator-a") + } + if _, ok := event.Metadata["new-b"]; ok { + t.Fatal("unexpected mutation leaked from hook mutator-b") + } +} + +func TestVoidHooksReceiveIsolatedAfterToolCallEvents(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnAfterToolCall("mutator-a", 0, func(_ context.Context, e *AfterToolCallEvent) error { + e.Args["k"] = "changed-a" + e.Result.ForLLM = "mutated-a" + return nil + }) + r.OnAfterToolCall("mutator-b", 1, func(_ context.Context, e *AfterToolCallEvent) error { + e.Args["k"] = "changed-b" + e.Args["new"] = "v" + e.Result.ForUser = "mutated-b" + return nil + }) + + event := &AfterToolCallEvent{ + ToolName: "shell", + Args: map[string]any{"k": "original"}, + Result: &tools.ToolResult{ + ForLLM: "for-llm", + ForUser: "for-user", + }, + } + + // Use a local copy so we can compare immutable expectations. + r.TriggerAfterToolCall(ctx, event) + + if got := event.Args["k"]; got != "original" { + t.Fatalf("expected args[k] to remain original, got %#v", got) + } + if _, ok := event.Args["new"]; ok { + t.Fatal("unexpected args mutation leaked from hook") + } + if event.Result.ForLLM != "for-llm" { + t.Fatalf("expected original result.ForLLM to remain unchanged, got %q", event.Result.ForLLM) + } + if event.Result.ForUser != "for-user" { + t.Fatalf("expected original result.ForUser to remain unchanged, got %q", event.Result.ForUser) + } +} + +func TestVoidHooksReceiveIsolatedLLMInputToolSchema(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnLLMInput("mutator", 0, func(_ context.Context, e *LLMInputEvent) error { + required, ok := e.Tools[0].Function.Parameters["required"].([]string) + if !ok { + t.Fatal("required should be []string") + } + required[0] = "mutated" + e.Tools[0].Function.Parameters["required"] = append(required, "extra") + return nil + }) + + event := &LLMInputEvent{ + AgentID: "a1", + Model: "m1", + Tools: []providers.ToolDefinition{ + { + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: "message", + Parameters: map[string]any{ + "type": "object", + "required": []string{"content"}, + }, + }, + }, + }, + } + + r.TriggerLLMInput(ctx, event) + + required, ok := event.Tools[0].Function.Parameters["required"].([]string) + if !ok { + t.Fatal("required should remain []string") + } + if len(required) != 1 || required[0] != "content" { + t.Fatalf("expected required to remain unchanged, got %#v", required) + } +} + +func TestModifyingHookPriority(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var mu sync.Mutex + var order []string + + // Register in reverse priority order to verify sorting. + r.OnMessageSending("third", 30, func(_ context.Context, _ *MessageSendingEvent) error { + mu.Lock() + order = append(order, "third") + mu.Unlock() + return nil + }) + r.OnMessageSending("first", 10, func(_ context.Context, _ *MessageSendingEvent) error { + mu.Lock() + order = append(order, "first") + mu.Unlock() + return nil + }) + r.OnMessageSending("second", 20, func(_ context.Context, _ *MessageSendingEvent) error { + mu.Lock() + order = append(order, "second") + mu.Unlock() + return nil + }) + + r.TriggerMessageSending(ctx, &MessageSendingEvent{Content: "hi"}) + + if len(order) != 3 { + t.Fatalf("Expected 3 handlers, got %d", len(order)) + } + if order[0] != "first" || order[1] != "second" || order[2] != "third" { + t.Errorf("Expected [first second third], got %v", order) + } +} + +func TestModifyingHookCancel(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var secondCalled bool + + r.OnMessageSending("canceler", 10, func(_ context.Context, e *MessageSendingEvent) error { + e.Cancel = true + e.CancelReason = "blocked" + return nil + }) + r.OnMessageSending("after-cancel", 20, func(_ context.Context, _ *MessageSendingEvent) error { + secondCalled = true + return nil + }) + + event := &MessageSendingEvent{Content: "hi"} + r.TriggerMessageSending(ctx, event) + + if !event.Cancel { + t.Error("Expected Cancel to be true") + } + if secondCalled { + t.Error("Expected second handler NOT to be called after cancel") + } +} + +func TestBeforeToolCallModification(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnBeforeToolCall("modifier", 10, func(_ context.Context, e *BeforeToolCallEvent) error { + e.Args["injected"] = "value" + return nil + }) + + event := &BeforeToolCallEvent{ + ToolName: "search", + Args: map[string]any{"query": "test"}, + } + r.TriggerBeforeToolCall(ctx, event) + + if event.Args["injected"] != "value" { + t.Error("Expected injected arg to persist") + } + if event.Args["query"] != "test" { + t.Error("Expected original arg to remain") + } +} + +func TestMessageSendingFilter(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + r.OnMessageSending("rewriter", 10, func(_ context.Context, e *MessageSendingEvent) error { + e.Content = "[filtered] " + e.Content + return nil + }) + + event := &MessageSendingEvent{Content: "hello world"} + r.TriggerMessageSending(ctx, event) + + if event.Content != "[filtered] hello world" { + t.Errorf("Expected '[filtered] hello world', got '%s'", event.Content) + } +} + +func TestZeroCostWhenEmpty(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // This is primarily a safety/smoke test — no panics, no allocations of note. + for range 100 { + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{}) + r.TriggerMessageSending(ctx, &MessageSendingEvent{}) + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{}) + r.TriggerAfterToolCall(ctx, &AfterToolCallEvent{}) + r.TriggerLLMInput(ctx, &LLMInputEvent{}) + r.TriggerLLMOutput(ctx, &LLMOutputEvent{}) + r.TriggerSessionStart(ctx, &SessionEvent{}) + r.TriggerSessionEnd(ctx, &SessionEvent{}) + } +} + +func TestLLMInputOutput(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var inputCalled, outputCalled atomic.Bool + + r.OnLLMInput("input-hook", 0, func(_ context.Context, e *LLMInputEvent) error { + if e.Model != "gpt-4" { + t.Errorf("Expected model 'gpt-4', got '%s'", e.Model) + } + inputCalled.Store(true) + return nil + }) + + r.OnLLMOutput("output-hook", 0, func(_ context.Context, e *LLMOutputEvent) error { + if e.Content != "response" { + t.Errorf("Expected content 'response', got '%s'", e.Content) + } + outputCalled.Store(true) + return nil + }) + + r.TriggerLLMInput(ctx, &LLMInputEvent{AgentID: "a1", Model: "gpt-4", Iteration: 1}) + r.TriggerLLMOutput(ctx, &LLMOutputEvent{AgentID: "a1", Model: "gpt-4", Content: "response", Iteration: 1}) + + if !inputCalled.Load() { + t.Error("Expected LLM input hook to be called") + } + if !outputCalled.Load() { + t.Error("Expected LLM output hook to be called") + } +} + +func TestSessionStartEnd(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var startCalled, endCalled atomic.Bool + + r.OnSessionStart("start-hook", 0, func(_ context.Context, e *SessionEvent) error { + if e.SessionKey != "sess-1" { + t.Errorf("Expected session key 'sess-1', got '%s'", e.SessionKey) + } + startCalled.Store(true) + return nil + }) + + r.OnSessionEnd("end-hook", 0, func(_ context.Context, e *SessionEvent) error { + if e.SessionKey != "sess-1" { + t.Errorf("Expected session key 'sess-1', got '%s'", e.SessionKey) + } + endCalled.Store(true) + return nil + }) + + event := &SessionEvent{AgentID: "a1", SessionKey: "sess-1", Channel: "test", ChatID: "c1"} + r.TriggerSessionStart(ctx, event) + r.TriggerSessionEnd(ctx, event) + + if !startCalled.Load() { + t.Error("Expected session start hook to be called") + } + if !endCalled.Load() { + t.Error("Expected session end hook to be called") + } +} + +func TestConcurrentRegistrationAndTrigger(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var wg sync.WaitGroup + + // Goroutines registering hooks. + for i := range 10 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + r.OnMessageReceived( + fmt.Sprintf("reg-hook-%d", idx), + idx, + func(_ context.Context, _ *MessageReceivedEvent) error { + return nil + }, + ) + }(i) + } + + // Goroutines triggering hooks concurrently. + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "race"}) + }() + } + + wg.Wait() +} + +func TestInsertSorted(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var order []int + + // Register with priorities: 50, 10, 30, 20, 40 + priorities := []int{50, 10, 30, 20, 40} + for _, p := range priorities { + r.OnBeforeToolCall(fmt.Sprintf("p-%d", p), p, func(_ context.Context, _ *BeforeToolCallEvent) error { + order = append(order, p) + return nil + }) + } + + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{ToolName: "test", Args: map[string]any{}}) + + expected := []int{10, 20, 30, 40, 50} + if len(order) != len(expected) { + t.Fatalf("Expected %d handlers, got %d", len(expected), len(order)) + } + for i, v := range expected { + if order[i] != v { + t.Errorf("Position %d: expected priority %d, got %d", i, v, order[i]) + } + } +} + +func TestAfterToolCallExecution(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + var called bool + var capturedName string + r.OnAfterToolCall("logger", 0, func(_ context.Context, event *AfterToolCallEvent) error { + called = true + capturedName = event.ToolName + return nil + }) + + r.TriggerAfterToolCall(ctx, &AfterToolCallEvent{ + ToolName: "shell", + Args: map[string]any{"cmd": "ls"}, + Channel: "telegram", + ChatID: "123", + }) + + if !called { + t.Error("Expected after_tool_call handler to be called") + } + if capturedName != "shell" { + t.Errorf("Expected ToolName 'shell', got '%s'", capturedName) + } +} + +func TestHandlerErrorsSwallowed(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // Test void hooks: error in one handler doesn't prevent others from running + var secondCalled bool + r.OnMessageReceived("erroring", 10, func(_ context.Context, _ *MessageReceivedEvent) error { + return fmt.Errorf("handler error") + }) + r.OnMessageReceived("observer", 20, func(_ context.Context, _ *MessageReceivedEvent) error { + secondCalled = true + return nil + }) + + r.TriggerMessageReceived(ctx, &MessageReceivedEvent{Content: "test"}) + if !secondCalled { + t.Error("Expected second void handler to run despite first handler's error") + } + + // Test modifying hooks: error doesn't stop chain (only Cancel does) + var modifySecondCalled bool + r.OnMessageSending("erroring", 10, func(_ context.Context, _ *MessageSendingEvent) error { + return fmt.Errorf("handler error") + }) + r.OnMessageSending("modifier", 20, func(_ context.Context, _ *MessageSendingEvent) error { + modifySecondCalled = true + return nil + }) + + r.TriggerMessageSending(ctx, &MessageSendingEvent{Content: "test"}) + if !modifySecondCalled { + t.Error("Expected second modifying handler to run despite first handler's error") + } +} + +func TestPanicRecovery(t *testing.T) { + r := NewHookRegistry() + ctx := context.Background() + + // Void hook: panic in one handler shouldn't crash, other handlers should still run + var safeHandlerCalled bool + r.OnLLMInput("panicker", 10, func(_ context.Context, _ *LLMInputEvent) error { + panic("boom") + }) + r.OnLLMInput("safe", 10, func(_ context.Context, _ *LLMInputEvent) error { + safeHandlerCalled = true + return nil + }) + + // Should not panic + r.TriggerLLMInput(ctx, &LLMInputEvent{AgentID: "test"}) + if !safeHandlerCalled { + t.Error("Expected safe handler to run despite panicking sibling") + } + + // Modifying hook: panic in handler shouldn't crash + r.OnBeforeToolCall("panicker", 10, func(_ context.Context, _ *BeforeToolCallEvent) error { + panic("boom") + }) + + // Should not panic + r.TriggerBeforeToolCall(ctx, &BeforeToolCallEvent{ToolName: "test"}) +} diff --git a/pkg/hooks/types.go b/pkg/hooks/types.go new file mode 100644 index 000000000..4a0f6697d --- /dev/null +++ b/pkg/hooks/types.go @@ -0,0 +1,82 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package hooks + +import ( + "time" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// MessageReceivedEvent is fired when an inbound message is consumed from the bus. +type MessageReceivedEvent struct { + Channel string + SenderID string + ChatID string + Content string + Media []string + Metadata map[string]string +} + +// MessageSendingEvent is fired before an outbound message is published. +// Handlers can modify Content or set Cancel to block delivery. +type MessageSendingEvent struct { + Channel string + ChatID string + Content string // Modifiable + Cancel bool + CancelReason string +} + +// BeforeToolCallEvent is fired before a tool is executed. +// Handlers can modify Args, or set Cancel to block execution. +type BeforeToolCallEvent struct { + ToolName string + Args map[string]any // Modifiable; guaranteed non-nil when triggered via AgentLoop. + Channel string + ChatID string + Cancel bool + CancelReason string // Message returned to LLM when canceled +} + +// AfterToolCallEvent is fired after a tool completes execution. +type AfterToolCallEvent struct { + ToolName string + Args map[string]any + Channel string + ChatID string + Duration time.Duration + Result *tools.ToolResult +} + +// LLMInputEvent is fired before the LLM provider is called. +type LLMInputEvent struct { + AgentID string + Model string + Messages []providers.Message + Tools []providers.ToolDefinition + Iteration int +} + +// LLMOutputEvent is fired after the LLM provider responds. +type LLMOutputEvent struct { + AgentID string + Model string + Content string + ToolCalls []providers.ToolCall + Iteration int + Duration time.Duration +} + +// SessionEvent is fired at session start and end. +type SessionEvent struct { + AgentID string + SessionKey string + Channel string + ChatID string +} diff --git a/pkg/plugin/demoplugin/policy_demo.go b/pkg/plugin/demoplugin/policy_demo.go new file mode 100644 index 000000000..255907d38 --- /dev/null +++ b/pkg/plugin/demoplugin/policy_demo.go @@ -0,0 +1,328 @@ +package demoplugin + +import ( + "context" + "fmt" + "math" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/hooks" + "github.com/sipeed/picoclaw/pkg/plugin" +) + +// PolicyDemoConfig controls the demo plugin behavior. +type PolicyDemoConfig struct { + BlockedTools []string + RedactPrefixes []string + ChannelToolAllowlist map[string][]string + DenyOutboundPatterns []string + MaxToolTimeoutSecond int +} + +// PolicyDemoStats provides basic evidence that hook paths were executed. +type PolicyDemoStats struct { + BeforeToolCalls int + BlockedToolCalls int + MessageSends int + RedactedMessages int + BlockedMessages int + SessionStarts int + SessionEnds int + AfterToolCalls int + TotalToolDuration time.Duration +} + +// PolicyDemoPlugin demonstrates why plugins are needed: it enforces runtime policy +// at tool-call and outbound-message lifecycle points and collects audit metrics. +type PolicyDemoPlugin struct { + blockedTools map[string]struct{} + prefixes []string + channelAllowlist map[string]map[string]struct{} + denyPatterns []string + maxTimeout int + + mu sync.Mutex + stats PolicyDemoStats +} + +func NewPolicyDemoPlugin(cfg PolicyDemoConfig) *PolicyDemoPlugin { + blocked := make(map[string]struct{}, len(cfg.BlockedTools)) + for _, t := range cfg.BlockedTools { + t = normalizeLower(t) + if t == "" { + continue + } + blocked[t] = struct{}{} + } + + prefixes := make([]string, 0, len(cfg.RedactPrefixes)) + for _, p := range cfg.RedactPrefixes { + p = strings.TrimSpace(p) + if p == "" { + continue + } + prefixes = append(prefixes, p) + } + + allowlist := make(map[string]map[string]struct{}, len(cfg.ChannelToolAllowlist)) + for channel, tools := range cfg.ChannelToolAllowlist { + channel = normalizeLower(channel) + if channel == "" { + continue + } + toolSet := make(map[string]struct{}, len(tools)) + for _, t := range tools { + t = normalizeLower(t) + if t == "" { + continue + } + toolSet[t] = struct{}{} + } + allowlist[channel] = toolSet + } + + patterns := make([]string, 0, len(cfg.DenyOutboundPatterns)) + for _, p := range cfg.DenyOutboundPatterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + patterns = append(patterns, p) + } + + maxTimeout := cfg.MaxToolTimeoutSecond + if maxTimeout < 0 { + maxTimeout = 0 + } + + return &PolicyDemoPlugin{ + blockedTools: blocked, + prefixes: prefixes, + channelAllowlist: allowlist, + denyPatterns: patterns, + maxTimeout: maxTimeout, + } +} + +func (p *PolicyDemoPlugin) Name() string { + return "policy-demo" +} + +func (p *PolicyDemoPlugin) APIVersion() string { + return plugin.APIVersion +} + +func (p *PolicyDemoPlugin) Snapshot() PolicyDemoStats { + p.mu.Lock() + defer p.mu.Unlock() + return p.stats +} + +func (p *PolicyDemoPlugin) Register(r *hooks.HookRegistry) error { + r.OnBeforeToolCall("policy-demo-tool-policy", 100, func(_ context.Context, e *hooks.BeforeToolCallEvent) error { + tool := normalizeLower(e.ToolName) + p.incBeforeToolCalls() + + if _, blocked := p.blockedTools[tool]; blocked { + e.Cancel = true + e.CancelReason = "blocked by policy-demo plugin" + p.incBlockedToolCalls() + return nil + } + + channel := normalizeLower(e.Channel) + if allow, ok := p.channelAllowlist[channel]; ok { + if _, allowed := allow[tool]; !allowed { + e.Cancel = true + e.CancelReason = fmt.Sprintf("tool %q is not allowed on channel %q", e.ToolName, e.Channel) + p.incBlockedToolCalls() + return nil + } + } + + if p.maxTimeout > 0 { + clampArgNumber(e.Args, "timeout", p.maxTimeout) + clampArgNumber(e.Args, "timeout_seconds", p.maxTimeout) + } + return nil + }) + + r.OnMessageSending("policy-demo-redact-and-guard", 50, func(_ context.Context, e *hooks.MessageSendingEvent) error { + p.incMessageSends() + + for _, pattern := range p.denyPatterns { + if strings.Contains(e.Content, pattern) { + e.Cancel = true + e.CancelReason = "blocked by policy-demo outbound guard" + p.incBlockedMessages() + return nil + } + } + + content := e.Content + redacted := false + for _, prefix := range p.prefixes { + next := strings.ReplaceAll(content, prefix, "[redacted]-") + if next != content { + redacted = true + } + content = next + } + e.Content = content + if redacted { + p.incRedactedMessages() + } + return nil + }) + + r.OnSessionStart("policy-demo-session-start-audit", 0, func(_ context.Context, _ *hooks.SessionEvent) error { + p.incSessionStarts() + return nil + }) + + r.OnSessionEnd("policy-demo-session-end-audit", 0, func(_ context.Context, _ *hooks.SessionEvent) error { + p.incSessionEnds() + return nil + }) + + r.OnAfterToolCall("policy-demo-after-tool-audit", 0, func(_ context.Context, e *hooks.AfterToolCallEvent) error { + p.incAfterToolCall(e.Duration) + return nil + }) + + return nil +} + +func normalizeLower(s string) string { + return strings.ToLower(strings.TrimSpace(s)) +} + +func clampArgNumber(args map[string]any, key string, limit int) { + if args == nil || limit <= 0 { + return + } + v, ok := args[key] + if !ok { + return + } + n, ok := toInt(v) + if !ok { + return + } + if n > limit { + args[key] = limit + } +} + +func toInt(v any) (int, bool) { + maxInt := int(^uint(0) >> 1) + maxIntU64 := uint64(maxInt) + maxInt64 := int64(maxInt) + minInt64 := -maxInt64 - 1 + + switch n := v.(type) { + case int: + return n, true + case int8: + return int(n), true + case int16: + return int(n), true + case int32: + return int(n), true + case int64: + if n < minInt64 || n > maxInt64 { + return 0, false + } + return int(n), true + case uint: + if uint64(n) > maxIntU64 { + return 0, false + } + return int(n), true + case uint8: + return int(n), true + case uint16: + return int(n), true + case uint32: + if uint64(n) > maxIntU64 { + return 0, false + } + return int(n), true + case uint64: + if n > maxIntU64 { + return 0, false + } + return int(n), true + case float32: + // Truncation is intentional for timeout normalization. + if math.IsNaN(float64(n)) || math.IsInf(float64(n), 0) { + return 0, false + } + if n < float32(minInt64) || n > float32(maxInt64) { + return 0, false + } + return int(n), true + case float64: + // Truncation is intentional for timeout normalization. + if math.IsNaN(n) || math.IsInf(n, 0) { + return 0, false + } + if n < float64(minInt64) || n > float64(maxInt64) { + return 0, false + } + return int(n), true + default: + return 0, false + } +} + +func (p *PolicyDemoPlugin) incBeforeToolCalls() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.BeforeToolCalls++ +} + +func (p *PolicyDemoPlugin) incBlockedToolCalls() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.BlockedToolCalls++ +} + +func (p *PolicyDemoPlugin) incMessageSends() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.MessageSends++ +} + +func (p *PolicyDemoPlugin) incRedactedMessages() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.RedactedMessages++ +} + +func (p *PolicyDemoPlugin) incBlockedMessages() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.BlockedMessages++ +} + +func (p *PolicyDemoPlugin) incSessionStarts() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.SessionStarts++ +} + +func (p *PolicyDemoPlugin) incSessionEnds() { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.SessionEnds++ +} + +func (p *PolicyDemoPlugin) incAfterToolCall(d time.Duration) { + p.mu.Lock() + defer p.mu.Unlock() + p.stats.AfterToolCalls++ + p.stats.TotalToolDuration += d +} diff --git a/pkg/plugin/demoplugin/policy_demo_test.go b/pkg/plugin/demoplugin/policy_demo_test.go new file mode 100644 index 000000000..aec08d8ea --- /dev/null +++ b/pkg/plugin/demoplugin/policy_demo_test.go @@ -0,0 +1,194 @@ +package demoplugin + +import ( + "context" + "math" + "strconv" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/hooks" + "github.com/sipeed/picoclaw/pkg/plugin" +) + +func TestPolicyDemoPluginBlocksConfiguredTool(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + BlockedTools: []string{"shell"}, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.BeforeToolCallEvent{ToolName: "shell", Args: map[string]any{}, Channel: "cli"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), e) + + if !e.Cancel { + t.Fatal("expected tool call to be canceled") + } + if e.CancelReason == "" { + t.Fatal("expected cancel reason") + } + + stats := p.Snapshot() + if stats.BeforeToolCalls != 1 || stats.BlockedToolCalls != 1 { + t.Fatalf("unexpected stats: %+v", stats) + } +} + +func TestPolicyDemoPluginRedactsOutboundContent(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + RedactPrefixes: []string{"sk-"}, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.MessageSendingEvent{Content: "token=sk-abc123"} + pm.HookRegistry().TriggerMessageSending(context.Background(), e) + + if e.Cancel { + t.Fatal("did not expect cancellation") + } + if e.Content != "token=[redacted]-abc123" { + t.Fatalf("unexpected redaction result: %q", e.Content) + } + + stats := p.Snapshot() + if stats.MessageSends != 1 || stats.RedactedMessages != 1 { + t.Fatalf("unexpected stats: %+v", stats) + } +} + +func TestPolicyDemoPluginChannelAllowlist(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + ChannelToolAllowlist: map[string][]string{ + "telegram": {"web_search"}, + }, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + blocked := &hooks.BeforeToolCallEvent{ToolName: "shell", Args: map[string]any{}, Channel: "telegram"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), blocked) + if !blocked.Cancel { + t.Fatal("expected tool to be blocked by channel allowlist") + } + + allowed := &hooks.BeforeToolCallEvent{ToolName: "web_search", Args: map[string]any{}, Channel: "telegram"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), allowed) + if allowed.Cancel { + t.Fatalf("did not expect allowlisted tool to be blocked: %s", allowed.CancelReason) + } +} + +func TestPolicyDemoPluginOutboundGuard(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{ + DenyOutboundPatterns: []string{"4111-1111-1111-1111", "@corp.internal"}, + }) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.MessageSendingEvent{Content: "card=4111-1111-1111-1111"} + pm.HookRegistry().TriggerMessageSending(context.Background(), e) + if !e.Cancel { + t.Fatal("expected outbound message to be blocked") + } + if e.CancelReason == "" { + t.Fatal("expected block reason") + } + + stats := p.Snapshot() + if stats.BlockedMessages != 1 { + t.Fatalf("expected blocked message count to be 1, got %+v", stats) + } +} + +func TestPolicyDemoPluginNormalizesTimeoutArg(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{MaxToolTimeoutSecond: 30}) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + e := &hooks.BeforeToolCallEvent{ + ToolName: "web_fetch", + Channel: "cli", + Args: map[string]any{ + "timeout": 120, + "timeout_seconds": 90.0, + }, + } + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), e) + + if got, ok := e.Args["timeout"].(int); !ok || got != 30 { + t.Fatalf("expected timeout to be clamped to 30, got %#v", e.Args["timeout"]) + } + if got, ok := e.Args["timeout_seconds"].(int); !ok || got != 30 { + t.Fatalf("expected timeout_seconds to be clamped to 30, got %#v", e.Args["timeout_seconds"]) + } +} + +func TestPolicyDemoPluginAuditHooks(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{}) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + pm.HookRegistry().TriggerSessionStart(context.Background(), &hooks.SessionEvent{AgentID: "a1", SessionKey: "s1"}) + pm.HookRegistry(). + TriggerAfterToolCall(context.Background(), &hooks.AfterToolCallEvent{ToolName: "web_search", Duration: 45 * time.Millisecond}) + pm.HookRegistry().TriggerSessionEnd(context.Background(), &hooks.SessionEvent{AgentID: "a1", SessionKey: "s1"}) + + stats := p.Snapshot() + if stats.SessionStarts != 1 || stats.SessionEnds != 1 { + t.Fatalf("unexpected session stats: %+v", stats) + } + if stats.AfterToolCalls != 1 || stats.TotalToolDuration != 45*time.Millisecond { + t.Fatalf("unexpected after_tool_call stats: %+v", stats) + } +} + +func TestPolicyDemoPluginNoConfigNoEffect(t *testing.T) { + pm := plugin.NewManager() + p := NewPolicyDemoPlugin(PolicyDemoConfig{}) + if err := pm.Register(p); err != nil { + t.Fatalf("register plugin: %v", err) + } + + toolEvent := &hooks.BeforeToolCallEvent{ToolName: "shell", Args: map[string]any{}, Channel: "telegram"} + pm.HookRegistry().TriggerBeforeToolCall(context.Background(), toolEvent) + if toolEvent.Cancel { + t.Fatal("did not expect cancellation with empty config") + } + + msgEvent := &hooks.MessageSendingEvent{Content: "token=sk-abc123"} + pm.HookRegistry().TriggerMessageSending(context.Background(), msgEvent) + if msgEvent.Content != "token=sk-abc123" { + t.Fatalf("did not expect content rewrite, got %q", msgEvent.Content) + } +} + +func TestToIntRejectsInt64OverflowOn32Bit(t *testing.T) { + if strconv.IntSize != 32 { + t.Skip("overflow scenario is specific to 32-bit int") + } + if _, ok := toInt(int64(1 << 40)); ok { + t.Fatal("expected overflow conversion to fail on 32-bit int") + } +} + +func TestToIntRejectsInvalidFloatValues(t *testing.T) { + cases := []float64{math.NaN(), math.Inf(1), math.Inf(-1), math.MaxFloat64} + for _, v := range cases { + if _, ok := toInt(v); ok { + t.Fatalf("expected float conversion to fail for %v", v) + } + } +} diff --git a/pkg/plugin/manager.go b/pkg/plugin/manager.go new file mode 100644 index 000000000..edc2408ed --- /dev/null +++ b/pkg/plugin/manager.go @@ -0,0 +1,99 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package plugin + +import ( + "errors" + "fmt" + "slices" + "strings" + "sync" + + "github.com/sipeed/picoclaw/pkg/hooks" +) + +// APIVersion identifies the compile-time plugin contract version. +const APIVersion = "v1alpha1" + +// Plugin is the Phase-1 compile-time contract for PicoClaw extensions. +type Plugin interface { + Name() string + APIVersion() string + Register(registry *hooks.HookRegistry) error +} + +// Manager owns a shared hook registry and loaded plugin metadata. +type Manager struct { + mu sync.RWMutex + registry *hooks.HookRegistry + names []string + seen map[string]struct{} +} + +// NewManager creates an empty plugin manager with a fresh hook registry. +func NewManager() *Manager { + return &Manager{ + registry: hooks.NewHookRegistry(), + seen: make(map[string]struct{}), + } +} + +// HookRegistry returns the shared registry where plugins register hooks. +func (m *Manager) HookRegistry() *hooks.HookRegistry { + return m.registry +} + +// Names returns loaded plugin names in registration order. +func (m *Manager) Names() []string { + m.mu.RLock() + defer m.mu.RUnlock() + return slices.Clone(m.names) +} + +// Register loads one plugin into the shared hook registry. +func (m *Manager) Register(p Plugin) error { + if p == nil { + return errors.New("plugin is nil") + } + name := strings.TrimSpace(p.Name()) + if name == "" { + return errors.New("plugin name is required") + } + if got := strings.TrimSpace(p.APIVersion()); got != APIVersion { + if got == "" { + got = "" + } + return fmt.Errorf( + "plugin %q api version mismatch: got %s, want %s", + name, + got, + APIVersion, + ) + } + + m.mu.Lock() + defer m.mu.Unlock() + if _, exists := m.seen[name]; exists { + return fmt.Errorf("plugin %q already registered", name) + } + if err := p.Register(m.registry); err != nil { + return fmt.Errorf("register plugin %q: %w", name, err) + } + m.seen[name] = struct{}{} + m.names = append(m.names, name) + return nil +} + +// RegisterAll loads plugins sequentially. +func (m *Manager) RegisterAll(plugins ...Plugin) error { + for _, p := range plugins { + if err := m.Register(p); err != nil { + return err + } + } + return nil +} diff --git a/pkg/plugin/manager_test.go b/pkg/plugin/manager_test.go new file mode 100644 index 000000000..8b8e431e7 --- /dev/null +++ b/pkg/plugin/manager_test.go @@ -0,0 +1,131 @@ +package plugin + +import ( + "context" + "errors" + "testing" + + "github.com/sipeed/picoclaw/pkg/hooks" +) + +type testPlugin struct { + name string + apiVersion string + registerFn func(*hooks.HookRegistry) error +} + +func (p testPlugin) Name() string { + return p.name +} + +func (p testPlugin) Register(r *hooks.HookRegistry) error { + if p.registerFn != nil { + return p.registerFn(r) + } + return nil +} + +func (p testPlugin) APIVersion() string { + if p.apiVersion == "" { + return APIVersion + } + return p.apiVersion +} + +func TestNewManager(t *testing.T) { + m := NewManager() + if m == nil { + t.Fatal("expected manager") + } + if m.HookRegistry() == nil { + t.Fatal("expected non-nil hook registry") + } + if len(m.Names()) != 0 { + t.Fatalf("expected empty names, got %v", m.Names()) + } +} + +func TestRegisterPluginAndTriggerHook(t *testing.T) { + m := NewManager() + called := false + p := testPlugin{ + name: "audit", + registerFn: func(r *hooks.HookRegistry) error { + r.OnSessionStart("audit-session", 0, func(_ context.Context, _ *hooks.SessionEvent) error { + called = true + return nil + }) + return nil + }, + } + + if err := m.Register(p); err != nil { + t.Fatalf("Register() error = %v", err) + } + if got := m.Names(); len(got) != 1 || got[0] != "audit" { + t.Fatalf("unexpected names: %v", got) + } + + m.HookRegistry().TriggerSessionStart(context.Background(), &hooks.SessionEvent{ + AgentID: "a1", + SessionKey: "s1", + }) + if !called { + t.Fatal("expected plugin hook to be called") + } +} + +func TestRegisterRejectsNilPlugin(t *testing.T) { + m := NewManager() + if err := m.Register(nil); err == nil { + t.Fatal("expected error for nil plugin") + } +} + +func TestRegisterRejectsEmptyName(t *testing.T) { + m := NewManager() + if err := m.Register(testPlugin{}); err == nil { + t.Fatal("expected error for empty name") + } +} + +func TestRegisterRejectsDuplicateName(t *testing.T) { + m := NewManager() + p := testPlugin{name: "dup"} + if err := m.Register(p); err != nil { + t.Fatalf("unexpected first register error: %v", err) + } + if err := m.Register(p); err == nil { + t.Fatal("expected duplicate name error") + } +} + +func TestRegisterPropagatesPluginError(t *testing.T) { + m := NewManager() + want := errors.New("register failed") + p := testPlugin{ + name: "bad", + registerFn: func(_ *hooks.HookRegistry) error { + return want + }, + } + err := m.Register(p) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, want) { + t.Fatalf("expected wrapped error %v, got %v", want, err) + } +} + +func TestRegisterRejectsPluginVersionMismatch(t *testing.T) { + m := NewManager() + p := testPlugin{ + name: "old-plugin", + apiVersion: "v0", + } + err := m.Register(p) + if err == nil { + t.Fatal("expected version mismatch error") + } +} diff --git a/pkg/tools/message.go b/pkg/tools/message.go index 15ef4ff73..6391ee8ab 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -5,10 +5,10 @@ import ( "fmt" ) -type SendCallback func(channel, chatID, content string) error +type SendCallbackWithContext func(ctx context.Context, channel, chatID, content string) error type MessageTool struct { - sendCallback SendCallback + sendCallback SendCallbackWithContext defaultChannel string defaultChatID string sentInRound bool // Tracks whether a message was sent in the current processing round @@ -58,7 +58,7 @@ func (t *MessageTool) HasSentInRound() bool { return t.sentInRound } -func (t *MessageTool) SetSendCallback(callback SendCallback) { +func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) { t.sendCallback = callback } @@ -86,7 +86,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes return &ToolResult{ForLLM: "Message sending not configured", IsError: true} } - if err := t.sendCallback(channel, chatID, content); err != nil { + if err := t.sendCallback(ctx, channel, chatID, content); err != nil { return &ToolResult{ ForLLM: fmt.Sprintf("sending message: %v", err), IsError: true, diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go index 717c1117b..a111e7a83 100644 --- a/pkg/tools/message_test.go +++ b/pkg/tools/message_test.go @@ -11,7 +11,7 @@ func TestMessageTool_Execute_Success(t *testing.T) { tool.SetContext("test-channel", "test-chat-id") var sentChannel, sentChatID, sentContent string - tool.SetSendCallback(func(channel, chatID, content string) error { + tool.SetSendCallback(func(_ context.Context, channel, chatID, content string) error { sentChannel = channel sentChatID = chatID sentContent = content @@ -63,7 +63,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { tool.SetContext("default-channel", "default-chat-id") var sentChannel, sentChatID string - tool.SetSendCallback(func(channel, chatID, content string) error { + tool.SetSendCallback(func(_ context.Context, channel, chatID, content string) error { sentChannel = channel sentChatID = chatID return nil @@ -94,12 +94,38 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { } } +func TestMessageTool_Execute_PropagatesContext(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + + type keyType string + const key keyType = "k" + ctx := context.WithValue(context.Background(), key, "v") + + seen := "" + tool.SetSendCallback(func(cbCtx context.Context, channel, chatID, content string) error { + val, _ := cbCtx.Value(key).(string) + seen = val + return nil + }) + + result := tool.Execute(ctx, map[string]any{ + "content": "context test", + }) + if result.IsError { + t.Fatalf("unexpected error: %v", result.ForLLM) + } + if seen != "v" { + t.Fatalf("expected propagated context value 'v', got %q", seen) + } +} + func TestMessageTool_Execute_SendFailure(t *testing.T) { tool := NewMessageTool() tool.SetContext("test-channel", "test-chat-id") sendErr := errors.New("network error") - tool.SetSendCallback(func(channel, chatID, content string) error { + tool.SetSendCallback(func(_ context.Context, channel, chatID, content string) error { return sendErr }) @@ -153,7 +179,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { tool := NewMessageTool() // No SetContext called, so defaultChannel and defaultChatID are empty - tool.SetSendCallback(func(channel, chatID, content string) error { + tool.SetSendCallback(func(_ context.Context, channel, chatID, content string) error { return nil })