From 2b9e74007fc3b39a1fc5fd617deca2fc7b25dd38 Mon Sep 17 00:00:00 2001 From: lxowalle Date: Fri, 27 Feb 2026 03:27:26 +0800 Subject: [PATCH 1/2] * refactor tools system --- cmd/picoclaw/internal/gateway/helpers.go | 20 +- pkg/agent/instance.go | 8 +- pkg/agent/loop.go | 79 +-- pkg/config/config.go | 45 +- pkg/config/defaults.go | 43 +- pkg/tools/append_file/append_file.go | 77 +++ pkg/tools/append_file/append_file_test.go | 103 ++++ pkg/tools/{ => common}/filesystem.go | 204 +------- pkg/tools/{ => common}/result.go | 2 +- pkg/tools/{base.go => common/types.go} | 17 +- pkg/tools/common/web.go | 64 +++ pkg/tools/{ => cron}/cron.go | 48 +- pkg/tools/{edit.go => edit_file/edit_file.go} | 97 +--- .../edit_file_test.go} | 129 +---- pkg/tools/{shell.go => exec/exec.go} | 21 +- .../exec_process_unix.go} | 2 +- .../exec_process_windows.go} | 2 +- .../{shell_test.go => exec/exec_test.go} | 2 +- .../exec_timeout_unix_test.go} | 2 +- pkg/tools/filesystem_test.go | 335 +----------- .../find_skills.go} | 13 +- .../find_skills_test.go} | 2 +- pkg/tools/{ => i2c}/i2c.go | 32 +- pkg/tools/{ => i2c}/i2c_linux.go | 54 +- pkg/tools/{ => i2c}/i2c_other.go | 2 +- .../install_skill.go} | 21 +- .../install_skill_test.go} | 2 +- pkg/tools/list_dir/list_dir.go | 70 +++ pkg/tools/list_dir/list_dir_test.go | 73 +++ pkg/tools/{ => message}/message.go | 16 +- pkg/tools/{ => message}/message_test.go | 6 +- pkg/tools/read_file/read_file.go | 55 ++ pkg/tools/read_file/read_file_test.go | 136 +++++ pkg/tools/registry.go | 119 ++++- pkg/tools/registry_test.go | 30 +- pkg/tools/{ => spi}/spi.go | 20 +- pkg/tools/{ => spi}/spi_linux.go | 42 +- pkg/tools/spi/spi_other.go | 15 + pkg/tools/spi_other.go | 13 - pkg/tools/{ => subagent}/spawn.go | 20 +- pkg/tools/{ => subagent}/spawn_test.go | 2 +- pkg/tools/{ => subagent}/subagent.go | 48 +- .../{ => subagent}/subagent_tool_test.go | 2 +- pkg/tools/types.go | 58 --- pkg/tools/types_export.go | 38 ++ pkg/tools/web_fetch/web_fetch.go | 194 +++++++ .../web_fetch_test.go} | 172 +----- .../{web.go => web_search/web_search.go} | 265 +--------- pkg/tools/web_search/web_search_test.go | 159 ++++++ pkg/tools/write_file/write_file.go | 65 +++ pkg/tools/write_file/write_file_test.go | 491 ++++++++++++++++++ 51 files changed, 2056 insertions(+), 1479 deletions(-) create mode 100644 pkg/tools/append_file/append_file.go create mode 100644 pkg/tools/append_file/append_file_test.go rename pkg/tools/{ => common}/filesystem.go (58%) rename pkg/tools/{ => common}/result.go (99%) rename pkg/tools/{base.go => common/types.go} (89%) create mode 100644 pkg/tools/common/web.go rename pkg/tools/{ => cron}/cron.go (85%) rename pkg/tools/{edit.go => edit_file/edit_file.go} (52%) rename pkg/tools/{edit_test.go => edit_file/edit_file_test.go} (69%) rename pkg/tools/{shell.go => exec/exec.go} (94%) rename pkg/tools/{shell_process_unix.go => exec/exec_process_unix.go} (97%) rename pkg/tools/{shell_process_windows.go => exec/exec_process_windows.go} (96%) rename pkg/tools/{shell_test.go => exec/exec_test.go} (99%) rename pkg/tools/{shell_timeout_unix_test.go => exec/exec_timeout_unix_test.go} (99%) rename pkg/tools/{skills_search.go => find_skills/find_skills.go} (88%) rename pkg/tools/{skills_search_test.go => find_skills/find_skills_test.go} (99%) rename pkg/tools/{ => i2c}/i2c.go (78%) rename pkg/tools/{ => i2c}/i2c_linux.go (76%) rename pkg/tools/{ => i2c}/i2c_other.go (97%) rename pkg/tools/{skills_install.go => install_skill/install_skill.go} (89%) rename pkg/tools/{skills_install_test.go => install_skill/install_skill_test.go} (99%) create mode 100644 pkg/tools/list_dir/list_dir.go create mode 100644 pkg/tools/list_dir/list_dir_test.go rename pkg/tools/{ => message}/message.go (85%) rename pkg/tools/{ => message}/message_test.go (98%) create mode 100644 pkg/tools/read_file/read_file.go create mode 100644 pkg/tools/read_file/read_file_test.go rename pkg/tools/{ => spi}/spi.go (87%) rename pkg/tools/{ => spi}/spi_linux.go (76%) create mode 100644 pkg/tools/spi/spi_other.go delete mode 100644 pkg/tools/spi_other.go rename pkg/tools/{ => subagent}/spawn.go (79%) rename pkg/tools/{ => subagent}/spawn_test.go (99%) rename pkg/tools/{ => subagent}/subagent.go (87%) rename pkg/tools/{ => subagent}/subagent_tool_test.go (99%) delete mode 100644 pkg/tools/types.go create mode 100644 pkg/tools/types_export.go create mode 100644 pkg/tools/web_fetch/web_fetch.go rename pkg/tools/{web_test.go => web_fetch/web_fetch_test.go} (70%) rename pkg/tools/{web.go => web_search/web_search.go} (63%) create mode 100644 pkg/tools/web_search/web_search_test.go create mode 100644 pkg/tools/write_file/write_file.go create mode 100644 pkg/tools/write_file/write_file_test.go diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index a06625dc9..8734352d8 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -24,6 +24,7 @@ import ( "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" + cron_tool "github.com/sipeed/picoclaw/pkg/tools/cron" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -232,14 +233,15 @@ func setupCronTool( cronService := cron.NewCronService(cronStorePath, nil) // Create and register CronTool - cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg) - agentLoop.RegisterTool(cronTool) - - // Set the onJob handler - cronService.SetOnJob(func(job *cron.CronJob) (string, error) { - result := cronTool.ExecuteJob(context.Background(), job) - return result, nil - }) - + if cfg.Tools.Cron.Enabled { + cronTool := cron_tool.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg) + agentLoop.RegisterTool(cronTool) + + // Set the onJob handler + cronService.SetOnJob(func(job *cron.CronJob) (string, error) { + result := cronTool.ExecuteJob(context.Background(), job) + return result, nil + }) + } return cronService } diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index a6fd365c7..1b20a2aac 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -47,13 +47,7 @@ func NewAgentInstance( fallbacks := resolveAgentFallbacks(agentCfg, defaults) restrict := defaults.RestrictToWorkspace - toolsRegistry := tools.NewToolRegistry() - toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) - toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict)) - toolsRegistry.Register(tools.NewListDirTool(workspace, restrict)) - toolsRegistry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg)) - toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict)) - toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict)) + toolsRegistry := tools.NewToolRegistry(cfg, workspace, restrict) sessionsDir := filepath.Join(workspace, "sessions") sessionsManager := session.NewSessionManager(sessionsDir) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 693f2227b..5ed27ad05 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -23,9 +23,10 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" - "github.com/sipeed/picoclaw/pkg/skills" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" + "github.com/sipeed/picoclaw/pkg/tools/message" + "github.com/sipeed/picoclaw/pkg/tools/subagent" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -92,63 +93,31 @@ func registerSharedTools( continue } - // Web tools - if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{ - BraveAPIKey: cfg.Tools.Web.Brave.APIKey, - BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, - BraveEnabled: cfg.Tools.Web.Brave.Enabled, - TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey, - TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL, - TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults, - TavilyEnabled: cfg.Tools.Web.Tavily.Enabled, - DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, - DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, - PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, - PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, - PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, - Proxy: cfg.Tools.Web.Proxy, - }); searchTool != nil { - agent.Tools.Register(searchTool) - } - agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy)) - - // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms - agent.Tools.Register(tools.NewI2CTool()) - agent.Tools.Register(tools.NewSPITool()) - // Message tool - messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content string) error { - msgBus.PublishOutbound(bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, - Content: content, + if cfg.Tools.Message.Enabled { + messageTool := message.NewMessageTool() + messageTool.SetSendCallback(func(channel, chatID, content string) error { + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) + return nil }) - return nil - }) - agent.Tools.Register(messageTool) - - // Skill discovery and installation tools - registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ - MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, - ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), - }) - searchCache := skills.NewSearchCache( - cfg.Tools.Skills.SearchCache.MaxSize, - time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second, - ) - agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache)) - agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace)) + agent.Tools.Register(messageTool) + } // Spawn tool with allowlist checker - subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) - subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) - spawnTool := tools.NewSpawnTool(subagentManager) - currentAgentID := agentID - spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { - return registry.CanSpawnSubagent(currentAgentID, targetAgentID) - }) - agent.Tools.Register(spawnTool) + if cfg.Tools.Spawn.Enabled { + subagentManager := subagent.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) + subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) + spawnTool := subagent.NewSpawnTool(subagentManager) + currentAgentID := agentID + spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { + return registry.CanSpawnSubagent(currentAgentID, targetAgentID) + }) + agent.Tools.Register(spawnTool) + } } } @@ -178,7 +147,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { defaultAgent := al.registry.GetDefaultAgent() if defaultAgent != nil { if tool, ok := defaultAgent.Tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { + if mt, ok := tool.(*message.MessageTool); ok { alreadySent = mt.HasSentInRound() } } diff --git a/pkg/config/config.go b/pkg/config/config.go index ca5803c35..79a4dcae2 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -452,6 +452,7 @@ type PerplexityConfig struct { } type WebToolsConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_ENABLED"` Brave BraveConfig `json:"brave"` Tavily TavilyConfig `json:"tavily"` DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"` @@ -461,19 +462,53 @@ type WebToolsConfig struct { Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` } -type CronToolsConfig struct { - ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout +type CronToolConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_CRON_ENABLED"` + ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout +} + +type ToolConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_ENABLED"` // Default env var, can be overridden per tool } type ExecConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_EXEC_ENABLED"` EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"` CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"` } type ToolsConfig struct { - Web WebToolsConfig `json:"web"` - Cron CronToolsConfig `json:"cron"` - Exec ExecConfig `json:"exec"` + // Web tools + Web WebToolsConfig `json:"web"` + + // Cron tools + Cron CronToolConfig `json:"cron"` + + // File tools + ReadFile ToolConfig `json:"read_file" env:"PICOCLAW_TOOLS_READ_FILE_ENABLED"` + WriteFile ToolConfig `json:"write_file" env:"PICOCLAW_TOOLS_WRITE_FILE_ENABLED"` + EditFile ToolConfig `json:"edit_file" env:"PICOCLAW_TOOLS_EDIT_FILE_ENABLED"` + AppendFile ToolConfig `json:"append_file" env:"PICOCLAW_TOOLS_APPEND_FILE_ENABLED"` + ListDir ToolConfig `json:"list_dir" env:"PICOCLAW_TOOLS_LIST_DIR_ENABLED"` + + // Exec tool + Exec ExecConfig `json:"exec"` + + // Skills tools + FindSkills ToolConfig `json:"find_skills" env:"PICOCLAW_TOOLS_FIND_SKILLS_ENABLED"` + InstallSkill ToolConfig `json:"install_skill" env:"PICOCLAW_TOOLS_INSTALL_SKILL_ENABLED"` + + // Subagent tools + Spawn ToolConfig `json:"spawn" env:"PICOCLAW_TOOLS_SPAWN_ENABLED"` + + // Message tool + Message ToolConfig `json:"message" env:"PICOCLAW_TOOLS_MESSAGE_ENABLED"` + + // Hardware tools + I2C ToolConfig `json:"i2c" env:"PICOCLAW_TOOLS_I2C_ENABLED"` + SPI ToolConfig `json:"spi" env:"PICOCLAW_TOOLS_SPI_ENABLED"` + + // Skills configuration (registry, cache, etc.) Skills SkillsToolsConfig `json:"skills"` } diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index cf799140d..60431608b 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -293,12 +293,53 @@ func DefaultConfig() *Config { MaxResults: 5, }, }, - Cron: CronToolsConfig{ + Cron: CronToolConfig{ + Enabled: true, ExecTimeoutMinutes: 5, }, + // File tools - each individually configurable + ReadFile: ToolConfig{ + Enabled: true, + }, + WriteFile: ToolConfig{ + Enabled: true, + }, + EditFile: ToolConfig{ + Enabled: false, + }, + AppendFile: ToolConfig{ + Enabled: false, + }, + ListDir: ToolConfig{ + Enabled: false, + }, + // Exec tool Exec: ExecConfig{ + Enabled: true, EnableDenyPatterns: true, }, + // Skills tools + FindSkills: ToolConfig{ + Enabled: true, + }, + InstallSkill: ToolConfig{ + Enabled: true, + }, + // Subagent tools + Spawn: ToolConfig{ + Enabled: true, + }, + // Message tool + Message: ToolConfig{ + Enabled: true, + }, + // Hardware tools + I2C: ToolConfig{ + Enabled: false, + }, + SPI: ToolConfig{ + Enabled: false, + }, Skills: SkillsToolsConfig{ Registries: SkillsRegistriesConfig{ ClawHub: ClawHubRegistryConfig{ diff --git a/pkg/tools/append_file/append_file.go b/pkg/tools/append_file/append_file.go new file mode 100644 index 000000000..855869b8d --- /dev/null +++ b/pkg/tools/append_file/append_file.go @@ -0,0 +1,77 @@ +package append_file + +import ( + "context" + "errors" + "fmt" + "io/fs" + + "github.com/sipeed/picoclaw/pkg/tools/common" +) + +type AppendFileTool struct { + fs common.FileSystem +} + +func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool { + var fs common.FileSystem + if restrict { + fs = &common.SandboxFs{Workspace: workspace} + } else { + fs = &common.HostFs{} + } + return &AppendFileTool{fs: fs} +} + +func (t *AppendFileTool) Name() string { + return "append_file" +} + +func (t *AppendFileTool) Description() string { + return "Append content to the end of a file" +} + +func (t *AppendFileTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "The file path to append to", + }, + "content": map[string]any{ + "type": "string", + "description": "The content to append", + }, + }, + "required": []string{"path", "content"}, + } +} + +func (t *AppendFileTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { + path, ok := args["path"].(string) + if !ok { + return common.ErrorResult("path is required") + } + + content, ok := args["content"].(string) + if !ok { + return common.ErrorResult("content is required") + } + + if err := appendFile(t.fs, path, content); err != nil { + return common.ErrorResult(err.Error()) + } + return common.SilentResult(fmt.Sprintf("Appended to %s", path)) +} + +// appendFile reads the existing content (if any) via sysFs, appends new content, and writes back. +func appendFile(sysFs common.FileSystem, path, appendContent string) error { + content, err := sysFs.ReadFile(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return err + } + + newContent := append(content, []byte(appendContent)...) + return sysFs.WriteFile(path, newContent) +} diff --git a/pkg/tools/append_file/append_file_test.go b/pkg/tools/append_file/append_file_test.go new file mode 100644 index 000000000..4e10aaede --- /dev/null +++ b/pkg/tools/append_file/append_file_test.go @@ -0,0 +1,103 @@ +package append_file + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestAppendFileTool_AppendToExisting verifies appending to an existing file +func TestAppendFileTool_AppendToExisting(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("Hello World"), 0o644) + + tool := NewAppendFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]any{ + "path": testFile, + "content": "\nAppended text", + } + + result := tool.Execute(ctx, args) + + assert.False(t, result.IsError, "Expected success, got error: %s", result.ForLLM) + assert.True(t, result.Silent, "Expected Silent=true for AppendFile") + + content, err := os.ReadFile(testFile) + assert.NoError(t, err) + assert.Contains(t, string(content), "Appended text") + assert.Contains(t, string(content), "Hello World") +} + +// TestAppendFileTool_AppendToNonExistent verifies appending to a non-existent file creates it +func TestAppendFileTool_AppendToNonExistent(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "newfile.txt") + + tool := NewAppendFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]any{ + "path": testFile, + "content": "First content", + } + + result := tool.Execute(ctx, args) + + assert.False(t, result.IsError, "Expected success, got error: %s", result.ForLLM) + + content, err := os.ReadFile(testFile) + assert.NoError(t, err) + assert.Equal(t, "First content", string(content)) +} + +// TestAppendFileTool_MissingPath verifies error handling for missing path +func TestAppendFileTool_MissingPath(t *testing.T) { + tool := NewAppendFileTool("", false) + ctx := context.Background() + args := map[string]any{ + "content": "Some content", + } + + result := tool.Execute(ctx, args) + + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "path is required") +} + +// TestAppendFileTool_MissingContent verifies error handling for missing content +func TestAppendFileTool_MissingContent(t *testing.T) { + tool := NewAppendFileTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": "/tmp/test.txt", + } + + result := tool.Execute(ctx, args) + + assert.True(t, result.IsError) + assert.Contains(t, result.ForLLM, "content is required") +} + +// TestAppendFileTool_RestrictedMode verifies access control +func TestAppendFileTool_RestrictedMode(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("Original"), 0o644) + + tool := NewAppendFileTool(tmpDir, true) + ctx := context.Background() + + // Try to append to a file outside the workspace + args := map[string]any{ + "path": "/etc/passwd", + "content": "Malicious content", + } + + result := tool.Execute(ctx, args) + + assert.True(t, result.IsError) +} diff --git a/pkg/tools/filesystem.go b/pkg/tools/common/filesystem.go similarity index 58% rename from pkg/tools/filesystem.go rename to pkg/tools/common/filesystem.go index 03d461dcc..7590db276 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/common/filesystem.go @@ -1,7 +1,6 @@ -package tools +package common import ( - "context" "fmt" "io/fs" "os" @@ -13,7 +12,7 @@ import ( ) // validatePath ensures the given path is within the workspace if restrict is true. -func validatePath(path, workspace string, restrict bool) (string, error) { +func ValidatePath(path, workspace string, restrict bool) (string, error) { if workspace == "" { return path, fmt.Errorf("workspace is not defined") } @@ -83,183 +82,18 @@ func isWithinWorkspace(candidate, workspace string) bool { return err == nil && filepath.IsLocal(rel) } -type ReadFileTool struct { - fs fileSystem -} - -func NewReadFileTool(workspace string, restrict bool) *ReadFileTool { - var fs fileSystem - if restrict { - fs = &sandboxFs{workspace: workspace} - } else { - fs = &hostFs{} - } - return &ReadFileTool{fs: fs} -} - -func (t *ReadFileTool) Name() string { - return "read_file" -} - -func (t *ReadFileTool) Description() string { - return "Read the contents of a file" -} - -func (t *ReadFileTool) Parameters() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "path": map[string]any{ - "type": "string", - "description": "Path to the file to read", - }, - }, - "required": []string{"path"}, - } -} - -func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult { - path, ok := args["path"].(string) - if !ok { - return ErrorResult("path is required") - } - - content, err := t.fs.ReadFile(path) - if err != nil { - return ErrorResult(err.Error()) - } - return NewToolResult(string(content)) -} - -type WriteFileTool struct { - fs fileSystem -} - -func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool { - var fs fileSystem - if restrict { - fs = &sandboxFs{workspace: workspace} - } else { - fs = &hostFs{} - } - return &WriteFileTool{fs: fs} -} - -func (t *WriteFileTool) Name() string { - return "write_file" -} - -func (t *WriteFileTool) Description() string { - return "Write content to a file" -} - -func (t *WriteFileTool) Parameters() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "path": map[string]any{ - "type": "string", - "description": "Path to the file to write", - }, - "content": map[string]any{ - "type": "string", - "description": "Content to write to the file", - }, - }, - "required": []string{"path", "content"}, - } -} - -func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult { - path, ok := args["path"].(string) - if !ok { - return ErrorResult("path is required") - } - - content, ok := args["content"].(string) - if !ok { - return ErrorResult("content is required") - } - - if err := t.fs.WriteFile(path, []byte(content)); err != nil { - return ErrorResult(err.Error()) - } - - return SilentResult(fmt.Sprintf("File written: %s", path)) -} - -type ListDirTool struct { - fs fileSystem -} - -func NewListDirTool(workspace string, restrict bool) *ListDirTool { - var fs fileSystem - if restrict { - fs = &sandboxFs{workspace: workspace} - } else { - fs = &hostFs{} - } - return &ListDirTool{fs: fs} -} - -func (t *ListDirTool) Name() string { - return "list_dir" -} - -func (t *ListDirTool) Description() string { - return "List files and directories in a path" -} - -func (t *ListDirTool) Parameters() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "path": map[string]any{ - "type": "string", - "description": "Path to list", - }, - }, - "required": []string{"path"}, - } -} - -func (t *ListDirTool) Execute(ctx context.Context, args map[string]any) *ToolResult { - path, ok := args["path"].(string) - if !ok { - path = "." - } - - entries, err := t.fs.ReadDir(path) - if err != nil { - return ErrorResult(fmt.Sprintf("failed to read directory: %v", err)) - } - return formatDirEntries(entries) -} - -func formatDirEntries(entries []os.DirEntry) *ToolResult { - var result strings.Builder - for _, entry := range entries { - if entry.IsDir() { - result.WriteString("DIR: " + entry.Name() + "\n") - } else { - result.WriteString("FILE: " + entry.Name() + "\n") - } - } - return NewToolResult(result.String()) -} - -// fileSystem abstracts reading, writing, and listing files, allowing both +// FileSystem abstracts reading, writing, and listing files, allowing both // unrestricted (host filesystem) and sandbox (os.Root) implementations to share the same polymorphic interface. -type fileSystem interface { +type FileSystem interface { ReadFile(path string) ([]byte, error) WriteFile(path string, data []byte) error ReadDir(path string) ([]os.DirEntry, error) } -// hostFs is an unrestricted fileReadWriter that operates directly on the host filesystem. -type hostFs struct{} +// HostFs is an unrestricted fileReadWriter that operates directly on the host filesystem. +type HostFs struct{} -func (h *hostFs) ReadFile(path string) ([]byte, error) { +func (h *HostFs) ReadFile(path string) ([]byte, error) { content, err := os.ReadFile(path) if err != nil { if os.IsNotExist(err) { @@ -273,33 +107,33 @@ func (h *hostFs) ReadFile(path string) ([]byte, error) { return content, nil } -func (h *hostFs) ReadDir(path string) ([]os.DirEntry, error) { +func (h *HostFs) ReadDir(path string) ([]os.DirEntry, error) { return os.ReadDir(path) } -func (h *hostFs) WriteFile(path string, data []byte) error { +func (h *HostFs) WriteFile(path string, data []byte) error { // Use unified atomic write utility with explicit sync for flash storage reliability. // Using 0o600 (owner read/write only) for secure default permissions. return fileutil.WriteFileAtomic(path, data, 0o600) } -// sandboxFs is a sandboxed fileSystem that operates within a strictly defined workspace using os.Root. -type sandboxFs struct { - workspace string +// SandboxFs is a sandboxed FileSystem that operates within a strictly defined workspace using os.Root. +type SandboxFs struct { + Workspace string } -func (r *sandboxFs) execute(path string, fn func(root *os.Root, relPath string) error) error { - if r.workspace == "" { +func (r *SandboxFs) execute(path string, fn func(root *os.Root, relPath string) error) error { + if r.Workspace == "" { return fmt.Errorf("workspace is not defined") } - root, err := os.OpenRoot(r.workspace) + root, err := os.OpenRoot(r.Workspace) if err != nil { return fmt.Errorf("failed to open workspace: %w", err) } defer root.Close() - relPath, err := getSafeRelPath(r.workspace, path) + relPath, err := getSafeRelPath(r.Workspace, path) if err != nil { return err } @@ -307,7 +141,7 @@ func (r *sandboxFs) execute(path string, fn func(root *os.Root, relPath string) return fn(root, relPath) } -func (r *sandboxFs) ReadFile(path string) ([]byte, error) { +func (r *SandboxFs) ReadFile(path string) ([]byte, error) { var content []byte err := r.execute(path, func(root *os.Root, relPath string) error { fileContent, err := root.ReadFile(relPath) @@ -328,7 +162,7 @@ func (r *sandboxFs) ReadFile(path string) ([]byte, error) { return content, err } -func (r *sandboxFs) WriteFile(path string, data []byte) error { +func (r *SandboxFs) WriteFile(path string, data []byte) error { return r.execute(path, func(root *os.Root, relPath string) error { dir := filepath.Dir(relPath) if dir != "." && dir != "/" { @@ -381,7 +215,7 @@ func (r *sandboxFs) WriteFile(path string, data []byte) error { }) } -func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) { +func (r *SandboxFs) ReadDir(path string) ([]os.DirEntry, error) { var entries []os.DirEntry err := r.execute(path, func(root *os.Root, relPath string) error { dirEntries, err := fs.ReadDir(root.FS(), relPath) diff --git a/pkg/tools/result.go b/pkg/tools/common/result.go similarity index 99% rename from pkg/tools/result.go rename to pkg/tools/common/result.go index b13055b1c..71d2ff122 100644 --- a/pkg/tools/result.go +++ b/pkg/tools/common/result.go @@ -1,4 +1,4 @@ -package tools +package common import "encoding/json" diff --git a/pkg/tools/base.go b/pkg/tools/common/types.go similarity index 89% rename from pkg/tools/base.go rename to pkg/tools/common/types.go index 770d8cb04..5901f0cb7 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/common/types.go @@ -1,6 +1,8 @@ -package tools +package common -import "context" +import ( + "context" +) // Tool is the interface that all tools must implement. type Tool interface { @@ -68,14 +70,3 @@ type AsyncTool interface { // The callback will be called from a goroutine and should handle thread-safety if needed. SetCallback(cb AsyncCallback) } - -func ToolToSchema(tool Tool) map[string]any { - return map[string]any{ - "type": "function", - "function": map[string]any{ - "name": tool.Name(), - "description": tool.Description(), - "parameters": tool.Parameters(), - }, - } -} diff --git a/pkg/tools/common/web.go b/pkg/tools/common/web.go new file mode 100644 index 000000000..d930d5599 --- /dev/null +++ b/pkg/tools/common/web.go @@ -0,0 +1,64 @@ +package common + +import ( + "fmt" + "net/http" + "net/url" + "regexp" + "strings" + "time" +) + +const ( + UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" +) + +// Pre-compiled regexes for HTML text extraction +var ( + ReScript = regexp.MustCompile(``) + ReStyle = regexp.MustCompile(``) + ReTags = regexp.MustCompile(`<[^>]+>`) + ReWhitespace = regexp.MustCompile(`[^\S\n]+`) + ReBlankLines = regexp.MustCompile(`\n{3,}`) + + // DuckDuckGo result extraction + ReDDGLink = regexp.MustCompile(`]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)`) + ReDDGSnippet = regexp.MustCompile(`([\s\S]*?)`) +) + +// createHTTPClient creates an HTTP client with optional proxy support +func CreateHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) { + client := &http.Client{ + Timeout: timeout, + Transport: &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + DisableCompression: false, + TLSHandshakeTimeout: 15 * time.Second, + }, + } + + if proxyURL != "" { + proxy, err := url.Parse(proxyURL) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %w", err) + } + scheme := strings.ToLower(proxy.Scheme) + switch scheme { + case "http", "https", "socks5", "socks5h": + default: + return nil, fmt.Errorf( + "unsupported proxy scheme %q (supported: http, https, socks5, socks5h)", + proxy.Scheme, + ) + } + if proxy.Host == "" { + return nil, fmt.Errorf("invalid proxy URL: missing host") + } + client.Transport.(*http.Transport).Proxy = http.ProxyURL(proxy) + } else { + client.Transport.(*http.Transport).Proxy = http.ProxyFromEnvironment + } + + return client, nil +} diff --git a/pkg/tools/cron.go b/pkg/tools/cron/cron.go similarity index 85% rename from pkg/tools/cron.go rename to pkg/tools/cron/cron.go index 562fffc84..6abf79f85 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron/cron.go @@ -1,4 +1,4 @@ -package tools +package cron import ( "context" @@ -9,6 +9,8 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" + "github.com/sipeed/picoclaw/pkg/tools/common" + "github.com/sipeed/picoclaw/pkg/tools/exec" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -22,7 +24,7 @@ type CronTool struct { cronService *cron.CronService executor JobExecutor msgBus *bus.MessageBus - execTool *ExecTool + execTool *exec.ExecTool channel string chatID string mu sync.RWMutex @@ -34,7 +36,7 @@ func NewCronTool( cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, config *config.Config, ) *CronTool { - execTool := NewExecToolWithConfig(workspace, restrict, config) + execTool := exec.NewExecToolWithConfig(workspace, restrict, config) execTool.SetTimeout(execTimeout) return &CronTool{ cronService: cronService, @@ -106,10 +108,10 @@ func (t *CronTool) SetContext(channel, chatID string) { } // Execute runs the tool with the given arguments -func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult { +func (t *CronTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { action, ok := args["action"].(string) if !ok { - return ErrorResult("action is required") + return common.ErrorResult("action is required") } switch action { @@ -124,23 +126,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult case "disable": return t.enableJob(args, false) default: - return ErrorResult(fmt.Sprintf("unknown action: %s", action)) + return common.ErrorResult(fmt.Sprintf("unknown action: %s", action)) } } -func (t *CronTool) addJob(args map[string]any) *ToolResult { +func (t *CronTool) addJob(args map[string]any) *common.ToolResult { t.mu.RLock() channel := t.channel chatID := t.chatID t.mu.RUnlock() if channel == "" || chatID == "" { - return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.") + return common.ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.") } message, ok := args["message"].(string) if !ok || message == "" { - return ErrorResult("message is required for add") + return common.ErrorResult("message is required for add") } var schedule cron.CronSchedule @@ -169,7 +171,7 @@ func (t *CronTool) addJob(args map[string]any) *ToolResult { Expr: cronExpr, } } else { - return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required") + return common.ErrorResult("one of at_seconds, every_seconds, or cron_expr is required") } // Read deliver parameter, default to true @@ -199,7 +201,7 @@ func (t *CronTool) addJob(args map[string]any) *ToolResult { chatID, ) if err != nil { - return ErrorResult(fmt.Sprintf("Error adding job: %v", err)) + return common.ErrorResult(fmt.Sprintf("Error adding job: %v", err)) } if command != "" { @@ -208,14 +210,14 @@ func (t *CronTool) addJob(args map[string]any) *ToolResult { t.cronService.UpdateJob(job) } - return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID)) + return common.SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID)) } -func (t *CronTool) listJobs() *ToolResult { +func (t *CronTool) listJobs() *common.ToolResult { jobs := t.cronService.ListJobs(false) if len(jobs) == 0 { - return SilentResult("No scheduled jobs") + return common.SilentResult("No scheduled jobs") } result := "Scheduled jobs:\n" @@ -233,37 +235,37 @@ func (t *CronTool) listJobs() *ToolResult { result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo) } - return SilentResult(result) + return common.SilentResult(result) } -func (t *CronTool) removeJob(args map[string]any) *ToolResult { +func (t *CronTool) removeJob(args map[string]any) *common.ToolResult { jobID, ok := args["job_id"].(string) if !ok || jobID == "" { - return ErrorResult("job_id is required for remove") + return common.ErrorResult("job_id is required for remove") } if t.cronService.RemoveJob(jobID) { - return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID)) + return common.SilentResult(fmt.Sprintf("Cron job removed: %s", jobID)) } - return ErrorResult(fmt.Sprintf("Job %s not found", jobID)) + return common.ErrorResult(fmt.Sprintf("Job %s not found", jobID)) } -func (t *CronTool) enableJob(args map[string]any, enable bool) *ToolResult { +func (t *CronTool) enableJob(args map[string]any, enable bool) *common.ToolResult { jobID, ok := args["job_id"].(string) if !ok || jobID == "" { - return ErrorResult("job_id is required for enable/disable") + return common.ErrorResult("job_id is required for enable/disable") } job := t.cronService.EnableJob(jobID, enable) if job == nil { - return ErrorResult(fmt.Sprintf("Job %s not found", jobID)) + return common.ErrorResult(fmt.Sprintf("Job %s not found", jobID)) } status := "enabled" if !enable { status = "disabled" } - return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status)) + return common.SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status)) } // ExecuteJob executes a cron job through the agent diff --git a/pkg/tools/edit.go b/pkg/tools/edit_file/edit_file.go similarity index 52% rename from pkg/tools/edit.go rename to pkg/tools/edit_file/edit_file.go index d3ab267bf..18e2f9cba 100644 --- a/pkg/tools/edit.go +++ b/pkg/tools/edit_file/edit_file.go @@ -1,26 +1,26 @@ -package tools +package edit_file import ( "context" - "errors" "fmt" - "io/fs" "strings" + + "github.com/sipeed/picoclaw/pkg/tools/common" ) // EditFileTool edits a file by replacing old_text with new_text. // The old_text must exist exactly in the file. type EditFileTool struct { - fs fileSystem + fs common.FileSystem } // NewEditFileTool creates a new EditFileTool with optional directory restriction. func NewEditFileTool(workspace string, restrict bool) *EditFileTool { - var fs fileSystem + var fs common.FileSystem if restrict { - fs = &sandboxFs{workspace: workspace} + fs = &common.SandboxFs{Workspace: workspace} } else { - fs = &hostFs{} + fs = &common.HostFs{} } return &EditFileTool{fs: fs} } @@ -54,87 +54,31 @@ func (t *EditFileTool) Parameters() map[string]any { } } -func (t *EditFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult { +func (t *EditFileTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { path, ok := args["path"].(string) if !ok { - return ErrorResult("path is required") + return common.ErrorResult("path is required") } oldText, ok := args["old_text"].(string) if !ok { - return ErrorResult("old_text is required") + return common.ErrorResult("old_text is required") } newText, ok := args["new_text"].(string) if !ok { - return ErrorResult("new_text is required") + return common.ErrorResult("new_text is required") } if err := editFile(t.fs, path, oldText, newText); err != nil { - return ErrorResult(err.Error()) + return common.ErrorResult(err.Error()) } - return SilentResult(fmt.Sprintf("File edited: %s", path)) -} - -type AppendFileTool struct { - fs fileSystem -} - -func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool { - var fs fileSystem - if restrict { - fs = &sandboxFs{workspace: workspace} - } else { - fs = &hostFs{} - } - return &AppendFileTool{fs: fs} -} - -func (t *AppendFileTool) Name() string { - return "append_file" -} - -func (t *AppendFileTool) Description() string { - return "Append content to the end of a file" -} - -func (t *AppendFileTool) Parameters() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "path": map[string]any{ - "type": "string", - "description": "The file path to append to", - }, - "content": map[string]any{ - "type": "string", - "description": "The content to append", - }, - }, - "required": []string{"path", "content"}, - } -} - -func (t *AppendFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult { - path, ok := args["path"].(string) - if !ok { - return ErrorResult("path is required") - } - - content, ok := args["content"].(string) - if !ok { - return ErrorResult("content is required") - } - - if err := appendFile(t.fs, path, content); err != nil { - return ErrorResult(err.Error()) - } - return SilentResult(fmt.Sprintf("Appended to %s", path)) + return common.SilentResult(fmt.Sprintf("File edited: %s", path)) } // editFile reads the file via sysFs, performs the replacement, and writes back. -// It uses a fileSystem interface, allowing the same logic for both restricted and unrestricted modes. -func editFile(sysFs fileSystem, path, oldText, newText string) error { +// It uses a common.FileSystem interface, allowing the same logic for both restricted and unrestricted modes. +func editFile(sysFs common.FileSystem, path, oldText, newText string) error { content, err := sysFs.ReadFile(path) if err != nil { return err @@ -148,17 +92,6 @@ func editFile(sysFs fileSystem, path, oldText, newText string) error { return sysFs.WriteFile(path, newContent) } -// appendFile reads the existing content (if any) via sysFs, appends new content, and writes back. -func appendFile(sysFs fileSystem, path, appendContent string) error { - content, err := sysFs.ReadFile(path) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return err - } - - newContent := append(content, []byte(appendContent)...) - return sysFs.WriteFile(path, newContent) -} - // replaceEditContent handles the core logic of finding and replacing a single occurrence of oldText. func replaceEditContent(content []byte, oldText, newText string) ([]byte, error) { contentStr := string(content) diff --git a/pkg/tools/edit_test.go b/pkg/tools/edit_file/edit_file_test.go similarity index 69% rename from pkg/tools/edit_test.go rename to pkg/tools/edit_file/edit_file_test.go index 83a7e778c..2bde2f3a5 100644 --- a/pkg/tools/edit_test.go +++ b/pkg/tools/edit_file/edit_file_test.go @@ -1,4 +1,4 @@ -package tools +package edit_file import ( "context" @@ -218,82 +218,6 @@ func TestEditTool_EditFile_MissingNewText(t *testing.T) { } } -// TestEditTool_AppendFile_Success verifies successful file appending -func TestEditTool_AppendFile_Success(t *testing.T) { - tmpDir := t.TempDir() - testFile := filepath.Join(tmpDir, "test.txt") - os.WriteFile(testFile, []byte("Initial content"), 0o644) - - tool := NewAppendFileTool("", false) - ctx := context.Background() - args := map[string]any{ - "path": testFile, - "content": "\nAppended content", - } - - result := tool.Execute(ctx, args) - - // Success should not be an error - if result.IsError { - t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) - } - - // Should return SilentResult - if !result.Silent { - t.Errorf("Expected Silent=true for AppendFile, got false") - } - - // ForUser should be empty (silent result) - if result.ForUser != "" { - t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) - } - - // Verify content was actually appended - content, err := os.ReadFile(testFile) - if err != nil { - t.Fatalf("Failed to read file: %v", err) - } - contentStr := string(content) - if !strings.Contains(contentStr, "Initial content") { - t.Errorf("Expected original content to remain, got: %s", contentStr) - } - if !strings.Contains(contentStr, "Appended content") { - t.Errorf("Expected appended content, got: %s", contentStr) - } -} - -// TestEditTool_AppendFile_MissingPath verifies error handling for missing path -func TestEditTool_AppendFile_MissingPath(t *testing.T) { - tool := NewAppendFileTool("", false) - ctx := context.Background() - args := map[string]any{ - "content": "test", - } - - result := tool.Execute(ctx, args) - - // Should return error result - if !result.IsError { - t.Errorf("Expected error when path is missing") - } -} - -// TestEditTool_AppendFile_MissingContent verifies error handling for missing content -func TestEditTool_AppendFile_MissingContent(t *testing.T) { - tool := NewAppendFileTool("", false) - ctx := context.Background() - args := map[string]any{ - "path": "/tmp/test.txt", - } - - result := tool.Execute(ctx, args) - - // Should return error result - if !result.IsError { - t.Errorf("Expected error when content is missing") - } -} - // TestReplaceEditContent verifies the helper function replaceEditContent func TestReplaceEditContent(t *testing.T) { tests := []struct { @@ -343,57 +267,6 @@ func TestReplaceEditContent(t *testing.T) { } } -// TestAppendFileTool_AppendToNonExistent_Restricted verifies that AppendFileTool in restricted mode -// can append to a file that does not yet exist — it should silently create the file. -// This exercises the errors.Is(err, fs.ErrNotExist) path in appendFileWithRW + rootRW. -func TestAppendFileTool_AppendToNonExistent_Restricted(t *testing.T) { - workspace := t.TempDir() - tool := NewAppendFileTool(workspace, true) - ctx := context.Background() - - args := map[string]any{ - "path": "brand_new_file.txt", - "content": "first content", - } - - result := tool.Execute(ctx, args) - assert.False( - t, - result.IsError, - "Expected success when appending to non-existent file in restricted mode, got: %s", - result.ForLLM, - ) - - // Verify the file was created with correct content - data, err := os.ReadFile(filepath.Join(workspace, "brand_new_file.txt")) - assert.NoError(t, err) - assert.Equal(t, "first content", string(data)) -} - -// TestAppendFileTool_Restricted_Success verifies that AppendFileTool in restricted mode -// correctly appends to an existing file within the sandbox. -func TestAppendFileTool_Restricted_Success(t *testing.T) { - workspace := t.TempDir() - testFile := "existing.txt" - err := os.WriteFile(filepath.Join(workspace, testFile), []byte("initial"), 0o644) - assert.NoError(t, err) - - tool := NewAppendFileTool(workspace, true) - ctx := context.Background() - args := map[string]any{ - "path": testFile, - "content": " appended", - } - - result := tool.Execute(ctx, args) - assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM) - assert.True(t, result.Silent) - - data, err := os.ReadFile(filepath.Join(workspace, testFile)) - assert.NoError(t, err) - assert.Equal(t, "initial appended", string(data)) -} - // TestEditFileTool_Restricted_InPlaceEdit verifies that EditFileTool in restricted mode // correctly edits a file using the single-open editFileInRoot path. func TestEditFileTool_Restricted_InPlaceEdit(t *testing.T) { diff --git a/pkg/tools/shell.go b/pkg/tools/exec/exec.go similarity index 94% rename from pkg/tools/shell.go rename to pkg/tools/exec/exec.go index ad1664b5b..d1cae555d 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/exec/exec.go @@ -1,4 +1,4 @@ -package tools +package exec import ( "bytes" @@ -14,6 +14,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/tools/common" ) type ExecTool struct { @@ -134,18 +135,18 @@ func (t *ExecTool) Parameters() map[string]any { } } -func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult { +func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { command, ok := args["command"].(string) if !ok { - return ErrorResult("command is required") + return common.ErrorResult("command is required") } cwd := t.workingDir if wd, ok := args["working_dir"].(string); ok && wd != "" { if t.restrictToWorkspace && t.workingDir != "" { - resolvedWD, err := validatePath(wd, t.workingDir, true) + resolvedWD, err := common.ValidatePath(wd, t.workingDir, true) if err != nil { - return ErrorResult("Command blocked by safety guard (" + err.Error() + ")") + return common.ErrorResult("Command blocked by safety guard (" + err.Error() + ")") } cwd = resolvedWD } else { @@ -161,7 +162,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult } if guardError := t.guardCommand(command, cwd); guardError != "" { - return ErrorResult(guardError) + return common.ErrorResult(guardError) } // timeout == 0 means no timeout @@ -191,7 +192,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult cmd.Stderr = &stderr if err := cmd.Start(); err != nil { - return ErrorResult(fmt.Sprintf("failed to start command: %v", err)) + return common.ErrorResult(fmt.Sprintf("failed to start command: %v", err)) } done := make(chan error, 1) @@ -222,7 +223,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult if err != nil { if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) { msg := fmt.Sprintf("Command timed out after %v", t.timeout) - return &ToolResult{ + return &common.ToolResult{ ForLLM: msg, ForUser: msg, IsError: true, @@ -241,14 +242,14 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult } if err != nil { - return &ToolResult{ + return &common.ToolResult{ ForLLM: output, ForUser: output, IsError: true, } } - return &ToolResult{ + return &common.ToolResult{ ForLLM: output, ForUser: output, IsError: false, diff --git a/pkg/tools/shell_process_unix.go b/pkg/tools/exec/exec_process_unix.go similarity index 97% rename from pkg/tools/shell_process_unix.go rename to pkg/tools/exec/exec_process_unix.go index 7b29a81bf..ed7e173e1 100644 --- a/pkg/tools/shell_process_unix.go +++ b/pkg/tools/exec/exec_process_unix.go @@ -1,6 +1,6 @@ //go:build !windows -package tools +package exec import ( "os/exec" diff --git a/pkg/tools/shell_process_windows.go b/pkg/tools/exec/exec_process_windows.go similarity index 96% rename from pkg/tools/shell_process_windows.go rename to pkg/tools/exec/exec_process_windows.go index fe23b5c96..5b1babeba 100644 --- a/pkg/tools/shell_process_windows.go +++ b/pkg/tools/exec/exec_process_windows.go @@ -1,6 +1,6 @@ //go:build windows -package tools +package exec import ( "os/exec" diff --git a/pkg/tools/shell_test.go b/pkg/tools/exec/exec_test.go similarity index 99% rename from pkg/tools/shell_test.go rename to pkg/tools/exec/exec_test.go index 6d35815e8..51fddcd8b 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/exec/exec_test.go @@ -1,4 +1,4 @@ -package tools +package exec import ( "context" diff --git a/pkg/tools/shell_timeout_unix_test.go b/pkg/tools/exec/exec_timeout_unix_test.go similarity index 99% rename from pkg/tools/shell_timeout_unix_test.go rename to pkg/tools/exec/exec_timeout_unix_test.go index 04ef8e441..075caa999 100644 --- a/pkg/tools/shell_timeout_unix_test.go +++ b/pkg/tools/exec/exec_timeout_unix_test.go @@ -1,6 +1,6 @@ //go:build !windows -package tools +package exec import ( "context" diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go index 6f896e22d..55274ed19 100644 --- a/pkg/tools/filesystem_test.go +++ b/pkg/tools/filesystem_test.go @@ -1,311 +1,14 @@ package tools import ( - "context" "io" "os" "path/filepath" - "strings" "testing" "github.com/stretchr/testify/assert" ) -// TestFilesystemTool_ReadFile_Success verifies successful file reading -func TestFilesystemTool_ReadFile_Success(t *testing.T) { - tmpDir := t.TempDir() - testFile := filepath.Join(tmpDir, "test.txt") - os.WriteFile(testFile, []byte("test content"), 0o644) - - tool := NewReadFileTool("", false) - ctx := context.Background() - args := map[string]any{ - "path": testFile, - } - - result := tool.Execute(ctx, args) - - // Success should not be an error - if result.IsError { - t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) - } - - // ForLLM should contain file content - if !strings.Contains(result.ForLLM, "test content") { - t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM) - } - - // ReadFile returns NewToolResult which only sets ForLLM, not ForUser - // This is the expected behavior - file content goes to LLM, not directly to user - if result.ForUser != "" { - t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser) - } -} - -// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file -func TestFilesystemTool_ReadFile_NotFound(t *testing.T) { - tool := NewReadFileTool("", false) - ctx := context.Background() - args := map[string]any{ - "path": "/nonexistent_file_12345.txt", - } - - result := tool.Execute(ctx, args) - - // Failure should be marked as error - if !result.IsError { - t.Errorf("Expected error for missing file, got IsError=false") - } - - // Should contain error message - if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { - t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) - } -} - -// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path -func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) { - tool := &ReadFileTool{} - ctx := context.Background() - args := map[string]any{} - - result := tool.Execute(ctx, args) - - // Should return error result - if !result.IsError { - t.Errorf("Expected error when path is missing") - } - - // Should mention required parameter - if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") { - t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM) - } -} - -// TestFilesystemTool_WriteFile_Success verifies successful file writing -func TestFilesystemTool_WriteFile_Success(t *testing.T) { - tmpDir := t.TempDir() - testFile := filepath.Join(tmpDir, "newfile.txt") - - tool := NewWriteFileTool("", false) - ctx := context.Background() - args := map[string]any{ - "path": testFile, - "content": "hello world", - } - - result := tool.Execute(ctx, args) - - // Success should not be an error - if result.IsError { - t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) - } - - // WriteFile returns SilentResult - if !result.Silent { - t.Errorf("Expected Silent=true for WriteFile, got false") - } - - // ForUser should be empty (silent result) - if result.ForUser != "" { - t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) - } - - // Verify file was actually written - content, err := os.ReadFile(testFile) - if err != nil { - t.Fatalf("Failed to read written file: %v", err) - } - if string(content) != "hello world" { - t.Errorf("Expected file content 'hello world', got: %s", string(content)) - } -} - -// TestFilesystemTool_WriteFile_CreateDir verifies directory creation -func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) { - tmpDir := t.TempDir() - testFile := filepath.Join(tmpDir, "subdir", "newfile.txt") - - tool := NewWriteFileTool("", false) - ctx := context.Background() - args := map[string]any{ - "path": testFile, - "content": "test", - } - - result := tool.Execute(ctx, args) - - // Success should not be an error - if result.IsError { - t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM) - } - - // Verify directory was created and file written - content, err := os.ReadFile(testFile) - if err != nil { - t.Fatalf("Failed to read written file: %v", err) - } - if string(content) != "test" { - t.Errorf("Expected file content 'test', got: %s", string(content)) - } -} - -// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path -func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) { - tool := NewWriteFileTool("", false) - ctx := context.Background() - args := map[string]any{ - "content": "test", - } - - result := tool.Execute(ctx, args) - - // Should return error result - if !result.IsError { - t.Errorf("Expected error when path is missing") - } -} - -// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content -func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) { - tool := NewWriteFileTool("", false) - ctx := context.Background() - args := map[string]any{ - "path": "/tmp/test.txt", - } - - result := tool.Execute(ctx, args) - - // Should return error result - if !result.IsError { - t.Errorf("Expected error when content is missing") - } - - // Should mention required parameter - if !strings.Contains(result.ForLLM, "content is required") && - !strings.Contains(result.ForUser, "content is required") { - t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM) - } -} - -// TestFilesystemTool_ListDir_Success verifies successful directory listing -func TestFilesystemTool_ListDir_Success(t *testing.T) { - tmpDir := t.TempDir() - os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0o644) - os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0o644) - os.Mkdir(filepath.Join(tmpDir, "subdir"), 0o755) - - tool := NewListDirTool("", false) - ctx := context.Background() - args := map[string]any{ - "path": tmpDir, - } - - result := tool.Execute(ctx, args) - - // Success should not be an error - if result.IsError { - t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) - } - - // Should list files and directories - if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") { - t.Errorf("Expected files in listing, got: %s", result.ForLLM) - } - if !strings.Contains(result.ForLLM, "subdir") { - t.Errorf("Expected subdir in listing, got: %s", result.ForLLM) - } -} - -// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory -func TestFilesystemTool_ListDir_NotFound(t *testing.T) { - tool := NewListDirTool("", false) - ctx := context.Background() - args := map[string]any{ - "path": "/nonexistent_directory_12345", - } - - result := tool.Execute(ctx, args) - - // Failure should be marked as error - if !result.IsError { - t.Errorf("Expected error for non-existent directory, got IsError=false") - } - - // Should contain error message - if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { - t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) - } -} - -// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory -func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { - tool := NewListDirTool("", false) - ctx := context.Background() - args := map[string]any{} - - result := tool.Execute(ctx, args) - - // Should use "." as default path - if result.IsError { - t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM) - } -} - -// Block paths that look inside workspace but point outside via symlink. -func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) { - root := t.TempDir() - workspace := filepath.Join(root, "workspace") - if err := os.MkdirAll(workspace, 0o755); err != nil { - t.Fatalf("failed to create workspace: %v", err) - } - - secret := filepath.Join(root, "secret.txt") - if err := os.WriteFile(secret, []byte("top secret"), 0o644); err != nil { - t.Fatalf("failed to write secret file: %v", err) - } - - link := filepath.Join(workspace, "leak.txt") - if err := os.Symlink(secret, link); err != nil { - t.Skipf("symlink not supported in this environment: %v", err) - } - - tool := NewReadFileTool(workspace, true) - result := tool.Execute(context.Background(), map[string]any{ - "path": link, - }) - - if !result.IsError { - t.Fatalf("expected symlink escape to be blocked") - } - // os.Root might return different errors depending on platform/implementation - // but it definitely should error. - // Our wrapper returns "access denied or file not found" - if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") && - !strings.Contains(result.ForLLM, "no such file") { - t.Fatalf("expected symlink escape error, got: %s", result.ForLLM) - } -} - -func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) { - tool := NewReadFileTool("", true) // restrict=true but workspace="" - - // Try to read a sensitive file (simulated by a temp file outside workspace) - tmpDir := t.TempDir() - secretFile := filepath.Join(tmpDir, "shadow") - os.WriteFile(secretFile, []byte("secret data"), 0o600) - - result := tool.Execute(context.Background(), map[string]any{ - "path": secretFile, - }) - - // We EXPECT IsError=true (access blocked due to empty workspace) - assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM) - - // Verify it failed for the right reason - assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error") -} - // TestRootMkdirAll verifies that root.MkdirAll (used by atomicWriteFileInRoot) handles all cases: // single dir, deeply nested dirs, already-existing dirs, and a file blocking a directory path. func TestRootMkdirAll(t *testing.T) { @@ -339,28 +42,6 @@ func TestRootMkdirAll(t *testing.T) { assert.Error(t, err, "expected error when a file exists at the directory path") } -func TestFilesystemTool_WriteFile_Restricted_CreateDir(t *testing.T) { - workspace := t.TempDir() - tool := NewWriteFileTool(workspace, true) - ctx := context.Background() - - testFile := "deep/nested/path/to/file.txt" - content := "deep content" - args := map[string]any{ - "path": testFile, - "content": content, - } - - result := tool.Execute(ctx, args) - assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM) - - // Verify file content - actualPath := filepath.Join(workspace, testFile) - data, err := os.ReadFile(actualPath) - assert.NoError(t, err) - assert.Equal(t, content, string(data)) -} - // TestHostRW_Read_PermissionDenied verifies that hostRW.Read surfaces access denied errors. func TestHostRW_Read_PermissionDenied(t *testing.T) { if os.Getuid() == 0 { @@ -372,7 +53,7 @@ func TestHostRW_Read_PermissionDenied(t *testing.T) { assert.NoError(t, err) defer os.Chmod(protected, 0o644) // ensure cleanup - _, err = (&hostFs{}).ReadFile(protected) + _, err = (&HostFs{}).ReadFile(protected) assert.Error(t, err) assert.Contains(t, err.Error(), "access denied") } @@ -381,7 +62,7 @@ func TestHostRW_Read_PermissionDenied(t *testing.T) { func TestHostRW_Read_Directory(t *testing.T) { tmpDir := t.TempDir() - _, err := (&hostFs{}).ReadFile(tmpDir) + _, err := (&HostFs{}).ReadFile(tmpDir) assert.Error(t, err, "expected error when reading a directory as a file") } @@ -396,7 +77,7 @@ func TestRootRW_Read_Directory(t *testing.T) { err = root.Mkdir("subdir", 0o755) assert.NoError(t, err) - _, err = (&sandboxFs{workspace: workspace}).ReadFile("subdir") + _, err = (&SandboxFs{Workspace: workspace}).ReadFile("subdir") assert.Error(t, err, "expected error when reading a directory as a file") } @@ -405,7 +86,7 @@ func TestHostRW_Write_ParentDirMissing(t *testing.T) { tmpDir := t.TempDir() target := filepath.Join(tmpDir, "a", "b", "c", "file.txt") - err := (&hostFs{}).WriteFile(target, []byte("hello")) + err := (&HostFs{}).WriteFile(target, []byte("hello")) assert.NoError(t, err) data, err := os.ReadFile(target) @@ -419,7 +100,7 @@ func TestRootRW_Write_ParentDirMissing(t *testing.T) { workspace := t.TempDir() relPath := "x/y/z/file.txt" - err := (&sandboxFs{workspace: workspace}).WriteFile(relPath, []byte("nested")) + err := (&SandboxFs{Workspace: workspace}).WriteFile(relPath, []byte("nested")) assert.NoError(t, err) data, err := os.ReadFile(filepath.Join(workspace, relPath)) @@ -433,7 +114,7 @@ func TestHostRW_Write(t *testing.T) { testFile := filepath.Join(tmpDir, "atomic_test.txt") testData := []byte("atomic test content") - err := (&hostFs{}).WriteFile(testFile, testData) + err := (&HostFs{}).WriteFile(testFile, testData) assert.NoError(t, err) content, err := os.ReadFile(testFile) @@ -442,7 +123,7 @@ func TestHostRW_Write(t *testing.T) { // Verify it overwrites correctly newData := []byte("new atomic content") - err = (&hostFs{}).WriteFile(testFile, newData) + err = (&HostFs{}).WriteFile(testFile, newData) assert.NoError(t, err) content, err = os.ReadFile(testFile) @@ -457,7 +138,7 @@ func TestRootRW_Write(t *testing.T) { relPath := "atomic_root_test.txt" testData := []byte("atomic root test content") - erw := &sandboxFs{workspace: tmpDir} + erw := &SandboxFs{Workspace: tmpDir} err := erw.WriteFile(relPath, testData) assert.NoError(t, err) diff --git a/pkg/tools/skills_search.go b/pkg/tools/find_skills/find_skills.go similarity index 88% rename from pkg/tools/skills_search.go rename to pkg/tools/find_skills/find_skills.go index 2b6cffd38..2c75a1955 100644 --- a/pkg/tools/skills_search.go +++ b/pkg/tools/find_skills/find_skills.go @@ -1,4 +1,4 @@ -package tools +package find_skills import ( "context" @@ -6,6 +6,7 @@ import ( "strings" "github.com/sipeed/picoclaw/pkg/skills" + "github.com/sipeed/picoclaw/pkg/tools/common" ) // FindSkillsTool allows the LLM agent to search for installable skills from registries. @@ -51,11 +52,11 @@ func (t *FindSkillsTool) Parameters() map[string]any { } } -func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]any) *ToolResult { +func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { query, ok := args["query"].(string) query = strings.ToLower(strings.TrimSpace(query)) if !ok || query == "" { - return ErrorResult("query is required and must be a non-empty string") + return common.ErrorResult("query is required and must be a non-empty string") } limit := 5 @@ -69,14 +70,14 @@ func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]any) *Tool // Check cache first. if t.cache != nil { if cached, hit := t.cache.Get(query); hit { - return SilentResult(formatSearchResults(query, cached, true)) + return common.SilentResult(formatSearchResults(query, cached, true)) } } // Search all registries. results, err := t.registryMgr.SearchAll(ctx, query, limit) if err != nil { - return ErrorResult(fmt.Sprintf("skill search failed: %v", err)) + return common.ErrorResult(fmt.Sprintf("skill search failed: %v", err)) } // Cache the results. @@ -84,7 +85,7 @@ func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]any) *Tool t.cache.Put(query, results) } - return SilentResult(formatSearchResults(query, results, false)) + return common.SilentResult(formatSearchResults(query, results, false)) } func formatSearchResults(query string, results []skills.SearchResult, cached bool) string { diff --git a/pkg/tools/skills_search_test.go b/pkg/tools/find_skills/find_skills_test.go similarity index 99% rename from pkg/tools/skills_search_test.go rename to pkg/tools/find_skills/find_skills_test.go index 0e5387cf5..a534deebf 100644 --- a/pkg/tools/skills_search_test.go +++ b/pkg/tools/find_skills/find_skills_test.go @@ -1,4 +1,4 @@ -package tools +package find_skills import ( "context" diff --git a/pkg/tools/i2c.go b/pkg/tools/i2c/i2c.go similarity index 78% rename from pkg/tools/i2c.go rename to pkg/tools/i2c/i2c.go index 779b1d5a7..e02f54ae1 100644 --- a/pkg/tools/i2c.go +++ b/pkg/tools/i2c/i2c.go @@ -1,4 +1,4 @@ -package tools +package i2c import ( "context" @@ -7,6 +7,8 @@ import ( "path/filepath" "regexp" "runtime" + + "github.com/sipeed/picoclaw/pkg/tools/common" ) // I2CTool provides I2C bus interaction for reading sensors and controlling peripherals. @@ -63,14 +65,14 @@ func (t *I2CTool) Parameters() map[string]any { } } -func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult { +func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { if runtime.GOOS != "linux" { - return ErrorResult("I2C is only supported on Linux. This tool requires /dev/i2c-* device files.") + return common.ErrorResult("I2C is only supported on Linux. This tool requires /dev/i2c-* device files.") } action, ok := args["action"].(string) if !ok { - return ErrorResult("action is required") + return common.ErrorResult("action is required") } switch action { @@ -83,19 +85,19 @@ func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult case "write": return t.writeDevice(args) default: - return ErrorResult(fmt.Sprintf("unknown action: %s (valid: detect, scan, read, write)", action)) + return common.ErrorResult(fmt.Sprintf("unknown action: %s (valid: detect, scan, read, write)", action)) } } // detect lists available I2C buses by globbing /dev/i2c-* -func (t *I2CTool) detect() *ToolResult { +func (t *I2CTool) detect() *common.ToolResult { matches, err := filepath.Glob("/dev/i2c-*") if err != nil { - return ErrorResult(fmt.Sprintf("failed to scan for I2C buses: %v", err)) + return common.ErrorResult(fmt.Sprintf("failed to scan for I2C buses: %v", err)) } if len(matches) == 0 { - return SilentResult( + return common.SilentResult( "No I2C buses found. You may need to:\n1. Load the i2c-dev module: modprobe i2c-dev\n2. Check that I2C is enabled in device tree\n3. Configure pinmux for your board (see hardware skill)", ) } @@ -114,7 +116,7 @@ func (t *I2CTool) detect() *ToolResult { } result, _ := json.MarshalIndent(buses, "", " ") - return SilentResult(fmt.Sprintf("Found %d I2C bus(es):\n%s", len(buses), string(result))) + return common.SilentResult(fmt.Sprintf("Found %d I2C bus(es):\n%s", len(buses), string(result))) } // Helper functions for I2C operations (used by platform-specific implementations) @@ -130,14 +132,14 @@ func isValidBusID(id string) bool { // parseI2CAddress extracts and validates an I2C address from args // //nolint:unused // Used by i2c_linux.go -func parseI2CAddress(args map[string]any) (int, *ToolResult) { +func parseI2CAddress(args map[string]any) (int, *common.ToolResult) { addrFloat, ok := args["address"].(float64) if !ok { - return 0, ErrorResult("address is required (e.g. 0x38 for AHT20)") + return 0, common.ErrorResult("address is required (e.g. 0x38 for AHT20)") } addr := int(addrFloat) if addr < 0x03 || addr > 0x77 { - return 0, ErrorResult("address must be in valid 7-bit range (0x03-0x77)") + return 0, common.ErrorResult("address must be in valid 7-bit range (0x03-0x77)") } return addr, nil } @@ -145,13 +147,13 @@ func parseI2CAddress(args map[string]any) (int, *ToolResult) { // parseI2CBus extracts and validates an I2C bus from args // //nolint:unused // Used by i2c_linux.go -func parseI2CBus(args map[string]any) (string, *ToolResult) { +func parseI2CBus(args map[string]any) (string, *common.ToolResult) { bus, ok := args["bus"].(string) if !ok || bus == "" { - return "", ErrorResult("bus is required (e.g. \"1\" for /dev/i2c-1)") + return "", common.ErrorResult("bus is required (e.g. \"1\" for /dev/i2c-1)") } if !isValidBusID(bus) { - return "", ErrorResult("invalid bus identifier: must be a number (e.g. \"1\")") + return "", common.ErrorResult("invalid bus identifier: must be a number (e.g. \"1\")") } return bus, nil } diff --git a/pkg/tools/i2c_linux.go b/pkg/tools/i2c/i2c_linux.go similarity index 76% rename from pkg/tools/i2c_linux.go rename to pkg/tools/i2c/i2c_linux.go index 4eaaf8f09..2d4f63250 100644 --- a/pkg/tools/i2c_linux.go +++ b/pkg/tools/i2c/i2c_linux.go @@ -1,10 +1,12 @@ -package tools +package i2c import ( "encoding/json" "fmt" "syscall" "unsafe" + + "github.com/sipeed/picoclaw/pkg/tools/common" ) // I2C ioctl constants from Linux kernel headers (, ) @@ -74,7 +76,7 @@ func smbusProbe(fd int, addr int, hasQuick bool) bool { // scan probes valid 7-bit addresses on a bus for connected devices. // Uses the same hybrid probe strategy as i2cdetect's MODE_AUTO: // SMBus Quick Write for most addresses, SMBus Read Byte for EEPROM ranges. -func (t *I2CTool) scan(args map[string]any) *ToolResult { +func (t *I2CTool) scan(args map[string]any) *common.ToolResult { bus, errResult := parseI2CBus(args) if errResult != nil { return errResult @@ -83,7 +85,7 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult { devPath := fmt.Sprintf("/dev/i2c-%s", bus) fd, err := syscall.Open(devPath, syscall.O_RDWR, 0) if err != nil { - return ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and i2c-dev module)", devPath, err)) + return common.ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and i2c-dev module)", devPath, err)) } defer syscall.Close(fd) @@ -92,14 +94,14 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult { var funcs uintptr _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cFuncs, uintptr(unsafe.Pointer(&funcs))) if errno != 0 { - return ErrorResult(fmt.Sprintf("failed to query I2C adapter capabilities on %s: %v", devPath, errno)) + return common.ErrorResult(fmt.Sprintf("failed to query I2C adapter capabilities on %s: %v", devPath, errno)) } hasQuick := funcs&i2cFuncSmbusQuick != 0 hasReadByte := funcs&i2cFuncSmbusReadByte != 0 if !hasQuick && !hasReadByte { - return ErrorResult( + return common.ErrorResult( fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath), ) } @@ -132,7 +134,7 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult { } if len(found) == 0 { - return SilentResult(fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath)) + return common.SilentResult(fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath)) } result, _ := json.MarshalIndent(map[string]any{ @@ -140,11 +142,11 @@ func (t *I2CTool) scan(args map[string]any) *ToolResult { "devices": found, "count": len(found), }, "", " ") - return SilentResult(fmt.Sprintf("Scan of %s:\n%s", devPath, string(result))) + return common.SilentResult(fmt.Sprintf("Scan of %s:\n%s", devPath, string(result))) } // readDevice reads bytes from an I2C device, optionally at a specific register -func (t *I2CTool) readDevice(args map[string]any) *ToolResult { +func (t *I2CTool) readDevice(args map[string]any) *common.ToolResult { bus, errResult := parseI2CBus(args) if errResult != nil { return errResult @@ -160,31 +162,31 @@ func (t *I2CTool) readDevice(args map[string]any) *ToolResult { length = int(l) } if length < 1 || length > 256 { - return ErrorResult("length must be between 1 and 256") + return common.ErrorResult("length must be between 1 and 256") } devPath := fmt.Sprintf("/dev/i2c-%s", bus) fd, err := syscall.Open(devPath, syscall.O_RDWR, 0) if err != nil { - return ErrorResult(fmt.Sprintf("failed to open %s: %v", devPath, err)) + return common.ErrorResult(fmt.Sprintf("failed to open %s: %v", devPath, err)) } defer syscall.Close(fd) // Set slave address _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSlave, uintptr(addr)) if errno != 0 { - return ErrorResult(fmt.Sprintf("failed to set I2C address 0x%02x: %v", addr, errno)) + return common.ErrorResult(fmt.Sprintf("failed to set I2C address 0x%02x: %v", addr, errno)) } // If register is specified, write it first if regFloat, ok := args["register"].(float64); ok { reg := int(regFloat) if reg < 0 || reg > 255 { - return ErrorResult("register must be between 0x00 and 0xFF") + return common.ErrorResult("register must be between 0x00 and 0xFF") } _, err = syscall.Write(fd, []byte{byte(reg)}) if err != nil { - return ErrorResult(fmt.Sprintf("failed to write register 0x%02x: %v", reg, err)) + return common.ErrorResult(fmt.Sprintf("failed to write register 0x%02x: %v", reg, err)) } } @@ -192,7 +194,7 @@ func (t *I2CTool) readDevice(args map[string]any) *ToolResult { buf := make([]byte, length) n, err := syscall.Read(fd, buf) if err != nil { - return ErrorResult(fmt.Sprintf("failed to read from device 0x%02x: %v", addr, err)) + return common.ErrorResult(fmt.Sprintf("failed to read from device 0x%02x: %v", addr, err)) } // Format as hex bytes @@ -210,14 +212,14 @@ func (t *I2CTool) readDevice(args map[string]any) *ToolResult { "hex": hexBytes, "length": n, }, "", " ") - return SilentResult(string(result)) + return common.SilentResult(string(result)) } // writeDevice writes bytes to an I2C device, optionally at a specific register -func (t *I2CTool) writeDevice(args map[string]any) *ToolResult { +func (t *I2CTool) writeDevice(args map[string]any) *common.ToolResult { confirm, _ := args["confirm"].(bool) if !confirm { - return ErrorResult( + return common.ErrorResult( "write operations require confirm: true. Please confirm with the user before writing to I2C devices, as incorrect writes can misconfigure hardware.", ) } @@ -234,10 +236,10 @@ func (t *I2CTool) writeDevice(args map[string]any) *ToolResult { dataRaw, ok := args["data"].([]any) if !ok || len(dataRaw) == 0 { - return ErrorResult("data is required for write (array of byte values 0-255)") + return common.ErrorResult("data is required for write (array of byte values 0-255)") } if len(dataRaw) > 256 { - return ErrorResult("data too long: maximum 256 bytes per I2C transaction") + return common.ErrorResult("data too long: maximum 256 bytes per I2C transaction") } data := make([]byte, 0, len(dataRaw)+1) @@ -246,7 +248,7 @@ func (t *I2CTool) writeDevice(args map[string]any) *ToolResult { if regFloat, ok := args["register"].(float64); ok { reg := int(regFloat) if reg < 0 || reg > 255 { - return ErrorResult("register must be between 0x00 and 0xFF") + return common.ErrorResult("register must be between 0x00 and 0xFF") } data = append(data, byte(reg)) } @@ -254,11 +256,11 @@ func (t *I2CTool) writeDevice(args map[string]any) *ToolResult { for i, v := range dataRaw { f, ok := v.(float64) if !ok { - return ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i)) + return common.ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i)) } b := int(f) if b < 0 || b > 255 { - return ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b)) + return common.ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b)) } data = append(data, byte(b)) } @@ -266,21 +268,21 @@ func (t *I2CTool) writeDevice(args map[string]any) *ToolResult { devPath := fmt.Sprintf("/dev/i2c-%s", bus) fd, err := syscall.Open(devPath, syscall.O_RDWR, 0) if err != nil { - return ErrorResult(fmt.Sprintf("failed to open %s: %v", devPath, err)) + return common.ErrorResult(fmt.Sprintf("failed to open %s: %v", devPath, err)) } defer syscall.Close(fd) // Set slave address _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSlave, uintptr(addr)) if errno != 0 { - return ErrorResult(fmt.Sprintf("failed to set I2C address 0x%02x: %v", addr, errno)) + return common.ErrorResult(fmt.Sprintf("failed to set I2C address 0x%02x: %v", addr, errno)) } // Write data n, err := syscall.Write(fd, data) if err != nil { - return ErrorResult(fmt.Sprintf("failed to write to device 0x%02x: %v", addr, err)) + return common.ErrorResult(fmt.Sprintf("failed to write to device 0x%02x: %v", addr, err)) } - return SilentResult(fmt.Sprintf("Wrote %d byte(s) to device 0x%02x on %s", n, addr, devPath)) + return common.SilentResult(fmt.Sprintf("Wrote %d byte(s) to device 0x%02x on %s", n, addr, devPath)) } diff --git a/pkg/tools/i2c_other.go b/pkg/tools/i2c/i2c_other.go similarity index 97% rename from pkg/tools/i2c_other.go rename to pkg/tools/i2c/i2c_other.go index 7becf8339..c23d81ee5 100644 --- a/pkg/tools/i2c_other.go +++ b/pkg/tools/i2c/i2c_other.go @@ -1,6 +1,6 @@ //go:build !linux -package tools +package i2c // scan is a stub for non-Linux platforms. func (t *I2CTool) scan(args map[string]any) *ToolResult { diff --git a/pkg/tools/skills_install.go b/pkg/tools/install_skill/install_skill.go similarity index 89% rename from pkg/tools/skills_install.go rename to pkg/tools/install_skill/install_skill.go index 71bfe730b..3a4f062a8 100644 --- a/pkg/tools/skills_install.go +++ b/pkg/tools/install_skill/install_skill.go @@ -1,4 +1,4 @@ -package tools +package install_skill import ( "context" @@ -12,6 +12,7 @@ import ( "github.com/sipeed/picoclaw/pkg/fileutil" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/skills" + "github.com/sipeed/picoclaw/pkg/tools/common" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -68,7 +69,7 @@ func (t *InstallSkillTool) Parameters() map[string]any { } } -func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *ToolResult { +func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { // Install lock to prevent concurrent directory operations. // Ideally this should be done at a `slug` level, currently, its at a `workspace` level. t.mu.Lock() @@ -77,13 +78,13 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To // Validate slug slug, _ := args["slug"].(string) if err := utils.ValidateSkillIdentifier(slug); err != nil { - return ErrorResult(fmt.Sprintf("invalid slug %q: error: %s", slug, err.Error())) + return common.ErrorResult(fmt.Sprintf("invalid slug %q: error: %s", slug, err.Error())) } // Validate registry registryName, _ := args["registry"].(string) if err := utils.ValidateSkillIdentifier(registryName); err != nil { - return ErrorResult(fmt.Sprintf("invalid registry %q: error: %s", registryName, err.Error())) + return common.ErrorResult(fmt.Sprintf("invalid registry %q: error: %s", registryName, err.Error())) } version, _ := args["version"].(string) @@ -95,7 +96,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To if !force { if _, err := os.Stat(targetDir); err == nil { - return ErrorResult( + return common.ErrorResult( fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir), ) } @@ -107,12 +108,12 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To // Resolve which registry to use. registry := t.registryMgr.GetRegistry(registryName) if registry == nil { - return ErrorResult(fmt.Sprintf("registry %q not found", registryName)) + return common.ErrorResult(fmt.Sprintf("registry %q not found", registryName)) } // Ensure skills directory exists. if err := os.MkdirAll(skillsDir, 0o755); err != nil { - return ErrorResult(fmt.Sprintf("failed to create skills directory: %v", err)) + return common.ErrorResult(fmt.Sprintf("failed to create skills directory: %v", err)) } // Download and install (handles metadata, version resolution, extraction). @@ -128,7 +129,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To "error": rmErr.Error(), }) } - return ErrorResult(fmt.Sprintf("failed to install %q: %v", slug, err)) + return common.ErrorResult(fmt.Sprintf("failed to install %q: %v", slug, err)) } // Moderation: block malware. @@ -142,7 +143,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To "error": rmErr.Error(), }) } - return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug)) + return common.ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug)) } // Write origin metadata. @@ -172,7 +173,7 @@ func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *To } output += "\nThe skill is now available and can be loaded in the current session." - return SilentResult(output) + return common.SilentResult(output) } // originMeta tracks which registry a skill was installed from. diff --git a/pkg/tools/skills_install_test.go b/pkg/tools/install_skill/install_skill_test.go similarity index 99% rename from pkg/tools/skills_install_test.go rename to pkg/tools/install_skill/install_skill_test.go index 676fcecc0..d86392108 100644 --- a/pkg/tools/skills_install_test.go +++ b/pkg/tools/install_skill/install_skill_test.go @@ -1,4 +1,4 @@ -package tools +package install_skill import ( "context" diff --git a/pkg/tools/list_dir/list_dir.go b/pkg/tools/list_dir/list_dir.go new file mode 100644 index 000000000..21969275a --- /dev/null +++ b/pkg/tools/list_dir/list_dir.go @@ -0,0 +1,70 @@ +package list_dir + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/sipeed/picoclaw/pkg/tools/common" +) + +type ListDirTool struct { + fs common.FileSystem +} + +func NewListDirTool(workspace string, restrict bool) *ListDirTool { + var fs common.FileSystem + if restrict { + fs = &common.SandboxFs{Workspace: workspace} + } else { + fs = &common.HostFs{} + } + return &ListDirTool{fs: fs} +} + +func (t *ListDirTool) Name() string { + return "list_dir" +} + +func (t *ListDirTool) Description() string { + return "List files and directories in a path" +} + +func (t *ListDirTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to list", + }, + }, + "required": []string{"path"}, + } +} + +func (t *ListDirTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { + path, ok := args["path"].(string) + if !ok { + path = "." + } + + entries, err := t.fs.ReadDir(path) + if err != nil { + return common.ErrorResult(fmt.Sprintf("failed to read directory: %v", err)) + } + return formatDirEntries(entries) +} + +func formatDirEntries(entries []os.DirEntry) *common.ToolResult { + var result strings.Builder + for _, entry := range entries { + if entry.IsDir() { + result.WriteString("DIR: " + entry.Name() + "\n") + } else { + result.WriteString("FILE: " + entry.Name() + "\n") + } + } + return common.NewToolResult(result.String()) +} diff --git a/pkg/tools/list_dir/list_dir_test.go b/pkg/tools/list_dir/list_dir_test.go new file mode 100644 index 000000000..cd48b883f --- /dev/null +++ b/pkg/tools/list_dir/list_dir_test.go @@ -0,0 +1,73 @@ +package list_dir + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +// TestFilesystemTool_ListDir_Success verifies successful directory listing +func TestFilesystemTool_ListDir_Success(t *testing.T) { + tmpDir := t.TempDir() + os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0o644) + os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0o644) + os.Mkdir(filepath.Join(tmpDir, "subdir"), 0o755) + + tool := NewListDirTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": tmpDir, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // Should list files and directories + if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") { + t.Errorf("Expected files in listing, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "subdir") { + t.Errorf("Expected subdir in listing, got: %s", result.ForLLM) + } +} + +// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory +func TestFilesystemTool_ListDir_NotFound(t *testing.T) { + tool := NewListDirTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": "/nonexistent_directory_12345", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for non-existent directory, got IsError=false") + } + + // Should contain error message + if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { + t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory +func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { + tool := NewListDirTool("", false) + ctx := context.Background() + args := map[string]any{} + + result := tool.Execute(ctx, args) + + // Should use "." as default path + if result.IsError { + t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM) + } +} diff --git a/pkg/tools/message.go b/pkg/tools/message/message.go similarity index 85% rename from pkg/tools/message.go rename to pkg/tools/message/message.go index 15ef4ff73..a5079756f 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message/message.go @@ -1,8 +1,10 @@ -package tools +package message import ( "context" "fmt" + + "github.com/sipeed/picoclaw/pkg/tools/common" ) type SendCallback func(channel, chatID, content string) error @@ -62,10 +64,10 @@ func (t *MessageTool) SetSendCallback(callback SendCallback) { t.sendCallback = callback } -func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult { +func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { content, ok := args["content"].(string) if !ok { - return &ToolResult{ForLLM: "content is required", IsError: true} + return &common.ToolResult{ForLLM: "content is required", IsError: true} } channel, _ := args["channel"].(string) @@ -79,15 +81,15 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes } if channel == "" || chatID == "" { - return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true} + return &common.ToolResult{ForLLM: "No target channel/chat specified", IsError: true} } if t.sendCallback == nil { - return &ToolResult{ForLLM: "Message sending not configured", IsError: true} + return &common.ToolResult{ForLLM: "Message sending not configured", IsError: true} } if err := t.sendCallback(channel, chatID, content); err != nil { - return &ToolResult{ + return &common.ToolResult{ ForLLM: fmt.Sprintf("sending message: %v", err), IsError: true, Err: err, @@ -96,7 +98,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes t.sentInRound = true // Silent: user already received the message directly - return &ToolResult{ + return &common.ToolResult{ ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID), Silent: true, } diff --git a/pkg/tools/message_test.go b/pkg/tools/message/message_test.go similarity index 98% rename from pkg/tools/message_test.go rename to pkg/tools/message/message_test.go index 717c1117b..1c7599e9c 100644 --- a/pkg/tools/message_test.go +++ b/pkg/tools/message/message_test.go @@ -1,4 +1,4 @@ -package tools +package message import ( "context" @@ -36,7 +36,7 @@ func TestMessageTool_Execute_Success(t *testing.T) { t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent) } - // Verify ToolResult meets US-011 criteria: + // Verify tools.ToolResult meets US-011 criteria: // - Send success returns SilentResult (Silent=true) if !result.Silent { t.Error("Expected Silent=true for successful send") @@ -110,7 +110,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) { result := tool.Execute(ctx, args) - // Verify ToolResult for send failure: + // Verify tools.ToolResult for send failure: // - Send failure returns ErrorResult (IsError=true) if !result.IsError { t.Error("Expected IsError=true for failed send") diff --git a/pkg/tools/read_file/read_file.go b/pkg/tools/read_file/read_file.go new file mode 100644 index 000000000..f2cccecb2 --- /dev/null +++ b/pkg/tools/read_file/read_file.go @@ -0,0 +1,55 @@ +package read_file + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/tools/common" +) + +type ReadFileTool struct { + fs common.FileSystem +} + +func NewReadFileTool(workspace string, restrict bool) *ReadFileTool { + var fs common.FileSystem + if restrict { + fs = &common.SandboxFs{Workspace: workspace} + } else { + fs = &common.HostFs{} + } + return &ReadFileTool{fs: fs} +} + +func (t *ReadFileTool) Name() string { + return "read_file" +} + +func (t *ReadFileTool) Description() string { + return "Read the contents of a file" +} + +func (t *ReadFileTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to the file to read", + }, + }, + "required": []string{"path"}, + } +} + +func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { + path, ok := args["path"].(string) + if !ok { + return common.ErrorResult("path is required") + } + + content, err := t.fs.ReadFile(path) + if err != nil { + return common.ErrorResult(err.Error()) + } + return common.NewToolResult(string(content)) +} diff --git a/pkg/tools/read_file/read_file_test.go b/pkg/tools/read_file/read_file_test.go new file mode 100644 index 000000000..ba13ce2da --- /dev/null +++ b/pkg/tools/read_file/read_file_test.go @@ -0,0 +1,136 @@ +package read_file + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestFilesystemTool_ReadFile_Success verifies successful file reading +func TestFilesystemTool_ReadFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test content"), 0o644) + + tool := NewReadFileTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": testFile, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForLLM should contain file content + if !strings.Contains(result.ForLLM, "test content") { + t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM) + } + + // ReadFile returns NewToolResult which only sets ForLLM, not ForUser + // This is the expected behavior - file content goes to LLM, not directly to user + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser) + } +} + +// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file +func TestFilesystemTool_ReadFile_NotFound(t *testing.T) { + tool := NewReadFileTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": "/nonexistent_file_12345.txt", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for missing file, got IsError=false") + } + + // Should contain error message + if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { + t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path +func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) { + tool := &ReadFileTool{} + ctx := context.Background() + args := map[string]any{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } + + // Should mention required parameter + if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") { + t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// Block paths that look inside workspace but point outside via symlink. +func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) { + root := t.TempDir() + workspace := filepath.Join(root, "workspace") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatalf("failed to create workspace: %v", err) + } + + secret := filepath.Join(root, "secret.txt") + if err := os.WriteFile(secret, []byte("top secret"), 0o644); err != nil { + t.Fatalf("failed to write secret file: %v", err) + } + + link := filepath.Join(workspace, "leak.txt") + if err := os.Symlink(secret, link); err != nil { + t.Skipf("symlink not supported in this environment: %v", err) + } + + tool := NewReadFileTool(workspace, true) + result := tool.Execute(context.Background(), map[string]any{ + "path": link, + }) + + if !result.IsError { + t.Fatalf("expected symlink escape to be blocked") + } + // os.Root might return different errors depending on platform/implementation + // but it definitely should error. + // Our wrapper returns "access denied or file not found" + if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") && + !strings.Contains(result.ForLLM, "no such file") { + t.Fatalf("expected symlink escape error, got: %s", result.ForLLM) + } +} + +func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) { + tool := NewReadFileTool("", true) // restrict=true but workspace="" + + // Try to read a sensitive file (simulated by a temp file outside workspace) + tmpDir := t.TempDir() + secretFile := filepath.Join(tmpDir, "shadow") + os.WriteFile(secretFile, []byte("secret data"), 0o600) + + result := tool.Execute(context.Background(), map[string]any{ + "path": secretFile, + }) + + // We EXPECT IsError=true (access blocked due to empty workspace) + assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM) + + // Verify it failed for the right reason + assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error") +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index d37a093a8..a3519ae5f 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -7,8 +7,23 @@ import ( "sync" "time" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/skills" + "github.com/sipeed/picoclaw/pkg/tools/append_file" + "github.com/sipeed/picoclaw/pkg/tools/edit_file" + "github.com/sipeed/picoclaw/pkg/tools/exec" + "github.com/sipeed/picoclaw/pkg/tools/find_skills" + "github.com/sipeed/picoclaw/pkg/tools/i2c" + "github.com/sipeed/picoclaw/pkg/tools/install_skill" + "github.com/sipeed/picoclaw/pkg/tools/list_dir" + "github.com/sipeed/picoclaw/pkg/tools/message" + "github.com/sipeed/picoclaw/pkg/tools/read_file" + "github.com/sipeed/picoclaw/pkg/tools/spi" + "github.com/sipeed/picoclaw/pkg/tools/web_fetch" + "github.com/sipeed/picoclaw/pkg/tools/web_search" + "github.com/sipeed/picoclaw/pkg/tools/write_file" ) type ToolRegistry struct { @@ -16,7 +31,98 @@ type ToolRegistry struct { mu sync.RWMutex } -func NewToolRegistry() *ToolRegistry { +func NewToolRegistry(cfg *config.Config, workspace string, restrict bool) *ToolRegistry { + toolsRegistry := &ToolRegistry{ + tools: make(map[string]Tool), + } + + // Handle nil config (for testing) + if cfg == nil { + cfg = config.DefaultConfig() + } + + // File tools - each with individual configuration + if cfg.Tools.ReadFile.Enabled { + toolsRegistry.Register(read_file.NewReadFileTool(workspace, restrict)) + } + if cfg.Tools.WriteFile.Enabled { + toolsRegistry.Register(write_file.NewWriteFileTool(workspace, restrict)) + } + if cfg.Tools.EditFile.Enabled { + toolsRegistry.Register(edit_file.NewEditFileTool(workspace, restrict)) + } + if cfg.Tools.AppendFile.Enabled { + toolsRegistry.Register(append_file.NewAppendFileTool(workspace, restrict)) + } + if cfg.Tools.ListDir.Enabled { + toolsRegistry.Register(list_dir.NewListDirTool(workspace, restrict)) + } + + // Exec tool + if cfg.Tools.Exec.Enabled { + toolsRegistry.Register(exec.NewExecToolWithConfig(workspace, restrict, cfg)) + } + + // Web tools + if searchTool := web_search.NewWebSearchTool(web_search.WebSearchToolOptions{ + BraveAPIKey: cfg.Tools.Web.Brave.APIKey, + BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, + BraveEnabled: cfg.Tools.Web.Brave.Enabled, + TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey, + TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL, + TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults, + TavilyEnabled: cfg.Tools.Web.Tavily.Enabled, + DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, + DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, + PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, + PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, + PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, + Proxy: cfg.Tools.Web.Proxy, + }); searchTool != nil { + toolsRegistry.Register(searchTool) + } + toolsRegistry.Register(web_fetch.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy)) + + // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms + if cfg.Tools.I2C.Enabled { + toolsRegistry.Register(i2c.NewI2CTool()) + } + if cfg.Tools.SPI.Enabled { + toolsRegistry.Register(spi.NewSPITool()) + } + + // Skill discovery and installation tools + registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ + MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, + ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), + }) + searchCache := skills.NewSearchCache( + cfg.Tools.Skills.SearchCache.MaxSize, + time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second, + ) + if cfg.Tools.FindSkills.Enabled { + toolsRegistry.Register(find_skills.NewFindSkillsTool(registryMgr, searchCache)) + } + if cfg.Tools.InstallSkill.Enabled { + toolsRegistry.Register(install_skill.NewInstallSkillTool(registryMgr, workspace)) + } + + // Message tool + if cfg.Tools.Message.Enabled { + toolsRegistry.Register(message.NewMessageTool()) + } + + // // Spawn tool + // if cfg.Tools.Spawn.Enabled { + // // Note: Spawn tool is registered separately in agent loop + // } + + return toolsRegistry +} + +// NewEmptyToolRegistry creates a tool registry without pre-registered tools. +// This is useful for testing. +func NewEmptyToolRegistry() *ToolRegistry { return &ToolRegistry{ tools: make(map[string]Tool), } @@ -121,6 +227,17 @@ func (r *ToolRegistry) sortedToolNames() []string { return names } +func ToolToSchema(tool Tool) map[string]any { + return map[string]any{ + "type": "function", + "function": map[string]any{ + "name": tool.Name(), + "description": tool.Description(), + "parameters": tool.Parameters(), + }, + } +} + func (r *ToolRegistry) GetDefinitions() []map[string]any { r.mu.RLock() defer r.mu.RUnlock() diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index 8ae13b20c..0c60c0790 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -59,7 +59,7 @@ func newMockTool(name, desc string) *mockRegistryTool { // --- tests --- func TestNewToolRegistry(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() if r.Count() != 0 { t.Errorf("expected empty registry, got count %d", r.Count()) } @@ -69,7 +69,7 @@ func TestNewToolRegistry(t *testing.T) { } func TestToolRegistry_RegisterAndGet(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() tool := newMockTool("echo", "echoes input") r.Register(tool) @@ -83,7 +83,7 @@ func TestToolRegistry_RegisterAndGet(t *testing.T) { } func TestToolRegistry_Get_NotFound(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() _, ok := r.Get("nonexistent") if ok { t.Error("expected ok=false for unregistered tool") @@ -91,7 +91,7 @@ func TestToolRegistry_Get_NotFound(t *testing.T) { } func TestToolRegistry_RegisterOverwrite(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() r.Register(newMockTool("dup", "first")) r.Register(newMockTool("dup", "second")) @@ -105,7 +105,7 @@ func TestToolRegistry_RegisterOverwrite(t *testing.T) { } func TestToolRegistry_Execute_Success(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() r.Register(&mockRegistryTool{ name: "greet", desc: "says hello", @@ -123,7 +123,7 @@ func TestToolRegistry_Execute_Success(t *testing.T) { } func TestToolRegistry_Execute_NotFound(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() result := r.Execute(context.Background(), "missing", nil) if !result.IsError { t.Error("expected error for missing tool") @@ -137,7 +137,7 @@ func TestToolRegistry_Execute_NotFound(t *testing.T) { } func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() ct := &mockCtxTool{ mockRegistryTool: *newMockTool("ctx_tool", "needs context"), } @@ -154,7 +154,7 @@ func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) { } func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() ct := &mockCtxTool{ mockRegistryTool: *newMockTool("ctx_tool", "needs context"), } @@ -168,7 +168,7 @@ func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) { } func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() at := &mockAsyncRegistryTool{ mockRegistryTool: *newMockTool("async_tool", "async work"), } @@ -193,7 +193,7 @@ func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) { } func TestToolRegistry_GetDefinitions(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() r.Register(newMockTool("alpha", "tool A")) defs := r.GetDefinitions() @@ -216,7 +216,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) { } func TestToolRegistry_ToProviderDefs(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() params := map[string]any{"type": "object", "properties": map[string]any{}} r.Register(&mockRegistryTool{ name: "beta", @@ -251,7 +251,7 @@ func TestToolRegistry_ToProviderDefs(t *testing.T) { } func TestToolRegistry_List(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() r.Register(newMockTool("x", "")) r.Register(newMockTool("y", "")) @@ -270,7 +270,7 @@ func TestToolRegistry_List(t *testing.T) { } func TestToolRegistry_Count(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() if r.Count() != 0 { t.Errorf("expected 0, got %d", r.Count()) } @@ -288,7 +288,7 @@ func TestToolRegistry_Count(t *testing.T) { } func TestToolRegistry_GetSummaries(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() r.Register(newMockTool("read_file", "Reads a file")) summaries := r.GetSummaries() @@ -326,7 +326,7 @@ func TestToolToSchema(t *testing.T) { } func TestToolRegistry_ConcurrentAccess(t *testing.T) { - r := NewToolRegistry() + r := NewEmptyToolRegistry() var wg sync.WaitGroup for i := 0; i < 50; i++ { diff --git a/pkg/tools/spi.go b/pkg/tools/spi/spi.go similarity index 87% rename from pkg/tools/spi.go rename to pkg/tools/spi/spi.go index 0ca17e84f..6e5bdd561 100644 --- a/pkg/tools/spi.go +++ b/pkg/tools/spi/spi.go @@ -1,4 +1,4 @@ -package tools +package spi import ( "context" @@ -7,6 +7,8 @@ import ( "path/filepath" "regexp" "runtime" + + "github.com/sipeed/picoclaw/pkg/tools/common" ) // SPITool provides SPI bus interaction for high-speed peripheral communication. @@ -67,14 +69,14 @@ func (t *SPITool) Parameters() map[string]any { } } -func (t *SPITool) Execute(ctx context.Context, args map[string]any) *ToolResult { +func (t *SPITool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { if runtime.GOOS != "linux" { - return ErrorResult("SPI is only supported on Linux. This tool requires /dev/spidev* device files.") + return common.ErrorResult("SPI is only supported on Linux. This tool requires /dev/spidev* device files.") } action, ok := args["action"].(string) if !ok { - return ErrorResult("action is required") + return common.ErrorResult("action is required") } switch action { @@ -85,19 +87,19 @@ func (t *SPITool) Execute(ctx context.Context, args map[string]any) *ToolResult case "read": return t.readDevice(args) default: - return ErrorResult(fmt.Sprintf("unknown action: %s (valid: list, transfer, read)", action)) + return common.ErrorResult(fmt.Sprintf("unknown action: %s (valid: list, transfer, read)", action)) } } // list finds available SPI devices by globbing /dev/spidev* -func (t *SPITool) list() *ToolResult { +func (t *SPITool) list() *common.ToolResult { matches, err := filepath.Glob("/dev/spidev*") if err != nil { - return ErrorResult(fmt.Sprintf("failed to scan for SPI devices: %v", err)) + return common.ErrorResult(fmt.Sprintf("failed to scan for SPI devices: %v", err)) } if len(matches) == 0 { - return SilentResult( + return common.SilentResult( "No SPI devices found. You may need to:\n1. Enable SPI in device tree\n2. Configure pinmux for your board (see hardware skill)\n3. Check that spidev module is loaded", ) } @@ -116,7 +118,7 @@ func (t *SPITool) list() *ToolResult { } result, _ := json.MarshalIndent(devices, "", " ") - return SilentResult(fmt.Sprintf("Found %d SPI device(s):\n%s", len(devices), string(result))) + return common.SilentResult(fmt.Sprintf("Found %d SPI device(s):\n%s", len(devices), string(result))) } // Helper function for SPI operations (used by platform-specific implementations) diff --git a/pkg/tools/spi_linux.go b/pkg/tools/spi/spi_linux.go similarity index 76% rename from pkg/tools/spi_linux.go rename to pkg/tools/spi/spi_linux.go index 9def73662..b1c25c600 100644 --- a/pkg/tools/spi_linux.go +++ b/pkg/tools/spi/spi_linux.go @@ -1,4 +1,4 @@ -package tools +package spi import ( "encoding/json" @@ -6,6 +6,8 @@ import ( "runtime" "syscall" "unsafe" + + "github.com/sipeed/picoclaw/pkg/tools/common" ) // SPI ioctl constants from Linux kernel headers. @@ -35,67 +37,67 @@ type spiTransfer struct { } // configureSPI opens an SPI device and sets mode, bits per word, and speed -func configureSPI(devPath string, mode uint8, bits uint8, speed uint32) (int, *ToolResult) { +func configureSPI(devPath string, mode uint8, bits uint8, speed uint32) (int, *common.ToolResult) { fd, err := syscall.Open(devPath, syscall.O_RDWR, 0) if err != nil { - return -1, ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and spidev module)", devPath, err)) + return -1, common.ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and spidev module)", devPath, err)) } // Set SPI mode _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMode, uintptr(unsafe.Pointer(&mode))) if errno != 0 { syscall.Close(fd) - return -1, ErrorResult(fmt.Sprintf("failed to set SPI mode %d: %v", mode, errno)) + return -1, common.ErrorResult(fmt.Sprintf("failed to set SPI mode %d: %v", mode, errno)) } // Set bits per word _, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrBitsPerWord, uintptr(unsafe.Pointer(&bits))) if errno != 0 { syscall.Close(fd) - return -1, ErrorResult(fmt.Sprintf("failed to set bits per word %d: %v", bits, errno)) + return -1, common.ErrorResult(fmt.Sprintf("failed to set bits per word %d: %v", bits, errno)) } // Set max speed _, _, errno = syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), spiIocWrMaxSpeedHz, uintptr(unsafe.Pointer(&speed))) if errno != 0 { syscall.Close(fd) - return -1, ErrorResult(fmt.Sprintf("failed to set SPI speed %d Hz: %v", speed, errno)) + return -1, common.ErrorResult(fmt.Sprintf("failed to set SPI speed %d Hz: %v", speed, errno)) } return fd, nil } // transfer performs a full-duplex SPI transfer -func (t *SPITool) transfer(args map[string]any) *ToolResult { +func (t *SPITool) transfer(args map[string]any) *common.ToolResult { confirm, _ := args["confirm"].(bool) if !confirm { - return ErrorResult( + return common.ErrorResult( "transfer operations require confirm: true. Please confirm with the user before sending data to SPI devices.", ) } dev, speed, mode, bits, errMsg := parseSPIArgs(args) if errMsg != "" { - return ErrorResult(errMsg) + return common.ErrorResult(errMsg) } dataRaw, ok := args["data"].([]any) if !ok || len(dataRaw) == 0 { - return ErrorResult("data is required for transfer (array of byte values 0-255)") + return common.ErrorResult("data is required for transfer (array of byte values 0-255)") } if len(dataRaw) > 4096 { - return ErrorResult("data too long: maximum 4096 bytes per SPI transfer") + return common.ErrorResult("data too long: maximum 4096 bytes per SPI transfer") } txBuf := make([]byte, len(dataRaw)) for i, v := range dataRaw { f, ok := v.(float64) if !ok { - return ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i)) + return common.ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i)) } b := int(f) if b < 0 || b > 255 { - return ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b)) + return common.ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b)) } txBuf[i] = byte(b) } @@ -121,7 +123,7 @@ func (t *SPITool) transfer(args map[string]any) *ToolResult { runtime.KeepAlive(txBuf) runtime.KeepAlive(rxBuf) if errno != 0 { - return ErrorResult(fmt.Sprintf("SPI transfer failed: %v", errno)) + return common.ErrorResult(fmt.Sprintf("SPI transfer failed: %v", errno)) } // Format received bytes @@ -138,14 +140,14 @@ func (t *SPITool) transfer(args map[string]any) *ToolResult { "received": intBytes, "hex": hexBytes, }, "", " ") - return SilentResult(string(result)) + return common.SilentResult(string(result)) } // readDevice reads bytes from SPI by sending zeros (read-only, no confirm needed) -func (t *SPITool) readDevice(args map[string]any) *ToolResult { +func (t *SPITool) readDevice(args map[string]any) *common.ToolResult { dev, speed, mode, bits, errMsg := parseSPIArgs(args) if errMsg != "" { - return ErrorResult(errMsg) + return common.ErrorResult(errMsg) } length := 0 @@ -153,7 +155,7 @@ func (t *SPITool) readDevice(args map[string]any) *ToolResult { length = int(l) } if length < 1 || length > 4096 { - return ErrorResult("length is required for read (1-4096)") + return common.ErrorResult("length is required for read (1-4096)") } devPath := fmt.Sprintf("/dev/spidev%s", dev) @@ -178,7 +180,7 @@ func (t *SPITool) readDevice(args map[string]any) *ToolResult { runtime.KeepAlive(txBuf) runtime.KeepAlive(rxBuf) if errno != 0 { - return ErrorResult(fmt.Sprintf("SPI read failed: %v", errno)) + return common.ErrorResult(fmt.Sprintf("SPI read failed: %v", errno)) } hexBytes := make([]string, len(rxBuf)) @@ -194,5 +196,5 @@ func (t *SPITool) readDevice(args map[string]any) *ToolResult { "hex": hexBytes, "length": len(rxBuf), }, "", " ") - return SilentResult(string(result)) + return common.SilentResult(string(result)) } diff --git a/pkg/tools/spi/spi_other.go b/pkg/tools/spi/spi_other.go new file mode 100644 index 000000000..cae05cc1c --- /dev/null +++ b/pkg/tools/spi/spi_other.go @@ -0,0 +1,15 @@ +//go:build !linux + +package spi + +import "github.com/sipeed/picoclaw/pkg/tools/common" + +// transfer is a stub for non-Linux platforms. +func (t *SPITool) transfer(args map[string]any) *common.ToolResult { + return common.ErrorResult("SPI is only supported on Linux") +} + +// readDevice is a stub for non-Linux platforms. +func (t *SPITool) readDevice(args map[string]any) *common.ToolResult { + return common.ErrorResult("SPI is only supported on Linux") +} diff --git a/pkg/tools/spi_other.go b/pkg/tools/spi_other.go deleted file mode 100644 index 5d078ac3f..000000000 --- a/pkg/tools/spi_other.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build !linux - -package tools - -// transfer is a stub for non-Linux platforms. -func (t *SPITool) transfer(args map[string]any) *ToolResult { - return ErrorResult("SPI is only supported on Linux") -} - -// readDevice is a stub for non-Linux platforms. -func (t *SPITool) readDevice(args map[string]any) *ToolResult { - return ErrorResult("SPI is only supported on Linux") -} diff --git a/pkg/tools/spawn.go b/pkg/tools/subagent/spawn.go similarity index 79% rename from pkg/tools/spawn.go rename to pkg/tools/subagent/spawn.go index 8b166b41f..f97504527 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/subagent/spawn.go @@ -1,9 +1,11 @@ -package tools +package subagent import ( "context" "fmt" "strings" + + "github.com/sipeed/picoclaw/pkg/tools" ) type SpawnTool struct { @@ -11,7 +13,7 @@ type SpawnTool struct { originChannel string originChatID string allowlistCheck func(targetAgentID string) bool - callback AsyncCallback // For async completion notification + callback tools.AsyncCallback // For async completion notification } func NewSpawnTool(manager *SubagentManager) *SpawnTool { @@ -23,7 +25,7 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool { } // SetCallback implements AsyncTool interface for async completion notification -func (t *SpawnTool) SetCallback(cb AsyncCallback) { +func (t *SpawnTool) SetCallback(cb tools.AsyncCallback) { t.callback = cb } @@ -65,10 +67,10 @@ func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) { t.allowlistCheck = check } -func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult { +func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { task, ok := args["task"].(string) if !ok || strings.TrimSpace(task) == "" { - return ErrorResult("task is required and must be a non-empty string") + return tools.ErrorResult("task is required and must be a non-empty string") } label, _ := args["label"].(string) @@ -77,20 +79,20 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResul // Check allowlist if targeting a specific agent if agentID != "" && t.allowlistCheck != nil { if !t.allowlistCheck(agentID) { - return ErrorResult(fmt.Sprintf("not allowed to spawn agent '%s'", agentID)) + return tools.ErrorResult(fmt.Sprintf("not allowed to spawn agent '%s'", agentID)) } } if t.manager == nil { - return ErrorResult("Subagent manager not configured") + return tools.ErrorResult("Subagent manager not configured") } // Pass callback to manager for async completion notification result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback) if err != nil { - return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err)) + return tools.ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err)) } // Return AsyncResult since the task runs in background - return AsyncResult(result) + return tools.AsyncResult(result) } diff --git a/pkg/tools/spawn_test.go b/pkg/tools/subagent/spawn_test.go similarity index 99% rename from pkg/tools/spawn_test.go rename to pkg/tools/subagent/spawn_test.go index 0646c82a9..b36ac3f23 100644 --- a/pkg/tools/spawn_test.go +++ b/pkg/tools/subagent/spawn_test.go @@ -1,4 +1,4 @@ -package tools +package subagent import ( "context" diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent/subagent.go similarity index 87% rename from pkg/tools/subagent.go rename to pkg/tools/subagent/subagent.go index ad371a649..59aa0308a 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent/subagent.go @@ -1,4 +1,4 @@ -package tools +package subagent import ( "context" @@ -8,6 +8,8 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" + "github.com/sipeed/picoclaw/pkg/tools/common" ) type SubagentTask struct { @@ -29,7 +31,7 @@ type SubagentManager struct { defaultModel string bus *bus.MessageBus workspace string - tools *ToolRegistry + tools *tools.ToolRegistry maxIterations int maxTokens int temperature float64 @@ -49,7 +51,7 @@ func NewSubagentManager( defaultModel: defaultModel, bus: bus, workspace: workspace, - tools: NewToolRegistry(), + tools: tools.NewToolRegistry(nil, "", false), maxIterations: 10, nextID: 1, } @@ -67,14 +69,14 @@ func (sm *SubagentManager) SetLLMOptions(maxTokens int, temperature float64) { // SetTools sets the tool registry for subagent execution. // If not set, subagent will have access to the provided tools. -func (sm *SubagentManager) SetTools(tools *ToolRegistry) { +func (sm *SubagentManager) SetTools(tools *tools.ToolRegistry) { sm.mu.Lock() defer sm.mu.Unlock() sm.tools = tools } // RegisterTool registers a tool for subagent execution. -func (sm *SubagentManager) RegisterTool(tool Tool) { +func (sm *SubagentManager) RegisterTool(tool tools.Tool) { sm.mu.Lock() defer sm.mu.Unlock() sm.tools.Register(tool) @@ -83,7 +85,7 @@ func (sm *SubagentManager) RegisterTool(tool Tool) { func (sm *SubagentManager) Spawn( ctx context.Context, task, label, agentID, originChannel, originChatID string, - callback AsyncCallback, + callback common.AsyncCallback, ) (string, error) { sm.mu.Lock() defer sm.mu.Unlock() @@ -112,7 +114,7 @@ func (sm *SubagentManager) Spawn( return fmt.Sprintf("Spawned subagent for task: %s", task), nil } -func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) { +func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback common.AsyncCallback) { task.Status = "running" task.Created = time.Now().UnixMilli() @@ -145,7 +147,7 @@ After completing the task, provide a clear summary of what was done.` // Run tool loop with access to tools sm.mu.RLock() - tools := sm.tools + sm_tools := sm.tools maxIter := sm.maxIterations maxTokens := sm.maxTokens temperature := sm.temperature @@ -164,16 +166,16 @@ After completing the task, provide a clear summary of what was done.` } } - loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ + loopResult, err := tools.RunToolLoop(ctx, tools.ToolLoopConfig{ Provider: sm.provider, Model: sm.defaultModel, - Tools: tools, + Tools: sm_tools, MaxIterations: maxIter, LLMOptions: llmOptions, }, messages, task.OriginChannel, task.OriginChatID) sm.mu.Lock() - var result *ToolResult + var result *common.ToolResult defer func() { sm.mu.Unlock() // Call callback if provided and result is set @@ -190,7 +192,7 @@ After completing the task, provide a clear summary of what was done.` task.Status = "canceled" task.Result = "Task canceled during execution" } - result = &ToolResult{ + result = &common.ToolResult{ ForLLM: task.Result, ForUser: "", Silent: false, @@ -201,7 +203,7 @@ After completing the task, provide a clear summary of what was done.` } else { task.Status = "completed" task.Result = loopResult.Content - result = &ToolResult{ + result = &common.ToolResult{ ForLLM: fmt.Sprintf( "Subagent '%s' completed (iterations: %d): %s", task.Label, @@ -248,7 +250,7 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask { // SubagentTool executes a subagent task synchronously and returns the result. // Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion -// and returns the result directly in the ToolResult. +// and returns the result directly in the common.ToolResult. type SubagentTool struct { manager *SubagentManager originChannel string @@ -293,16 +295,16 @@ func (t *SubagentTool) SetContext(channel, chatID string) { t.originChatID = chatID } -func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult { +func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { task, ok := args["task"].(string) if !ok { - return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required")) + return common.ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required")) } label, _ := args["label"].(string) if t.manager == nil { - return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil")) + return common.ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil")) } // Build messages for subagent @@ -317,10 +319,10 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe }, } - // Use RunToolLoop to execute with tools (same as async SpawnTool) + // Use common.RunToolLoop to execute with tools (same as async SpawnTool) sm := t.manager sm.mu.RLock() - tools := sm.tools + sm_tools := sm.tools maxIter := sm.maxIterations maxTokens := sm.maxTokens temperature := sm.temperature @@ -339,15 +341,15 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe } } - loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ + loopResult, err := tools.RunToolLoop(ctx, tools.ToolLoopConfig{ Provider: sm.provider, Model: sm.defaultModel, - Tools: tools, + Tools: sm_tools, MaxIterations: maxIter, LLMOptions: llmOptions, }, messages, t.originChannel, t.originChatID) if err != nil { - return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err) + return common.ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err) } // ForUser: Brief summary for user (truncated if too long) @@ -365,7 +367,7 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s", labelStr, loopResult.Iterations, loopResult.Content) - return &ToolResult{ + return &common.ToolResult{ ForLLM: llmContent, ForUser: userContent, Silent: false, diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent/subagent_tool_test.go similarity index 99% rename from pkg/tools/subagent_tool_test.go rename to pkg/tools/subagent/subagent_tool_test.go index 59bfdffae..1d8644890 100644 --- a/pkg/tools/subagent_tool_test.go +++ b/pkg/tools/subagent/subagent_tool_test.go @@ -1,4 +1,4 @@ -package tools +package subagent import ( "context" diff --git a/pkg/tools/types.go b/pkg/tools/types.go deleted file mode 100644 index a6015cde3..000000000 --- a/pkg/tools/types.go +++ /dev/null @@ -1,58 +0,0 @@ -package tools - -import "context" - -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} - -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type"` - Function *FunctionCall `json:"function,omitempty"` - Name string `json:"name,omitempty"` - Arguments map[string]any `json:"arguments,omitempty"` -} - -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` -} - -type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type LLMProvider interface { - Chat( - ctx context.Context, - messages []Message, - tools []ToolDefinition, - model string, - options map[string]any, - ) (*LLMResponse, error) - GetDefaultModel() string -} - -type ToolDefinition struct { - Type string `json:"type"` - Function ToolFunctionDefinition `json:"function"` -} - -type ToolFunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]any `json:"parameters"` -} diff --git a/pkg/tools/types_export.go b/pkg/tools/types_export.go new file mode 100644 index 000000000..1ca1a04a5 --- /dev/null +++ b/pkg/tools/types_export.go @@ -0,0 +1,38 @@ +package tools + +import ( + "github.com/sipeed/picoclaw/pkg/tools/common" +) + +type Tool = common.Tool +type ToolResult = common.ToolResult +type ContextualTool = common.ContextualTool +type AsyncTool = common.AsyncTool +type AsyncCallback = common.AsyncCallback +type FileSystem = common.FileSystem +type HostFs = common.HostFs +type SandboxFs = common.SandboxFs + +func NewToolResult(forLLM string) *ToolResult { + return common.NewToolResult(forLLM) +} + +func SilentResult(forLLM string) *ToolResult { + return common.SilentResult(forLLM) +} + +func AsyncResult(forLLM string) *ToolResult { + return common.AsyncResult(forLLM) +} + +func ErrorResult(message string) *ToolResult { + return common.ErrorResult(message) +} + +func UserResult(content string) *ToolResult { + return common.UserResult(content) +} + +func ValidatePath(path, workspace string, restrict bool) (string, error) { + return common.ValidatePath(path, workspace, restrict) +} diff --git a/pkg/tools/web_fetch/web_fetch.go b/pkg/tools/web_fetch/web_fetch.go new file mode 100644 index 000000000..4f8b9215f --- /dev/null +++ b/pkg/tools/web_fetch/web_fetch.go @@ -0,0 +1,194 @@ +package web_fetch + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/tools/common" +) + +type WebFetchTool struct { + maxChars int + proxy string +} + +func NewWebFetchTool(maxChars int) *WebFetchTool { + if maxChars <= 0 { + maxChars = 50000 + } + return &WebFetchTool{ + maxChars: maxChars, + } +} + +func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool { + if maxChars <= 0 { + maxChars = 50000 + } + return &WebFetchTool{ + maxChars: maxChars, + proxy: proxy, + } +} + +func (t *WebFetchTool) Name() string { + return "web_fetch" +} + +func (t *WebFetchTool) Description() string { + return "Fetch a URL and extract readable content (HTML to text). Use this to get weather info, news, articles, or any web content." +} + +func (t *WebFetchTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "url": map[string]any{ + "type": "string", + "description": "URL to fetch", + }, + "maxChars": map[string]any{ + "type": "integer", + "description": "Maximum characters to extract", + "minimum": 100.0, + }, + }, + "required": []string{"url"}, + } +} + +func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { + urlStr, ok := args["url"].(string) + if !ok { + return common.ErrorResult("url is required") + } + + parsedURL, err := url.Parse(urlStr) + if err != nil { + return common.ErrorResult(fmt.Sprintf("invalid URL: %v", err)) + } + + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return common.ErrorResult("only http/https URLs are allowed") + } + + if parsedURL.Host == "" { + return common.ErrorResult("missing domain in URL") + } + + maxChars := t.maxChars + if mc, ok := args["maxChars"].(float64); ok { + if int(mc) > 100 { + maxChars = int(mc) + } + } + + req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) + if err != nil { + return common.ErrorResult(fmt.Sprintf("failed to create request: %v", err)) + } + + req.Header.Set("User-Agent", common.UserAgent) + + client, err := common.CreateHTTPClient(t.proxy, 60*time.Second) + if err != nil { + return common.ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err)) + } + + // Configure redirect handling + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) >= 5 { + return fmt.Errorf("stopped after 5 redirects") + } + return nil + } + + resp, err := client.Do(req) + if err != nil { + return common.ErrorResult(fmt.Sprintf("request failed: %v", err)) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return common.ErrorResult(fmt.Sprintf("failed to read response: %v", err)) + } + + contentType := resp.Header.Get("Content-Type") + + var text, extractor string + + if strings.Contains(contentType, "application/json") { + var jsonData any + if err := json.Unmarshal(body, &jsonData); err == nil { + formatted, _ := json.MarshalIndent(jsonData, "", " ") + text = string(formatted) + extractor = "json" + } else { + text = string(body) + extractor = "raw" + } + } else if strings.Contains(contentType, "text/html") || len(body) > 0 && + (strings.HasPrefix(string(body), " maxChars + if truncated { + text = text[:maxChars] + } + + result := map[string]any{ + "url": urlStr, + "status": resp.StatusCode, + "extractor": extractor, + "truncated": truncated, + "length": len(text), + "text": text, + } + + resultJSON, _ := json.MarshalIndent(result, "", " ") + + return &common.ToolResult{ + ForLLM: fmt.Sprintf( + "Fetched %d bytes from %s (extractor: %s, truncated: %v)", + len(text), + urlStr, + extractor, + truncated, + ), + ForUser: string(resultJSON), + } +} + +func (t *WebFetchTool) extractText(htmlContent string) string { + result := common.ReScript.ReplaceAllLiteralString(htmlContent, "") + result = common.ReStyle.ReplaceAllLiteralString(result, "") + result = common.ReTags.ReplaceAllLiteralString(result, "") + + result = strings.TrimSpace(result) + + result = common.ReWhitespace.ReplaceAllString(result, " ") + result = common.ReBlankLines.ReplaceAllString(result, "\n\n") + + lines := strings.Split(result, "\n") + var cleanLines []string + for _, line := range lines { + line = strings.TrimSpace(line) + if line != "" { + cleanLines = append(cleanLines, line) + } + } + + return strings.Join(cleanLines, "\n") +} diff --git a/pkg/tools/web_test.go b/pkg/tools/web_fetch/web_fetch_test.go similarity index 70% rename from pkg/tools/web_test.go rename to pkg/tools/web_fetch/web_fetch_test.go index 2cd79eb24..78c58df27 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_fetch/web_fetch_test.go @@ -1,8 +1,9 @@ -package tools +package web_fetch import ( "context" "encoding/json" + "github.com/sipeed/picoclaw/pkg/tools/common" "net/http" "net/http/httptest" "strings" @@ -174,34 +175,6 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { } } -// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing -func TestWebTool_WebSearch_NoApiKey(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""}) - if tool != nil { - t.Errorf("Expected nil tool when Brave API key is empty") - } - - // Also nil when nothing is enabled - tool = NewWebSearchTool(WebSearchToolOptions{}) - if tool != nil { - t.Errorf("Expected nil tool when no provider is enabled") - } -} - -// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query -func TestWebTool_WebSearch_MissingQuery(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5}) - ctx := context.Background() - args := map[string]any{} - - result := tool.Execute(ctx, args) - - // Should return error result - if !result.IsError { - t.Errorf("Expected error when query is missing") - } -} - // TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -336,9 +309,9 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) { } func TestCreateHTTPClient_ProxyConfigured(t *testing.T) { - client, err := createHTTPClient("http://127.0.0.1:7890", 12*time.Second) + client, err := common.CreateHTTPClient("http://127.0.0.1:7890", 12*time.Second) if err != nil { - t.Fatalf("createHTTPClient() error: %v", err) + t.Fatalf("common.CreateHTTPClient() error: %v", err) } if client.Timeout != 12*time.Second { t.Fatalf("client.Timeout = %v, want %v", client.Timeout, 12*time.Second) @@ -366,16 +339,16 @@ func TestCreateHTTPClient_ProxyConfigured(t *testing.T) { } func TestCreateHTTPClient_InvalidProxy(t *testing.T) { - _, err := createHTTPClient("://bad-proxy", 10*time.Second) + _, err := common.CreateHTTPClient("://bad-proxy", 10*time.Second) if err == nil { - t.Fatal("createHTTPClient() expected error for invalid proxy URL, got nil") + t.Fatal("common.CreateHTTPClient() expected error for invalid proxy URL, got nil") } } func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) { - client, err := createHTTPClient("socks5://127.0.0.1:1080", 8*time.Second) + client, err := common.CreateHTTPClient("socks5://127.0.0.1:1080", 8*time.Second) if err != nil { - t.Fatalf("createHTTPClient() error: %v", err) + t.Fatalf("common.CreateHTTPClient() error: %v", err) } tr, ok := client.Transport.(*http.Transport) @@ -396,9 +369,9 @@ func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) { } func TestCreateHTTPClient_UnsupportedProxyScheme(t *testing.T) { - _, err := createHTTPClient("ftp://127.0.0.1:21", 10*time.Second) + _, err := common.CreateHTTPClient("ftp://127.0.0.1:21", 10*time.Second) if err == nil { - t.Fatal("createHTTPClient() expected error for unsupported scheme, got nil") + t.Fatal("common.CreateHTTPClient() expected error for unsupported scheme, got nil") } if !strings.Contains(err.Error(), "unsupported proxy scheme") { t.Fatalf("error = %q, want to contain %q", err.Error(), "unsupported proxy scheme") @@ -415,9 +388,9 @@ func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) { t.Setenv("NO_PROXY", "") t.Setenv("no_proxy", "") - client, err := createHTTPClient("", 10*time.Second) + client, err := common.CreateHTTPClient("", 10*time.Second) if err != nil { - t.Fatalf("createHTTPClient() error: %v", err) + t.Fatalf("common.CreateHTTPClient() error: %v", err) } tr, ok := client.Transport.(*http.Transport) @@ -451,124 +424,3 @@ func TestNewWebFetchToolWithProxy(t *testing.T) { t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000) } } - -func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { - t.Run("perplexity", func(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{ - PerplexityEnabled: true, - PerplexityAPIKey: "k", - PerplexityMaxResults: 3, - Proxy: "http://127.0.0.1:7890", - }) - p, ok := tool.provider.(*PerplexitySearchProvider) - if !ok { - t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider) - } - if p.proxy != "http://127.0.0.1:7890" { - t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") - } - }) - - t.Run("brave", func(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{ - BraveEnabled: true, - BraveAPIKey: "k", - BraveMaxResults: 3, - Proxy: "http://127.0.0.1:7890", - }) - p, ok := tool.provider.(*BraveSearchProvider) - if !ok { - t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider) - } - if p.proxy != "http://127.0.0.1:7890" { - t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") - } - }) - - t.Run("duckduckgo", func(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{ - DuckDuckGoEnabled: true, - DuckDuckGoMaxResults: 3, - Proxy: "http://127.0.0.1:7890", - }) - p, ok := tool.provider.(*DuckDuckGoSearchProvider) - if !ok { - t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider) - } - if p.proxy != "http://127.0.0.1:7890" { - t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") - } - }) -} - -// TestWebTool_TavilySearch_Success verifies successful Tavily search -func TestWebTool_TavilySearch_Success(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "POST" { - t.Errorf("Expected POST request, got %s", r.Method) - } - if r.Header.Get("Content-Type") != "application/json" { - t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) - } - - // Verify payload - var payload map[string]any - json.NewDecoder(r.Body).Decode(&payload) - if payload["api_key"] != "test-key" { - t.Errorf("Expected api_key test-key, got %v", payload["api_key"]) - } - if payload["query"] != "test query" { - t.Errorf("Expected query 'test query', got %v", payload["query"]) - } - - // Return mock response - response := map[string]any{ - "results": []map[string]any{ - { - "title": "Test Result 1", - "url": "https://example.com/1", - "content": "Content for result 1", - }, - { - "title": "Test Result 2", - "url": "https://example.com/2", - "content": "Content for result 2", - }, - }, - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(response) - })) - defer server.Close() - - tool := NewWebSearchTool(WebSearchToolOptions{ - TavilyEnabled: true, - TavilyAPIKey: "test-key", - TavilyBaseURL: server.URL, - TavilyMaxResults: 5, - }) - - ctx := context.Background() - args := map[string]any{ - "query": "test query", - } - - result := tool.Execute(ctx, args) - - // Success should not be an error - if result.IsError { - t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) - } - - // ForUser should contain result titles and URLs - if !strings.Contains(result.ForUser, "Test Result 1") || - !strings.Contains(result.ForUser, "https://example.com/1") { - t.Errorf("Expected results in output, got: %s", result.ForUser) - } - - // Should mention via Tavily - if !strings.Contains(result.ForUser, "via Tavily") { - t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser) - } -} diff --git a/pkg/tools/web.go b/pkg/tools/web_search/web_search.go similarity index 63% rename from pkg/tools/web.go rename to pkg/tools/web_search/web_search.go index 8ba2a723a..ae984960b 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web_search/web_search.go @@ -1,4 +1,4 @@ -package tools +package web_search import ( "bytes" @@ -8,65 +8,12 @@ import ( "io" "net/http" "net/url" - "regexp" "strings" "time" -) - -const ( - userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" -) -// Pre-compiled regexes for HTML text extraction -var ( - reScript = regexp.MustCompile(``) - reStyle = regexp.MustCompile(``) - reTags = regexp.MustCompile(`<[^>]+>`) - reWhitespace = regexp.MustCompile(`[^\S\n]+`) - reBlankLines = regexp.MustCompile(`\n{3,}`) - - // DuckDuckGo result extraction - reDDGLink = regexp.MustCompile(`]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)`) - reDDGSnippet = regexp.MustCompile(`([\s\S]*?)`) + "github.com/sipeed/picoclaw/pkg/tools/common" ) -// createHTTPClient creates an HTTP client with optional proxy support -func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) { - client := &http.Client{ - Timeout: timeout, - Transport: &http.Transport{ - MaxIdleConns: 10, - IdleConnTimeout: 30 * time.Second, - DisableCompression: false, - TLSHandshakeTimeout: 15 * time.Second, - }, - } - - if proxyURL != "" { - proxy, err := url.Parse(proxyURL) - if err != nil { - return nil, fmt.Errorf("invalid proxy URL: %w", err) - } - scheme := strings.ToLower(proxy.Scheme) - switch scheme { - case "http", "https", "socks5", "socks5h": - default: - return nil, fmt.Errorf( - "unsupported proxy scheme %q (supported: http, https, socks5, socks5h)", - proxy.Scheme, - ) - } - if proxy.Host == "" { - return nil, fmt.Errorf("invalid proxy URL: missing host") - } - client.Transport.(*http.Transport).Proxy = http.ProxyURL(proxy) - } else { - client.Transport.(*http.Transport).Proxy = http.ProxyFromEnvironment - } - - return client, nil -} - type SearchProvider interface { Search(ctx context.Context, query string, count int) (string, error) } @@ -88,7 +35,7 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in req.Header.Set("Accept", "application/json") req.Header.Set("X-Subscription-Token", p.apiKey) - client, err := createHTTPClient(p.proxy, 10*time.Second) + client, err := common.CreateHTTPClient(p.proxy, 10*time.Second) if err != nil { return "", fmt.Errorf("failed to create HTTP client: %w", err) } @@ -172,9 +119,9 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", userAgent) + req.Header.Set("User-Agent", common.UserAgent) - client, err := createHTTPClient(p.proxy, 10*time.Second) + client, err := common.CreateHTTPClient(p.proxy, 10*time.Second) if err != nil { return "", fmt.Errorf("failed to create HTTP client: %w", err) } @@ -237,9 +184,9 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou return "", fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("User-Agent", userAgent) + req.Header.Set("User-Agent", common.UserAgent) - client, err := createHTTPClient(p.proxy, 10*time.Second) + client, err := common.CreateHTTPClient(p.proxy, 10*time.Second) if err != nil { return "", fmt.Errorf("failed to create HTTP client: %w", err) } @@ -264,7 +211,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query // Try finding the result links directly first, as they are the most critical // Pattern: Title // The previous regex was a bit strict. Let's make it more flexible for attributes order/content - matches := reDDGLink.FindAllStringSubmatch(html, count+5) + matches := common.ReDDGLink.FindAllStringSubmatch(html, count+5) if len(matches) == 0 { return fmt.Sprintf("No results found or extraction failed. Query: %s", query), nil @@ -281,7 +228,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query // A better regex approach: iterate through text and find matches in order // But for now, let's grab all snippets too - snippetMatches := reDDGSnippet.FindAllStringSubmatch(html, count+5) + snippetMatches := common.ReDDGSnippet.FindAllStringSubmatch(html, count+5) maxItems := min(len(matches), count) @@ -316,7 +263,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query } func stripTags(content string) string { - return reTags.ReplaceAllString(content, "") + return common.ReTags.ReplaceAllString(content, "") } type PerplexitySearchProvider struct { @@ -354,9 +301,9 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+p.apiKey) - req.Header.Set("User-Agent", userAgent) + req.Header.Set("User-Agent", common.UserAgent) - client, err := createHTTPClient(p.proxy, 30*time.Second) + client, err := common.CreateHTTPClient(p.proxy, 30*time.Second) if err != nil { return "", fmt.Errorf("failed to create HTTP client: %w", err) } @@ -481,10 +428,10 @@ func (t *WebSearchTool) Parameters() map[string]any { } } -func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult { +func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { query, ok := args["query"].(string) if !ok { - return ErrorResult("query is required") + return common.ErrorResult("query is required") } count := t.maxResults @@ -496,191 +443,11 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR result, err := t.provider.Search(ctx, query, count) if err != nil { - return ErrorResult(fmt.Sprintf("search failed: %v", err)) + return common.ErrorResult(fmt.Sprintf("search failed: %v", err)) } - return &ToolResult{ + return &common.ToolResult{ ForLLM: result, ForUser: result, } } - -type WebFetchTool struct { - maxChars int - proxy string -} - -func NewWebFetchTool(maxChars int) *WebFetchTool { - if maxChars <= 0 { - maxChars = 50000 - } - return &WebFetchTool{ - maxChars: maxChars, - } -} - -func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool { - if maxChars <= 0 { - maxChars = 50000 - } - return &WebFetchTool{ - maxChars: maxChars, - proxy: proxy, - } -} - -func (t *WebFetchTool) Name() string { - return "web_fetch" -} - -func (t *WebFetchTool) Description() string { - return "Fetch a URL and extract readable content (HTML to text). Use this to get weather info, news, articles, or any web content." -} - -func (t *WebFetchTool) Parameters() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "url": map[string]any{ - "type": "string", - "description": "URL to fetch", - }, - "maxChars": map[string]any{ - "type": "integer", - "description": "Maximum characters to extract", - "minimum": 100.0, - }, - }, - "required": []string{"url"}, - } -} - -func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolResult { - urlStr, ok := args["url"].(string) - if !ok { - return ErrorResult("url is required") - } - - parsedURL, err := url.Parse(urlStr) - if err != nil { - return ErrorResult(fmt.Sprintf("invalid URL: %v", err)) - } - - if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - return ErrorResult("only http/https URLs are allowed") - } - - if parsedURL.Host == "" { - return ErrorResult("missing domain in URL") - } - - maxChars := t.maxChars - if mc, ok := args["maxChars"].(float64); ok { - if int(mc) > 100 { - maxChars = int(mc) - } - } - - req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) - if err != nil { - return ErrorResult(fmt.Sprintf("failed to create request: %v", err)) - } - - req.Header.Set("User-Agent", userAgent) - - client, err := createHTTPClient(t.proxy, 60*time.Second) - if err != nil { - return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err)) - } - - // Configure redirect handling - client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - if len(via) >= 5 { - return fmt.Errorf("stopped after 5 redirects") - } - return nil - } - - resp, err := client.Do(req) - if err != nil { - return ErrorResult(fmt.Sprintf("request failed: %v", err)) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return ErrorResult(fmt.Sprintf("failed to read response: %v", err)) - } - - contentType := resp.Header.Get("Content-Type") - - var text, extractor string - - if strings.Contains(contentType, "application/json") { - var jsonData any - if err := json.Unmarshal(body, &jsonData); err == nil { - formatted, _ := json.MarshalIndent(jsonData, "", " ") - text = string(formatted) - extractor = "json" - } else { - text = string(body) - extractor = "raw" - } - } else if strings.Contains(contentType, "text/html") || len(body) > 0 && - (strings.HasPrefix(string(body), " maxChars - if truncated { - text = text[:maxChars] - } - - result := map[string]any{ - "url": urlStr, - "status": resp.StatusCode, - "extractor": extractor, - "truncated": truncated, - "length": len(text), - "text": text, - } - - resultJSON, _ := json.MarshalIndent(result, "", " ") - - return &ToolResult{ - ForLLM: fmt.Sprintf( - "Fetched %d bytes from %s (extractor: %s, truncated: %v)", - len(text), - urlStr, - extractor, - truncated, - ), - ForUser: string(resultJSON), - } -} - -func (t *WebFetchTool) extractText(htmlContent string) string { - result := reScript.ReplaceAllLiteralString(htmlContent, "") - result = reStyle.ReplaceAllLiteralString(result, "") - result = reTags.ReplaceAllLiteralString(result, "") - - result = strings.TrimSpace(result) - - result = reWhitespace.ReplaceAllString(result, " ") - result = reBlankLines.ReplaceAllString(result, "\n\n") - - lines := strings.Split(result, "\n") - var cleanLines []string - for _, line := range lines { - line = strings.TrimSpace(line) - if line != "" { - cleanLines = append(cleanLines, line) - } - } - - return strings.Join(cleanLines, "\n") -} diff --git a/pkg/tools/web_search/web_search_test.go b/pkg/tools/web_search/web_search_test.go new file mode 100644 index 000000000..b190d66bf --- /dev/null +++ b/pkg/tools/web_search/web_search_test.go @@ -0,0 +1,159 @@ +package web_search + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing +func TestWebTool_WebSearch_NoApiKey(t *testing.T) { + tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""}) + if tool != nil { + t.Errorf("Expected nil tool when Brave API key is empty") + } + + // Also nil when nothing is enabled + tool = NewWebSearchTool(WebSearchToolOptions{}) + if tool != nil { + t.Errorf("Expected nil tool when no provider is enabled") + } +} + +// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query +func TestWebTool_WebSearch_MissingQuery(t *testing.T) { + tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5}) + ctx := context.Background() + args := map[string]any{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when query is missing") + } +} + +func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { + t.Run("perplexity", func(t *testing.T) { + tool := NewWebSearchTool(WebSearchToolOptions{ + PerplexityEnabled: true, + PerplexityAPIKey: "k", + PerplexityMaxResults: 3, + Proxy: "http://127.0.0.1:7890", + }) + p, ok := tool.provider.(*PerplexitySearchProvider) + if !ok { + t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider) + } + if p.proxy != "http://127.0.0.1:7890" { + t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") + } + }) + + t.Run("brave", func(t *testing.T) { + tool := NewWebSearchTool(WebSearchToolOptions{ + BraveEnabled: true, + BraveAPIKey: "k", + BraveMaxResults: 3, + Proxy: "http://127.0.0.1:7890", + }) + p, ok := tool.provider.(*BraveSearchProvider) + if !ok { + t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider) + } + if p.proxy != "http://127.0.0.1:7890" { + t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") + } + }) + + t.Run("duckduckgo", func(t *testing.T) { + tool := NewWebSearchTool(WebSearchToolOptions{ + DuckDuckGoEnabled: true, + DuckDuckGoMaxResults: 3, + Proxy: "http://127.0.0.1:7890", + }) + p, ok := tool.provider.(*DuckDuckGoSearchProvider) + if !ok { + t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider) + } + if p.proxy != "http://127.0.0.1:7890" { + t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") + } + }) +} + +// TestWebTool_TavilySearch_Success verifies successful Tavily search +func TestWebTool_TavilySearch_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) + } + + // Verify payload + var payload map[string]any + json.NewDecoder(r.Body).Decode(&payload) + if payload["api_key"] != "test-key" { + t.Errorf("Expected api_key test-key, got %v", payload["api_key"]) + } + if payload["query"] != "test query" { + t.Errorf("Expected query 'test query', got %v", payload["query"]) + } + + // Return mock response + response := map[string]any{ + "results": []map[string]any{ + { + "title": "Test Result 1", + "url": "https://example.com/1", + "content": "Content for result 1", + }, + { + "title": "Test Result 2", + "url": "https://example.com/2", + "content": "Content for result 2", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + tool := NewWebSearchTool(WebSearchToolOptions{ + TavilyEnabled: true, + TavilyAPIKey: "test-key", + TavilyBaseURL: server.URL, + TavilyMaxResults: 5, + }) + + ctx := context.Background() + args := map[string]any{ + "query": "test query", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain result titles and URLs + if !strings.Contains(result.ForUser, "Test Result 1") || + !strings.Contains(result.ForUser, "https://example.com/1") { + t.Errorf("Expected results in output, got: %s", result.ForUser) + } + + // Should mention via Tavily + if !strings.Contains(result.ForUser, "via Tavily") { + t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser) + } +} diff --git a/pkg/tools/write_file/write_file.go b/pkg/tools/write_file/write_file.go new file mode 100644 index 000000000..4d14f3931 --- /dev/null +++ b/pkg/tools/write_file/write_file.go @@ -0,0 +1,65 @@ +package write_file + +import ( + "context" + "fmt" + + "github.com/sipeed/picoclaw/pkg/tools/common" +) + +type WriteFileTool struct { + fs common.FileSystem +} + +func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool { + var fs common.FileSystem + if restrict { + fs = &common.SandboxFs{Workspace: workspace} + } else { + fs = &common.HostFs{} + } + return &WriteFileTool{fs: fs} +} + +func (t *WriteFileTool) Name() string { + return "write_file" +} + +func (t *WriteFileTool) Description() string { + return "Write content to a file" +} + +func (t *WriteFileTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to the file to write", + }, + "content": map[string]any{ + "type": "string", + "description": "Content to write to the file", + }, + }, + "required": []string{"path", "content"}, + } +} + +func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *common.ToolResult { + path, ok := args["path"].(string) + if !ok { + return common.ErrorResult("path is required") + } + + content, ok := args["content"].(string) + if !ok { + return common.ErrorResult("content is required") + } + + if err := t.fs.WriteFile(path, []byte(content)); err != nil { + return common.ErrorResult(err.Error()) + } + + return common.SilentResult(fmt.Sprintf("File written: %s", path)) +} diff --git a/pkg/tools/write_file/write_file_test.go b/pkg/tools/write_file/write_file_test.go new file mode 100644 index 000000000..ae1457faa --- /dev/null +++ b/pkg/tools/write_file/write_file_test.go @@ -0,0 +1,491 @@ +package write_file + +import ( + "context" + "io" + "github.com/sipeed/picoclaw/pkg/tools/list_dir" + "github.com/sipeed/picoclaw/pkg/tools/read_file" + "github.com/sipeed/picoclaw/pkg/tools/common" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestFilesystemTool_ReadFile_Success verifies successful file reading +func TestFilesystemTool_ReadFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test content"), 0o644) + + tool := read_file.NewReadFileTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": testFile, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForLLM should contain file content + if !strings.Contains(result.ForLLM, "test content") { + t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM) + } + + // ReadFile returns NewToolResult which only sets ForLLM, not ForUser + // This is the expected behavior - file content goes to LLM, not directly to user + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser) + } +} + +// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file +func TestFilesystemTool_ReadFile_NotFound(t *testing.T) { + tool := read_file.NewReadFileTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": "/nonexistent_file_12345.txt", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for missing file, got IsError=false") + } + + // Should contain error message + if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { + t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path +func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) { + tool := &read_file.ReadFileTool{} + ctx := context.Background() + args := map[string]any{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } + + // Should mention required parameter + if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") { + t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestFilesystemTool_WriteFile_Success verifies successful file writing +func TestFilesystemTool_WriteFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "newfile.txt") + + tool := NewWriteFileTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": testFile, + "content": "hello world", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // WriteFile returns SilentResult + if !result.Silent { + t.Errorf("Expected Silent=true for WriteFile, got false") + } + + // ForUser should be empty (silent result) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) + } + + // Verify file was actually written + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read written file: %v", err) + } + if string(content) != "hello world" { + t.Errorf("Expected file content 'hello world', got: %s", string(content)) + } +} + +// TestFilesystemTool_WriteFile_CreateDir verifies directory creation +func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "subdir", "newfile.txt") + + tool := NewWriteFileTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": testFile, + "content": "test", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM) + } + + // Verify directory was created and file written + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read written file: %v", err) + } + if string(content) != "test" { + t.Errorf("Expected file content 'test', got: %s", string(content)) + } +} + +// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path +func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) { + tool := NewWriteFileTool("", false) + ctx := context.Background() + args := map[string]any{ + "content": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } +} + +// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content +func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) { + tool := NewWriteFileTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": "/tmp/test.txt", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when content is missing") + } + + // Should mention required parameter + if !strings.Contains(result.ForLLM, "content is required") && + !strings.Contains(result.ForUser, "content is required") { + t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestFilesystemTool_ListDir_Success verifies successful directory listing +func TestFilesystemTool_ListDir_Success(t *testing.T) { + tmpDir := t.TempDir() + os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0o644) + os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0o644) + os.Mkdir(filepath.Join(tmpDir, "subdir"), 0o755) + + tool := list_dir.NewListDirTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": tmpDir, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // Should list files and directories + if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") { + t.Errorf("Expected files in listing, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "subdir") { + t.Errorf("Expected subdir in listing, got: %s", result.ForLLM) + } +} + +// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory +func TestFilesystemTool_ListDir_NotFound(t *testing.T) { + tool := list_dir.NewListDirTool("", false) + ctx := context.Background() + args := map[string]any{ + "path": "/nonexistent_directory_12345", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for non-existent directory, got IsError=false") + } + + // Should contain error message + if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { + t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory +func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { + tool := list_dir.NewListDirTool("", false) + ctx := context.Background() + args := map[string]any{} + + result := tool.Execute(ctx, args) + + // Should use "." as default path + if result.IsError { + t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM) + } +} + +// Block paths that look inside workspace but point outside via symlink. +func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) { + root := t.TempDir() + workspace := filepath.Join(root, "workspace") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatalf("failed to create workspace: %v", err) + } + + secret := filepath.Join(root, "secret.txt") + if err := os.WriteFile(secret, []byte("top secret"), 0o644); err != nil { + t.Fatalf("failed to write secret file: %v", err) + } + + link := filepath.Join(workspace, "leak.txt") + if err := os.Symlink(secret, link); err != nil { + t.Skipf("symlink not supported in this environment: %v", err) + } + + tool := read_file.NewReadFileTool(workspace, true) + result := tool.Execute(context.Background(), map[string]any{ + "path": link, + }) + + if !result.IsError { + t.Fatalf("expected symlink escape to be blocked") + } + // os.Root might return different errors depending on platform/implementation + // but it definitely should error. + // Our wrapper returns "access denied or file not found" + if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") && + !strings.Contains(result.ForLLM, "no such file") { + t.Fatalf("expected symlink escape error, got: %s", result.ForLLM) + } +} + +func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) { + tool := read_file.NewReadFileTool("", true) // restrict=true but workspace="" + + // Try to read a sensitive file (simulated by a temp file outside workspace) + tmpDir := t.TempDir() + secretFile := filepath.Join(tmpDir, "shadow") + os.WriteFile(secretFile, []byte("secret data"), 0o600) + + result := tool.Execute(context.Background(), map[string]any{ + "path": secretFile, + }) + + // We EXPECT IsError=true (access blocked due to empty workspace) + assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM) + + // Verify it failed for the right reason + assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error") +} + +// TestRootMkdirAll verifies that root.MkdirAll (used by atomicWriteFileInRoot) handles all cases: +// single dir, deeply nested dirs, already-existing dirs, and a file blocking a directory path. +func TestRootMkdirAll(t *testing.T) { + workspace := t.TempDir() + root, err := os.OpenRoot(workspace) + if err != nil { + t.Fatalf("failed to open root: %v", err) + } + defer root.Close() + + // Case 1: Single directory + err = root.MkdirAll("dir1", 0o755) + assert.NoError(t, err) + _, err = os.Stat(filepath.Join(workspace, "dir1")) + assert.NoError(t, err) + + // Case 2: Deeply nested directory + err = root.MkdirAll("a/b/c/d", 0o755) + assert.NoError(t, err) + _, err = os.Stat(filepath.Join(workspace, "a/b/c/d")) + assert.NoError(t, err) + + // Case 3: Already exists — must be idempotent + err = root.MkdirAll("a/b/c/d", 0o755) + assert.NoError(t, err) + + // Case 4: A regular file blocks directory creation — must error + err = os.WriteFile(filepath.Join(workspace, "file_exists"), []byte("data"), 0o644) + assert.NoError(t, err) + err = root.MkdirAll("file_exists", 0o755) + assert.Error(t, err, "expected error when a file exists at the directory path") +} + +func TestFilesystemTool_WriteFile_Restricted_CreateDir(t *testing.T) { + workspace := t.TempDir() + tool := NewWriteFileTool(workspace, true) + ctx := context.Background() + + testFile := "deep/nested/path/to/file.txt" + content := "deep content" + args := map[string]any{ + "path": testFile, + "content": content, + } + + result := tool.Execute(ctx, args) + assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM) + + // Verify file content + actualPath := filepath.Join(workspace, testFile) + data, err := os.ReadFile(actualPath) + assert.NoError(t, err) + assert.Equal(t, content, string(data)) +} + +// TestHostRW_Read_PermissionDenied verifies that hostRW.Read surfaces access denied errors. +func TestHostRW_Read_PermissionDenied(t *testing.T) { + if os.Getuid() == 0 { + t.Skip("skipping permission test: running as root") + } + tmpDir := t.TempDir() + protected := filepath.Join(tmpDir, "protected.txt") + err := os.WriteFile(protected, []byte("secret"), 0o000) + assert.NoError(t, err) + defer os.Chmod(protected, 0o644) // ensure cleanup + + _, err = (&common.HostFs{}).ReadFile(protected) + assert.Error(t, err) + assert.Contains(t, err.Error(), "access denied") +} + +// TestHostRW_Read_Directory verifies that hostRW.Read returns an error when given a directory path. +func TestHostRW_Read_Directory(t *testing.T) { + tmpDir := t.TempDir() + + _, err := (&common.HostFs{}).ReadFile(tmpDir) + assert.Error(t, err, "expected error when reading a directory as a file") +} + +// TestRootRW_Read_Directory verifies that rootRW.Read returns an error when given a directory. +func TestRootRW_Read_Directory(t *testing.T) { + workspace := t.TempDir() + root, err := os.OpenRoot(workspace) + assert.NoError(t, err) + defer root.Close() + + // Create a subdirectory + err = root.Mkdir("subdir", 0o755) + assert.NoError(t, err) + + _, err = (&common.SandboxFs{Workspace: workspace}).ReadFile("subdir") + assert.Error(t, err, "expected error when reading a directory as a file") +} + +// TestHostRW_Write_ParentDirMissing verifies that hostRW.Write creates parent dirs automatically. +func TestHostRW_Write_ParentDirMissing(t *testing.T) { + tmpDir := t.TempDir() + target := filepath.Join(tmpDir, "a", "b", "c", "file.txt") + + err := (&common.HostFs{}).WriteFile(target, []byte("hello")) + assert.NoError(t, err) + + data, err := os.ReadFile(target) + assert.NoError(t, err) + assert.Equal(t, "hello", string(data)) +} + +// TestRootRW_Write_ParentDirMissing verifies that rootRW.Write creates +// nested parent directories automatically within the sandbox. +func TestRootRW_Write_ParentDirMissing(t *testing.T) { + workspace := t.TempDir() + + relPath := "x/y/z/file.txt" + err := (&common.SandboxFs{Workspace: workspace}).WriteFile(relPath, []byte("nested")) + assert.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(workspace, relPath)) + assert.NoError(t, err) + assert.Equal(t, "nested", string(data)) +} + +// TestHostRW_Write verifies the hostRW.Write helper function +func TestHostRW_Write(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "atomic_test.txt") + testData := []byte("atomic test content") + + err := (&common.HostFs{}).WriteFile(testFile, testData) + assert.NoError(t, err) + + content, err := os.ReadFile(testFile) + assert.NoError(t, err) + assert.Equal(t, testData, content) + + // Verify it overwrites correctly + newData := []byte("new atomic content") + err = (&common.HostFs{}).WriteFile(testFile, newData) + assert.NoError(t, err) + + content, err = os.ReadFile(testFile) + assert.NoError(t, err) + assert.Equal(t, newData, content) +} + +// TestRootRW_Write verifies the rootRW.Write helper function +func TestRootRW_Write(t *testing.T) { + tmpDir := t.TempDir() + + relPath := "atomic_root_test.txt" + testData := []byte("atomic root test content") + + erw := &common.SandboxFs{Workspace: tmpDir} + err := erw.WriteFile(relPath, testData) + assert.NoError(t, err) + + root, err := os.OpenRoot(tmpDir) + assert.NoError(t, err) + defer root.Close() + + f, err := root.Open(relPath) + assert.NoError(t, err) + defer f.Close() + + content, err := io.ReadAll(f) + assert.NoError(t, err) + assert.Equal(t, testData, content) + + // Verify it overwrites correctly + newData := []byte("new root atomic content") + err = erw.WriteFile(relPath, newData) + assert.NoError(t, err) + + f2, err := root.Open(relPath) + assert.NoError(t, err) + defer f2.Close() + + content, err = io.ReadAll(f2) + assert.NoError(t, err) + assert.Equal(t, newData, content) +} From cadf39c9ef6cadfe43e3ec6762858b0998ba40b2 Mon Sep 17 00:00:00 2001 From: lxowalle Date: Fri, 27 Feb 2026 03:31:09 +0800 Subject: [PATCH 2/2] * Fix fmt --- pkg/config/config.go | 16 ++++++++-------- pkg/tools/cron/cron.go | 4 +++- pkg/tools/i2c/i2c_linux.go | 4 +++- pkg/tools/spi/spi_linux.go | 4 +++- pkg/tools/types_export.go | 18 ++++++++++-------- pkg/tools/web_fetch/web_fetch_test.go | 3 ++- pkg/tools/write_file/write_file_test.go | 7 ++++--- 7 files changed, 33 insertions(+), 23 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 79a4dcae2..9306e6978 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -452,7 +452,7 @@ type PerplexityConfig struct { } type WebToolsConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_ENABLED"` + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_ENABLED"` Brave BraveConfig `json:"brave"` Tavily TavilyConfig `json:"tavily"` DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"` @@ -463,7 +463,7 @@ type WebToolsConfig struct { } type CronToolConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_CRON_ENABLED"` + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_CRON_ENABLED"` ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout } @@ -472,7 +472,7 @@ type ToolConfig struct { } type ExecConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_EXEC_ENABLED"` + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_EXEC_ENABLED"` EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"` CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"` } @@ -485,17 +485,17 @@ type ToolsConfig struct { Cron CronToolConfig `json:"cron"` // File tools - ReadFile ToolConfig `json:"read_file" env:"PICOCLAW_TOOLS_READ_FILE_ENABLED"` - WriteFile ToolConfig `json:"write_file" env:"PICOCLAW_TOOLS_WRITE_FILE_ENABLED"` - EditFile ToolConfig `json:"edit_file" env:"PICOCLAW_TOOLS_EDIT_FILE_ENABLED"` + ReadFile ToolConfig `json:"read_file" env:"PICOCLAW_TOOLS_READ_FILE_ENABLED"` + WriteFile ToolConfig `json:"write_file" env:"PICOCLAW_TOOLS_WRITE_FILE_ENABLED"` + EditFile ToolConfig `json:"edit_file" env:"PICOCLAW_TOOLS_EDIT_FILE_ENABLED"` AppendFile ToolConfig `json:"append_file" env:"PICOCLAW_TOOLS_APPEND_FILE_ENABLED"` - ListDir ToolConfig `json:"list_dir" env:"PICOCLAW_TOOLS_LIST_DIR_ENABLED"` + ListDir ToolConfig `json:"list_dir" env:"PICOCLAW_TOOLS_LIST_DIR_ENABLED"` // Exec tool Exec ExecConfig `json:"exec"` // Skills tools - FindSkills ToolConfig `json:"find_skills" env:"PICOCLAW_TOOLS_FIND_SKILLS_ENABLED"` + FindSkills ToolConfig `json:"find_skills" env:"PICOCLAW_TOOLS_FIND_SKILLS_ENABLED"` InstallSkill ToolConfig `json:"install_skill" env:"PICOCLAW_TOOLS_INSTALL_SKILL_ENABLED"` // Subagent tools diff --git a/pkg/tools/cron/cron.go b/pkg/tools/cron/cron.go index 6abf79f85..dd2ea5a97 100644 --- a/pkg/tools/cron/cron.go +++ b/pkg/tools/cron/cron.go @@ -137,7 +137,9 @@ func (t *CronTool) addJob(args map[string]any) *common.ToolResult { t.mu.RUnlock() if channel == "" || chatID == "" { - return common.ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.") + return common.ErrorResult( + "no session context (channel/chat_id not set). Use this tool in an active conversation.", + ) } message, ok := args["message"].(string) diff --git a/pkg/tools/i2c/i2c_linux.go b/pkg/tools/i2c/i2c_linux.go index 2d4f63250..7d2661388 100644 --- a/pkg/tools/i2c/i2c_linux.go +++ b/pkg/tools/i2c/i2c_linux.go @@ -85,7 +85,9 @@ func (t *I2CTool) scan(args map[string]any) *common.ToolResult { devPath := fmt.Sprintf("/dev/i2c-%s", bus) fd, err := syscall.Open(devPath, syscall.O_RDWR, 0) if err != nil { - return common.ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and i2c-dev module)", devPath, err)) + return common.ErrorResult( + fmt.Sprintf("failed to open %s: %v (check permissions and i2c-dev module)", devPath, err), + ) } defer syscall.Close(fd) diff --git a/pkg/tools/spi/spi_linux.go b/pkg/tools/spi/spi_linux.go index b1c25c600..91e5ab528 100644 --- a/pkg/tools/spi/spi_linux.go +++ b/pkg/tools/spi/spi_linux.go @@ -40,7 +40,9 @@ type spiTransfer struct { func configureSPI(devPath string, mode uint8, bits uint8, speed uint32) (int, *common.ToolResult) { fd, err := syscall.Open(devPath, syscall.O_RDWR, 0) if err != nil { - return -1, common.ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and spidev module)", devPath, err)) + return -1, common.ErrorResult( + fmt.Sprintf("failed to open %s: %v (check permissions and spidev module)", devPath, err), + ) } // Set SPI mode diff --git a/pkg/tools/types_export.go b/pkg/tools/types_export.go index 1ca1a04a5..a3eefca4f 100644 --- a/pkg/tools/types_export.go +++ b/pkg/tools/types_export.go @@ -4,14 +4,16 @@ import ( "github.com/sipeed/picoclaw/pkg/tools/common" ) -type Tool = common.Tool -type ToolResult = common.ToolResult -type ContextualTool = common.ContextualTool -type AsyncTool = common.AsyncTool -type AsyncCallback = common.AsyncCallback -type FileSystem = common.FileSystem -type HostFs = common.HostFs -type SandboxFs = common.SandboxFs +type ( + Tool = common.Tool + ToolResult = common.ToolResult + ContextualTool = common.ContextualTool + AsyncTool = common.AsyncTool + AsyncCallback = common.AsyncCallback + FileSystem = common.FileSystem + HostFs = common.HostFs + SandboxFs = common.SandboxFs +) func NewToolResult(forLLM string) *ToolResult { return common.NewToolResult(forLLM) diff --git a/pkg/tools/web_fetch/web_fetch_test.go b/pkg/tools/web_fetch/web_fetch_test.go index 78c58df27..80a6689b6 100644 --- a/pkg/tools/web_fetch/web_fetch_test.go +++ b/pkg/tools/web_fetch/web_fetch_test.go @@ -3,12 +3,13 @@ package web_fetch import ( "context" "encoding/json" - "github.com/sipeed/picoclaw/pkg/tools/common" "net/http" "net/http/httptest" "strings" "testing" "time" + + "github.com/sipeed/picoclaw/pkg/tools/common" ) // TestWebTool_WebFetch_Success verifies successful URL fetching diff --git a/pkg/tools/write_file/write_file_test.go b/pkg/tools/write_file/write_file_test.go index ae1457faa..8c5b20a70 100644 --- a/pkg/tools/write_file/write_file_test.go +++ b/pkg/tools/write_file/write_file_test.go @@ -3,15 +3,16 @@ package write_file import ( "context" "io" - "github.com/sipeed/picoclaw/pkg/tools/list_dir" - "github.com/sipeed/picoclaw/pkg/tools/read_file" - "github.com/sipeed/picoclaw/pkg/tools/common" "os" "path/filepath" "strings" "testing" "github.com/stretchr/testify/assert" + + "github.com/sipeed/picoclaw/pkg/tools/common" + "github.com/sipeed/picoclaw/pkg/tools/list_dir" + "github.com/sipeed/picoclaw/pkg/tools/read_file" ) // TestFilesystemTool_ReadFile_Success verifies successful file reading