diff --git a/.gitignore b/.gitignore index 3ff195fbf..c045ba265 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ -# Binaries -# Go build artifacts +# Binaries & build artifacts bin/ build/ +dist/ *.exe *.dll *.so @@ -12,35 +12,32 @@ build/ /picoclaw-test cmd/**/workspace -# Picoclaw specific - -# PicoClaw +# PicoClaw workspace & config .picoclaw/ config.json sessions/ -build/ - -# Coverage +cmd/picoclaw/workspace -# Secrets & Config (keep templates, ignore actual secrets) +# Secrets .env config/config.json -# Test +# Coverage coverage.txt coverage.html # OS .DS_Store -# Ralph workspace -ralph/ -.ralph/ -tasks/ - -# Editors +# Editors & tools .vscode/ .idea/ +.claude/ -# Added by goreleaser init: -dist/ +# Task tracking +TASKS.md + +# Legacy +ralph/ +.ralph/ +tasks/ diff --git a/pkg/agent/cmd_helpers_test.go b/pkg/agent/cmd_helpers_test.go new file mode 100644 index 000000000..f4ca967e6 --- /dev/null +++ b/pkg/agent/cmd_helpers_test.go @@ -0,0 +1,104 @@ +package agent + +import ( + "os" + "path/filepath" + "testing" +) + +func TestResolveEditPath_AbsoluteBlocked(t *testing.T) { + workspace := t.TempDir() + _, err := resolveEditPath("/etc/passwd", workspace, workspace) + if err == nil { + t.Fatal("Expected absolute path outside workspace to be blocked") + } +} + +func TestResolveEditPath_TildeIsWorkspace(t *testing.T) { + workspace := t.TempDir() + // Create a file so the path resolves + os.WriteFile(filepath.Join(workspace, "test.txt"), []byte("hi"), 0o644) + + path, err := resolveEditPath("~/test.txt", workspace, workspace) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + expected := filepath.Join(workspace, "test.txt") + if path != expected { + t.Errorf("Expected %s, got %s", expected, path) + } +} + +func TestResolveEditPath_BareTildeIsWorkspace(t *testing.T) { + workspace := t.TempDir() + path, err := resolveEditPath("~", workspace, workspace) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if path != workspace { + t.Errorf("Expected %s, got %s", workspace, path) + } +} + +func TestResolveEditPath_TraversalBlocked(t *testing.T) { + workspace := t.TempDir() + _, err := resolveEditPath("../../etc/passwd", workspace, workspace) + if err == nil { + t.Fatal("Expected path traversal to be blocked") + } +} + +func TestResolveEditPath_SymlinkBlocked(t *testing.T) { + root := t.TempDir() + workspace := filepath.Join(root, "workspace") + os.MkdirAll(workspace, 0o755) + secret := filepath.Join(root, "secret.txt") + os.WriteFile(secret, []byte("secret"), 0o644) + link := filepath.Join(workspace, "link.txt") + if err := os.Symlink(secret, link); err != nil { + t.Skip("symlinks not supported") + } + _, err := resolveEditPath("link.txt", workspace, workspace) + if err == nil { + t.Fatal("Expected symlink escape to be blocked") + } +} + +func TestResolveEditPath_ValidRelative(t *testing.T) { + workspace := t.TempDir() + subdir := filepath.Join(workspace, "subdir") + os.MkdirAll(subdir, 0o755) + testFile := filepath.Join(subdir, "test.txt") + os.WriteFile(testFile, []byte("content"), 0o644) + + path, err := resolveEditPath("subdir/test.txt", workspace, workspace) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if path != testFile { + t.Errorf("Expected %s, got %s", testFile, path) + } +} + +func TestShortenHomePath(t *testing.T) { + home, err := os.UserHomeDir() + if err != nil { + t.Skip("cannot get home dir") + } + + tests := []struct { + input string + expected string + }{ + {home, "~"}, + {filepath.Join(home, "projects"), "~/projects"}, + {"/tmp/other", "/tmp/other"}, + } + + for _, tt := range tests { + result := shortenHomePath(tt.input) + if result != tt.expected { + t.Errorf("shortenHomePath(%q) = %q, want %q", tt.input, result, tt.expected) + } + } +} diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index a6fd365c7..71c1ea0be 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -4,6 +4,7 @@ import ( "os" "path/filepath" "strings" + "sync/atomic" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" @@ -31,6 +32,21 @@ type AgentInstance struct { Subagents *config.SubagentsConfig SkillsFilter []string Candidates []providers.FallbackCandidate + + // Accumulated token usage counters (atomic for concurrent safety) + TotalPromptTokens atomic.Int64 + TotalCompletionTokens atomic.Int64 + TotalRequests atomic.Int64 +} + +// AddUsage accumulates token usage from a single LLM response. +func (a *AgentInstance) AddUsage(usage *providers.UsageInfo) { + if usage == nil { + return + } + a.TotalPromptTokens.Add(int64(usage.PromptTokens)) + a.TotalCompletionTokens.Add(int64(usage.CompletionTokens)) + a.TotalRequests.Add(1) } // NewAgentInstance creates an agent instance from config. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 693f2227b..d525dfebc 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -10,6 +10,9 @@ import ( "context" "encoding/json" "fmt" + "os" + "path/filepath" + "strconv" "strings" "sync" "sync/atomic" @@ -29,15 +32,25 @@ import ( "github.com/sipeed/picoclaw/pkg/utils" ) +// Session mode constants +type sessionMode int + +const ( + modePico sessionMode = iota // Default: messages → LLM + modeCmd // Command mode: messages → shell +) + type AgentLoop struct { - bus *bus.MessageBus - cfg *config.Config - registry *AgentRegistry - state *state.Manager - running atomic.Bool - summarizing sync.Map - fallback *providers.FallbackChain - channelManager *channels.Manager + bus *bus.MessageBus + cfg *config.Config + registry *AgentRegistry + state *state.Manager + running atomic.Bool + summarizing sync.Map + fallback *providers.FallbackChain + channelManager *channels.Manager + sessionModes sync.Map // per-session mode: sessionKey -> sessionMode + sessionWorkDirs sync.Map // per-session working dir: sessionKey -> string } // processOptions configures how a message is processed @@ -50,6 +63,7 @@ type processOptions struct { EnableSummary bool // Whether to trigger summarization SendResponse bool // Whether to send response via bus NoHistory bool // If true, don't load session history (for heartbeat) + WorkingDir string // Current working directory override (for hipico from cmd mode) } func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { @@ -79,6 +93,28 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers } } +func (al *AgentLoop) getSessionMode(sessionKey string) sessionMode { + if v, ok := al.sessionModes.Load(sessionKey); ok { + return v.(sessionMode) + } + return modePico +} + +func (al *AgentLoop) setSessionMode(sessionKey string, mode sessionMode) { + al.sessionModes.Store(sessionKey, mode) +} + +func (al *AgentLoop) getSessionWorkDir(sessionKey string) string { + if v, ok := al.sessionWorkDirs.Load(sessionKey); ok { + return v.(string) + } + return "" +} + +func (al *AgentLoop) setSessionWorkDir(sessionKey string, dir string) { + al.sessionWorkDirs.Store(sessionKey, dir) +} + // registerSharedTools registers tools that are shared across all agents (web, message, spawn). func registerSharedTools( cfg *config.Config, @@ -236,6 +272,23 @@ func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey stri return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct") } +// ProcessDirectWithWorkDir processes a message with an explicit working directory context. +// The workDir is injected into the system prompt so the AI resolves file paths relative to it. +func (al *AgentLoop) ProcessDirectWithWorkDir( + ctx context.Context, + content, sessionKey, workDir string, +) (string, error) { + msg := bus.InboundMessage{ + Channel: "cli", + SenderID: "cron", + ChatID: "direct", + Content: content, + SessionKey: sessionKey, + Metadata: map[string]string{"work_dir": workDir}, + } + return al.processMessage(ctx, msg) +} + func (al *AgentLoop) ProcessDirectWithChannel( ctx context.Context, content, sessionKey, channel, chatID string, @@ -321,15 +374,50 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) "matched_by": route.MatchedBy, }) - return al.runAgentLoop(ctx, agent, processOptions{ - SessionKey: sessionKey, - Channel: msg.Channel, - ChatID: msg.ChatID, - UserMessage: msg.Content, - DefaultResponse: "I've completed processing but have no response to give.", - EnableSummary: true, - SendResponse: false, - }) + // Handle mode-switching commands (:cmd, :pico, :hipico) + content := strings.TrimSpace(msg.Content) + if strings.HasPrefix(content, ":") { + if response, handled := al.handleModeCommand(content, sessionKey, agent); handled { + return response, nil + } + // :hipico falls through here — one-shot LLM call, stays in modeCmd + if strings.HasPrefix(content, ":hipico") { + userMessage := strings.TrimSpace(strings.TrimPrefix(content, ":hipico")) + workDir := al.getSessionWorkDir(sessionKey) + if workDir == "" { + workDir = agent.Workspace + } + hipicoSessionKey := sessionKey + ":hipico" + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: hipicoSessionKey, + Channel: msg.Channel, + ChatID: msg.ChatID, + UserMessage: userMessage, + DefaultResponse: "I've completed processing but have no response to give.", + EnableSummary: false, + SendResponse: false, + WorkingDir: workDir, + }) + } + } + + // Dispatch based on current session mode + switch al.getSessionMode(sessionKey) { + case modeCmd: + return al.executeCmdMode(ctx, agent, content, sessionKey, msg.Channel, msg.ChatID) + + default: // modePico + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: sessionKey, + Channel: msg.Channel, + ChatID: msg.ChatID, + UserMessage: msg.Content, + DefaultResponse: "I've completed processing but have no response to give.", + EnableSummary: true, + SendResponse: false, + WorkingDir: msg.Metadata["work_dir"], + }) + } } func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { @@ -420,6 +508,15 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt opts.ChatID, ) + // 2b. Inject current working directory into system prompt if set + if opts.WorkingDir != "" && len(messages) > 0 && messages[0].Role == "system" { + messages[0].Content += fmt.Sprintf( + "\n\n## Current Working Directory\nThe user is currently working in: %s\n"+ + "When the user refers to files or directories, resolve them relative to this path, not the workspace root.", + opts.WorkingDir, + ) + } + // 3. Save user message to session agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) @@ -594,6 +691,9 @@ func (al *AgentLoop) runLLMIteration( return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) } + // Accumulate token usage + agent.AddUsage(response.Usage) + // Check if no tool calls - we're done if len(response.ToolCalls) == 0 { finalContent = response.Content @@ -847,6 +947,25 @@ func (al *AgentLoop) GetStartupInfo() map[string]any { return info } +// GetUsageInfo returns accumulated token usage for the default agent. +func (al *AgentLoop) GetUsageInfo() map[string]any { + agent := al.registry.GetDefaultAgent() + if agent == nil { + return nil + } + promptTokens := agent.TotalPromptTokens.Load() + completionTokens := agent.TotalCompletionTokens.Load() + return map[string]any{ + "model": agent.Model, + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, + "prompt_tokens": promptTokens, + "completion_tokens": completionTokens, + "total_tokens": promptTokens + completionTokens, + "requests": agent.TotalRequests.Load(), + } +} + // formatMessagesForLog formats messages for logging func formatMessagesForLog(messages []providers.Message) string { if len(messages) == 0 { @@ -1031,6 +1150,12 @@ func (al *AgentLoop) estimateTokens(messages []providers.Message) int { func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) { content := strings.TrimSpace(msg.Content) + + // Handle : prefixed extension commands (work across all channels) + if strings.HasPrefix(content, ":") { + return al.handleExtensionCommand(content) + } + if !strings.HasPrefix(content, "/") { return "", false } @@ -1119,6 +1244,589 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) return "", false } +// handleExtensionCommand handles : prefixed commands that work across all channels. +func (al *AgentLoop) handleExtensionCommand(content string) (string, bool) { + parts := strings.Fields(content) + if len(parts) == 0 { + return "", false + } + + cmd := parts[0] + + switch cmd { + case ":cmd", ":pico", ":hipico", ":edit": + // Pass through to processMessage for mode handling (needs sessionKey from routing) + return "", false + + case ":help": + return `:help - Show this help message +:usage - Show model info and token usage +:cmd - Switch to command mode (execute shell commands) +:pico - Switch to chat mode (default, AI conversation) +:hipico - Ask AI for help (from command mode, one-shot) +:edit - View/edit files (cmd mode)`, true + + case ":usage": + agent := al.registry.GetDefaultAgent() + if agent == nil { + return "No agent available.", true + } + promptTokens := agent.TotalPromptTokens.Load() + completionTokens := agent.TotalCompletionTokens.Load() + return fmt.Sprintf(`Model: %s +Max tokens: %d +Temperature: %.1f + +Token usage (this session): + Prompt tokens: %d + Completion tokens: %d + Total tokens: %d + Requests: %d`, + agent.Model, + agent.MaxTokens, + agent.Temperature, + promptTokens, + completionTokens, + promptTokens+completionTokens, + agent.TotalRequests.Load(), + ), true + + default: + // Don't intercept unrecognized : prefixed messages (e.g. :) :D :thinking:) + // Let them pass through as normal chat messages + return "", false + } +} + +// handleModeCommand processes mode-switching commands (:cmd, :pico, :hipico). +// Returns (response, handled). If handled is true, the caller should return the response directly. +// For :hipico with a message, it returns ("", false) so processMessage continues with a one-shot LLM call. +func (al *AgentLoop) handleModeCommand(content, sessionKey string, agent *AgentInstance) (string, bool) { + parts := strings.Fields(content) + if len(parts) == 0 { + return "", false + } + + cmd := parts[0] + + switch cmd { + case ":cmd": + al.setSessionMode(sessionKey, modeCmd) + workDir := al.getSessionWorkDir(sessionKey) + if workDir == "" { + workDir = agent.Workspace + al.setSessionWorkDir(sessionKey, workDir) + } + displayDir := shortenHomePath(workDir) + return fmt.Sprintf("```\n%s$\n```\nType `:pico` to return to chat mode.", displayDir), true + + case ":pico": + al.setSessionMode(sessionKey, modePico) + return "Switched to chat mode. Type :cmd to enter command mode.", true + + case ":hipico": + msg := strings.TrimSpace(strings.TrimPrefix(content, ":hipico")) + if msg == "" { + return "Usage: :hipico \nExample: :hipico check the log files for errors", true + } + // Stay in modeCmd, just flag for one-shot LLM call — processMessage handles it + return "", false + } + + return "", false +} + +// executeCmdMode executes a shell command in command mode via ExecTool. +// Output is formatted as a console code block for channel display. +func (al *AgentLoop) executeCmdMode( + ctx context.Context, + agent *AgentInstance, + content, sessionKey, channel, chatID string, +) (string, error) { + content = strings.TrimSpace(content) + if content == "" { + return "", nil + } + + // Handle cd command specially + if content == "cd" || strings.HasPrefix(content, "cd ") { + return al.handleCdCommand(content, sessionKey, agent), nil + } + + // Handle :edit command + if content == ":edit" || strings.HasPrefix(content, ":edit ") { + workDir := al.getSessionWorkDir(sessionKey) + if workDir == "" { + workDir = agent.Workspace + } + return al.handleEditCommand(content, workDir, agent.Workspace), nil + } + + // Intercept interactive editors + if msg := interceptEditor(content); msg != "" { + return msg, nil + } + + // Get working directory + workDir := al.getSessionWorkDir(sessionKey) + if workDir == "" { + workDir = agent.Workspace + } + + // Execute via ExecTool + result := agent.Tools.ExecuteWithContext(ctx, "exec", map[string]any{ + "command": content, + "working_dir": workDir, + }, channel, chatID, nil) + + displayDir := shortenHomePath(workDir) + output := result.ForLLM + if output == "" { + output = "(no output)" + } + + // Colorize ls output with emoji type indicators (only when user explicitly used ls -l) + if isLsCommand(content) && hasLongFlag(content) { + output = formatLsOutput(output) + } + + // Format as console code block: prompt line + output (show original command, not modified) + return fmt.Sprintf("```\n%s$ %s\n%s\n```", displayDir, content, output), nil +} + +// handleCdCommand handles the cd command in command mode, updating per-session working directory. +// Special paths (cd, cd ~, cd /, cd /xxx) are redirected to the workspace directory for safety. +func (al *AgentLoop) handleCdCommand(content, sessionKey string, agent *AgentInstance) string { + parts := strings.Fields(content) + workspace := agent.Workspace + var target string + + if len(parts) < 2 || parts[1] == "~" || parts[1] == "/" { + // cd, cd ~, cd / → always go to workspace + target = workspace + } else { + target = parts[1] + // Expand ~ prefix: treat ~ as workspace root (not $HOME) + if strings.HasPrefix(target, "~/") { + target = workspace + target[1:] + } + // Absolute paths (e.g. cd /etc) → redirect to workspace + if filepath.IsAbs(target) { + target = workspace + } + // Resolve relative paths + if !filepath.IsAbs(target) { + currentDir := al.getSessionWorkDir(sessionKey) + if currentDir == "" { + currentDir = workspace + } + target = filepath.Join(currentDir, target) + } + } + + target = filepath.Clean(target) + + // Prevent traversal outside workspace via ../ + if !strings.HasPrefix(target, workspace) { + target = workspace + } + + info, err := os.Stat(target) + if err != nil { + return fmt.Sprintf("cd: %s: No such file or directory", target) + } + if !info.IsDir() { + return fmt.Sprintf("cd: %s: Not a directory", target) + } + + al.setSessionWorkDir(sessionKey, target) + return fmt.Sprintf("```\n%s$\n```", shortenHomePath(target)) +} + +// shortenHomePath replaces the user's home directory prefix with ~ for display. +func shortenHomePath(path string) string { + home, err := os.UserHomeDir() + if err != nil || home == "" { + return path + } + if path == home { + return "~" + } + if strings.HasPrefix(path, home+"/") { + return "~" + path[len(home):] + } + return path +} + +// handleEditCommand processes :edit commands for file viewing and editing in cmd mode. +// Syntax: +// +// :edit → show usage +// :edit → show file with line numbers +// :edit → replace line N +// :edit + → insert after line N +// :edit - → delete line N +// :edit -m """""" → write full content (create if needed) +func (al *AgentLoop) handleEditCommand(content, workDir, workspace string) string { + raw := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(content), ":edit")) + if raw == "" { + return editUsage() + } + + // Split on first newline to get the command line + firstLine := raw + if idx := strings.Index(raw, "\n"); idx != -1 { + firstLine = raw[:idx] + } + + parts := strings.Fields(firstLine) + if len(parts) == 0 { + return editUsage() + } + + filename, err := resolveEditPath(parts[0], workDir, workspace) + if err != nil { + return fmt.Sprintf("Access denied: %s", err) + } + + // :edit — show file content + if len(parts) == 1 && !strings.Contains(raw, "\n") { + return editShowFile(filename) + } + + // :edit -m """...""" + if len(parts) >= 2 && parts[1] == "-m" { + return editMultiline(filename, raw) + } + + // Line operations: N text, +N text, -N + if len(parts) >= 2 { + // Get raw text after the line-op token (preserves original spacing) + afterFile := strings.TrimSpace(firstLine[len(parts[0]):]) + return editLineOp(filename, afterFile) + } + + return editUsage() +} + +func resolveEditPath(name, workDir, workspace string) (string, error) { + // Treat ~ as workspace root (not $HOME) + if name == "~" { + name = "." + } else if strings.HasPrefix(name, "~/") { + name = name[2:] + } + // Resolve relative paths against workDir + if !filepath.IsAbs(name) { + name = filepath.Join(workDir, name) + } + // Validate against workspace (blocks absolute paths outside workspace, symlink escape, traversal) + return tools.ValidatePath(name, workspace, true) +} + +func editUsage() string { + return "Usage:\n" + + " :edit — view file\n" + + " :edit — replace line N\n" + + " :edit + — insert after line N\n" + + " :edit - — delete line N\n" + + " :edit -m \"\"\" — write content\n" + + " \n" + + " \"\"\"" +} + +func editShowFile(path string) string { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return fmt.Sprintf( + "File not found: %s\nUse :edit %s -m \"\"\" to create it.", + shortenHomePath(path), + filepath.Base(path), + ) + } + return fmt.Sprintf("Error reading file: %v", err) + } + + lines := strings.Split(string(data), "\n") + // Remove trailing empty line that Split produces + if len(lines) > 0 && lines[len(lines)-1] == "" { + lines = lines[:len(lines)-1] + } + + const maxLines = 50 + var b strings.Builder + b.WriteString(fmt.Sprintf("``` %s (%d lines)\n", filepath.Base(path), len(lines))) + if len(lines) <= maxLines { + for i, line := range lines { + b.WriteString(fmt.Sprintf("%4d│ %s\n", i+1, line)) + } + } else { + for i := 0; i < maxLines; i++ { + b.WriteString(fmt.Sprintf("%4d│ %s\n", i+1, lines[i])) + } + b.WriteString(fmt.Sprintf(" ...│ (%d more lines)\n", len(lines)-maxLines)) + } + b.WriteString("```") + return b.String() +} + +func editMultiline(filename, raw string) string { + // raw = ` -m """..."""` + start := strings.Index(raw, `"""`) + if start == -1 { + return editUsage() + } + rest := raw[start+3:] + // Trim leading newline after opening """ + rest = strings.TrimPrefix(rest, "\n") + + // Find closing """ + end := strings.LastIndex(rest, `"""`) + if end == -1 || end == 0 { + // No closing triple-quote — use entire rest as content + end = len(rest) + } + content := rest[:end] + + // Ensure trailing newline + if content != "" && !strings.HasSuffix(content, "\n") { + content += "\n" + } + + // Create parent dirs if needed + dir := filepath.Dir(filename) + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Sprintf("Error creating directory: %v", err) + } + + if err := os.WriteFile(filename, []byte(content), 0o644); err != nil { + return fmt.Sprintf("Error writing file: %v", err) + } + + lineCount := strings.Count(content, "\n") + return fmt.Sprintf("```\n✓ Wrote %d lines → %s\n```", lineCount, shortenHomePath(filename)) +} + +func editLineOp(filename, rawArgs string) string { + rawArgs = strings.TrimSpace(rawArgs) + // Split into op token and text + spaceIdx := strings.IndexByte(rawArgs, ' ') + var op, text string + if spaceIdx == -1 { + op = rawArgs + } else { + op = rawArgs[:spaceIdx] + text = rawArgs[spaceIdx+1:] + } + + var lineNum int + var action string // "replace", "insert", "delete" + var err error + + if strings.HasPrefix(op, "+") { + action = "insert" + lineNum, err = strconv.Atoi(op[1:]) + } else if strings.HasPrefix(op, "-") { + action = "delete" + lineNum, err = strconv.Atoi(op[1:]) + } else { + action = "replace" + lineNum, err = strconv.Atoi(op) + } + if err != nil || lineNum < 1 { + return "Invalid line number. Use a positive integer." + } + + // Read existing file + data, err := os.ReadFile(filename) + if err != nil { + if os.IsNotExist(err) { + return fmt.Sprintf("File not found: %s", shortenHomePath(filename)) + } + return fmt.Sprintf("Error reading file: %v", err) + } + + lines := strings.Split(string(data), "\n") + if len(lines) > 0 && lines[len(lines)-1] == "" { + lines = lines[:len(lines)-1] + } + + switch action { + case "delete": + if lineNum > len(lines) { + return fmt.Sprintf("Line %d out of range (file has %d lines).", lineNum, len(lines)) + } + deleted := lines[lineNum-1] + lines = append(lines[:lineNum-1], lines[lineNum:]...) + if err := os.WriteFile(filename, []byte(strings.Join(lines, "\n")+"\n"), 0o644); err != nil { + return fmt.Sprintf("Error writing file: %v", err) + } + return fmt.Sprintf("```\n✓ Deleted line %d: %s\n(%d lines remaining)\n```", lineNum, deleted, len(lines)) + + case "replace": + if text == "" { + return "Usage: :edit " + } + if lineNum > len(lines) { + return fmt.Sprintf("Line %d out of range (file has %d lines).", lineNum, len(lines)) + } + old := lines[lineNum-1] + lines[lineNum-1] = text + if err := os.WriteFile(filename, []byte(strings.Join(lines, "\n")+"\n"), 0o644); err != nil { + return fmt.Sprintf("Error writing file: %v", err) + } + return fmt.Sprintf("```\n✓ Line %d replaced\n was: %s\n now: %s\n```", lineNum, old, text) + + case "insert": + if text == "" { + return "Usage: :edit + " + } + if lineNum > len(lines) { + lineNum = len(lines) // insert at end + } + newLines := make([]string, 0, len(lines)+1) + newLines = append(newLines, lines[:lineNum]...) + newLines = append(newLines, text) + newLines = append(newLines, lines[lineNum:]...) + if err := os.WriteFile(filename, []byte(strings.Join(newLines, "\n")+"\n"), 0o644); err != nil { + return fmt.Sprintf("Error writing file: %v", err) + } + return fmt.Sprintf("```\n✓ Inserted after line %d: %s\n(%d lines total)\n```", lineNum, text, len(newLines)) + } + + return editUsage() +} + +// interceptEditor detects interactive editor commands and returns a helpful redirect message. +func interceptEditor(cmd string) string { + parts := strings.Fields(cmd) + if len(parts) == 0 { + return "" + } + name := parts[0] + switch name { + case "vim", "vi", "nvim", "nano", "emacs", "pico", "joe", "mcedit": + return fmt.Sprintf("⚠ %s requires a terminal and cannot run here.\nUse :edit instead:\n\n"+ + ":edit — view file\n"+ + ":edit -m \"\"\" — write content\n"+ + "\n"+ + "\"\"\"\n\n"+ + "Type :help for all commands.", name) + } + return "" +} + +// isLsCommand checks if a shell command is an ls invocation. +func isLsCommand(cmd string) bool { + cmd = strings.TrimSpace(cmd) + return cmd == "ls" || strings.HasPrefix(cmd, "ls ") +} + +// hasLongFlag checks if an ls command already includes the -l flag. +func hasLongFlag(cmd string) bool { + for _, p := range strings.Fields(cmd)[1:] { + if strings.HasPrefix(p, "-") && !strings.HasPrefix(p, "--") && strings.ContainsRune(p, 'l') { + return true + } + } + return false +} + +// formatLsOutput adds emoji type indicators to ls -l style output lines. +func formatLsOutput(output string) string { + lines := strings.Split(output, "\n") + for i, line := range lines { + lines[i] = formatLsLine(line) + } + return strings.Join(lines, "\n") +} + +// formatLsLine adds an emoji prefix to a single ls -l output line based on file type. +func formatLsLine(line string) string { + // Skip empty lines, "total" line, and lines too short to be ls -l + if line == "" || strings.HasPrefix(line, "total ") || len(line) < 10 { + return line + } + + // Check if line starts with a permission string (e.g. drwxr-xr-x) + perms := line[:10] + if !isPermString(perms) { + return line + } + + fileType := perms[0] + var emoji string + switch fileType { + case 'd': + emoji = "\U0001F4C1" // 📁 + case 'l': + emoji = "\U0001F517" // 🔗 + case 'b', 'c': + emoji = "\U0001F4BE" // 💾 + case 'p', 's': + emoji = "\U0001F50C" // 🔌 + default: + // Regular file: check executable bit (owner/group/other x positions) + if perms[3] == 'x' || perms[6] == 'x' || perms[9] == 'x' { + emoji = "\u26A1" // ⚡ + } else { + emoji = fileEmojiByExt(line) + } + } + + return emoji + " " + line +} + +// isPermString checks if a 10-char string looks like a Unix permission string. +func isPermString(s string) bool { + if len(s) != 10 { + return false + } + // First char: file type + switch s[0] { + case '-', 'd', 'l', 'b', 'c', 'p', 's': + default: + return false + } + // Remaining 9 chars: rwx or - (plus s/S/t/T for setuid/setgid/sticky) + for _, c := range s[1:] { + switch c { + case 'r', 'w', 'x', '-', 's', 'S', 't', 'T': + default: + return false + } + } + return true +} + +// fileEmojiByExt returns an emoji based on the file extension found in an ls -l line. +func fileEmojiByExt(line string) string { + // Extract filename: last whitespace-delimited field (for symlinks, take before " -> ") + name := line + if idx := strings.LastIndex(line, " -> "); idx != -1 { + name = line[:idx] + } + if idx := strings.LastIndex(name, " "); idx != -1 { + name = name[idx+1:] + } + name = strings.ToLower(name) + + ext := filepath.Ext(name) + switch ext { + case ".jpg", ".jpeg", ".png", ".gif", ".svg", ".webp", ".bmp", ".ico", ".tiff": + return "\U0001F5BC" // 🖼 + case ".mp3", ".wav", ".flac", ".aac", ".ogg", ".wma", ".m4a": + return "\U0001F3B5" // 🎵 + case ".mp4", ".avi", ".mkv", ".mov", ".webm", ".flv", ".wmv": + return "\U0001F3AC" // 🎬 + case ".zip", ".tar", ".gz", ".bz2", ".xz", ".7z", ".rar", ".zst", ".tgz": + return "\U0001F4E6" // 📦 + default: + return "\U0001F4C4" // 📄 + } +} + // extractPeer extracts the routing peer from inbound message metadata. func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { peerKind := msg.Metadata["peer_kind"] diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 4414398b1..6dc1ccc7f 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "testing" "time" @@ -631,3 +632,73 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory)) } } + +// TestHandleCdCommand_TraversalBlocked verifies that cd ../../.. cannot escape workspace. +func TestHandleCdCommand_TraversalBlocked(t *testing.T) { + workspace := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: workspace, + Model: "test-model", + MaxTokens: 4096, + }, + }, + } + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + agent := al.registry.GetDefaultAgent() + + // Set working dir to a subdir inside workspace + subdir := filepath.Join(workspace, "a", "b") + os.MkdirAll(subdir, 0o755) + al.setSessionWorkDir("test", subdir) + + // Try to escape via ../../../.. + result := al.handleCdCommand("cd ../../../..", "test", agent) + workDir := al.getSessionWorkDir("test") + + if !strings.HasPrefix(workDir, workspace) { + t.Errorf("cd traversal escaped workspace: workDir=%s, workspace=%s", workDir, workspace) + } + // Should land in workspace, not outside + if workDir != workspace { + t.Errorf("Expected workDir=%s, got %s", workspace, workDir) + } + _ = result +} + +// TestHandleExtensionCommand_EmojiPassthrough verifies that emoji-like +// messages starting with : are not intercepted as commands. +func TestHandleExtensionCommand_EmojiPassthrough(t *testing.T) { + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: t.TempDir(), + Model: "test-model", + MaxTokens: 4096, + }, + }, + } + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + emojiInputs := []string{":)", ":D", ":heart:", ":thinking:", ":-)", ":100:"} + for _, input := range emojiInputs { + _, handled := al.handleExtensionCommand(input) + if handled { + t.Errorf("Expected %q to pass through (not handled), but it was handled", input) + } + } + + // Known commands should still be handled + knownCommands := []string{":help", ":usage"} + for _, cmd := range knownCommands { + _, handled := al.handleExtensionCommand(cmd) + if !handled { + t.Errorf("Expected %q to be handled, but it was not", cmd) + } + } +} diff --git a/pkg/agent/ls_helpers_test.go b/pkg/agent/ls_helpers_test.go new file mode 100644 index 000000000..652f1d6de --- /dev/null +++ b/pkg/agent/ls_helpers_test.go @@ -0,0 +1,64 @@ +package agent + +import "testing" + +func TestIsLsCommand(t *testing.T) { + tests := []struct { + cmd string + want bool + }{ + {"ls", true}, + {"ls -la", true}, + {"ls /tmp", true}, + {"lsof", false}, + {"echo ls", false}, + {"", false}, + } + for _, tt := range tests { + if got := isLsCommand(tt.cmd); got != tt.want { + t.Errorf("isLsCommand(%q) = %v, want %v", tt.cmd, got, tt.want) + } + } +} + +func TestHasLongFlag(t *testing.T) { + tests := []struct { + cmd string + want bool + }{ + {"ls", false}, + {"ls -l", true}, + {"ls -la", true}, + {"ls -al", true}, + {"ls --color", false}, + {"ls -a /tmp", false}, + {"ls -l --color /tmp", true}, + } + for _, tt := range tests { + if got := hasLongFlag(tt.cmd); got != tt.want { + t.Errorf("hasLongFlag(%q) = %v, want %v", tt.cmd, got, tt.want) + } + } +} + +func TestIsPermString(t *testing.T) { + tests := []struct { + s string + want bool + }{ + {"drwxr-xr-x", true}, + {"-rw-r--r--", true}, + {"lrwxrwxrwx", true}, + {"-rwsr-xr-x", true}, // setuid + {"drwxrwxrwt", true}, // sticky + {"hello world", false}, // wrong length/chars + {"----------", true}, + {"xrwxrwxrwx", false}, // invalid first char + {"", false}, + } + for _, tt := range tests { + if got := isPermString(tt.s); got != tt.want { + t.Errorf("isPermString(%q) = %v, want %v", tt.s, got, tt.want) + } + } +} diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 37db8b4ae..76d6f91e9 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -10,8 +10,8 @@ import ( "time" ) -// validatePath ensures the given path is within the workspace if restrict is true. -func validatePath(path, workspace string, restrict bool) (string, error) { +// ValidatePath ensures the given path is within the workspace if restrict is true. +func ValidatePath(path, workspace string, restrict bool) (string, error) { if workspace == "" { return path, fmt.Errorf("workspace is not defined") } diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index ad1664b5b..e4e1bef9b 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -143,7 +143,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult cwd := t.workingDir if wd, ok := args["working_dir"].(string); ok && wd != "" { if t.restrictToWorkspace && t.workingDir != "" { - resolvedWD, err := validatePath(wd, t.workingDir, true) + resolvedWD, err := ValidatePath(wd, t.workingDir, true) if err != nil { return ErrorResult("Command blocked by safety guard (" + err.Error() + ")") } @@ -289,9 +289,16 @@ func (t *ExecTool) guardCommand(command, cwd string) string { } pathPattern := regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`) - matches := pathPattern.FindAllString(cmd, -1) + matchIndices := pathPattern.FindAllStringIndex(cmd, -1) + + for _, loc := range matchIndices { + raw := cmd[loc[0]:loc[1]] + // Skip relative paths like ./executable — the regex extracts + // "/executable" from "./executable" but it's not an absolute path. + if loc[0] > 0 && cmd[loc[0]-1] == '.' { + continue + } - for _, raw := range matches { p, err := filepath.Abs(raw) if err != nil { continue diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index 6d35815e8..a8bd603cc 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -272,3 +272,27 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) { ) } } + +// TestGuardCommand_DotSlashExecutable verifies that ./executable style +// commands are NOT blocked by the path extraction regex in guardCommand. +func TestGuardCommand_DotSlashExecutable(t *testing.T) { + tmpDir := t.TempDir() + tool := NewExecTool(tmpDir, true) + + // Create a test script in the workspace + scriptPath := filepath.Join(tmpDir, "test.sh") + os.WriteFile(scriptPath, []byte("#!/bin/sh\necho ok"), 0o755) + + ctx := context.Background() + result := tool.Execute(ctx, map[string]any{ + "command": "./test.sh", + "working_dir": tmpDir, + }) + + if result.IsError { + t.Errorf("Expected ./test.sh to be allowed, got error: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "ok") { + t.Errorf("Expected output 'ok', got: %s", result.ForLLM) + } +}