diff --git a/.golangci.yml b/.golangci.yml index f52fd17..7875193 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -381,7 +381,7 @@ linters: - contextcheck # contextcheck: flow.go uses InvocationContext which wraps context - - path: internal/agent/chat/flow\.go + - path: internal/chat/flow\.go linters: - contextcheck diff --git a/cmd/addr.go b/cmd/addr.go new file mode 100644 index 0000000..5893230 --- /dev/null +++ b/cmd/addr.go @@ -0,0 +1,74 @@ +package cmd + +import ( + "flag" + "fmt" + "net" + "os" + "strconv" + "strings" +) + +// parseServeAddr parses and validates the server address from command line arguments. +// Uses flag.FlagSet for standard Go flag parsing, supporting: +// - koopa serve :8080 (positional) +// - koopa serve --addr :8080 (flag) +// - koopa serve -addr :8080 (single dash) +func parseServeAddr() (string, error) { + const defaultAddr = "127.0.0.1:3400" + + serveFlags := flag.NewFlagSet("serve", flag.ContinueOnError) + serveFlags.SetOutput(os.Stderr) + + addr := serveFlags.String("addr", defaultAddr, "Server address (host:port)") + + args := []string{} + if len(os.Args) > 2 { + args = os.Args[2:] + } + + // Check for positional argument first (koopa serve :8080) + if len(args) > 0 && !strings.HasPrefix(args[0], "-") { + *addr = args[0] + args = args[1:] + } + + if err := serveFlags.Parse(args); err != nil { + return "", fmt.Errorf("parsing serve flags: %w", err) + } + + if err := validateAddr(*addr); err != nil { + return "", fmt.Errorf("invalid address %q: %w", *addr, err) + } + + return *addr, nil +} + +// validateAddr validates the server address format. +func validateAddr(addr string) error { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return fmt.Errorf("must be in host:port format: %w", err) + } + + if host != "" && host != "localhost" { + if ip := net.ParseIP(host); ip == nil { + if strings.ContainsAny(host, " \t\n") { + return fmt.Errorf("invalid host: %s", host) + } + } + } + + if port == "" { + return fmt.Errorf("port is required") + } + portNum, err := strconv.Atoi(port) + if err != nil { + return fmt.Errorf("port must be numeric: %w", err) + } + if portNum < 0 || portNum > 65535 { + return fmt.Errorf("port must be 0-65535 (0 = auto-assign), got %d", portNum) + } + + return nil +} diff --git a/cmd/addr_test.go b/cmd/addr_test.go new file mode 100644 index 0000000..fe12772 --- /dev/null +++ b/cmd/addr_test.go @@ -0,0 +1,68 @@ +package cmd + +import "testing" + +func TestValidateAddr(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + addr string + wantErr bool + }{ + // Valid addresses + {name: "port only", addr: ":8080", wantErr: false}, + {name: "localhost", addr: "localhost:3400", wantErr: false}, + {name: "loopback", addr: "127.0.0.1:3400", wantErr: false}, + {name: "all interfaces", addr: "0.0.0.0:80", wantErr: false}, + {name: "ipv6 loopback", addr: "[::1]:8080", wantErr: false}, + {name: "port zero", addr: ":0", wantErr: false}, + {name: "port max", addr: ":65535", wantErr: false}, + {name: "hostname", addr: "myhost:9090", wantErr: false}, + + // Invalid: bad format + {name: "no port", addr: "localhost", wantErr: true}, + {name: "port alone", addr: "8080", wantErr: true}, + {name: "empty string", addr: "", wantErr: true}, + + // Invalid: bad port + {name: "port non-numeric", addr: ":abc", wantErr: true}, + {name: "port negative", addr: ":-1", wantErr: true}, + {name: "port too high", addr: ":65536", wantErr: true}, + {name: "port empty after colon", addr: "localhost:", wantErr: true}, + + // Invalid: bad host + {name: "host with space", addr: "my host:8080", wantErr: true}, + {name: "host with tab", addr: "my\thost:8080", wantErr: true}, + {name: "host with newline", addr: "my\nhost:8080", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateAddr(tt.addr) + if tt.wantErr && err == nil { + t.Errorf("validateAddr(%q) = nil, want error", tt.addr) + } + if !tt.wantErr && err != nil { + t.Errorf("validateAddr(%q) = %v, want nil", tt.addr, err) + } + }) + } +} + +func FuzzValidateAddr(f *testing.F) { + f.Add(":8080") + f.Add("localhost:3400") + f.Add("127.0.0.1:80") + f.Add("") + f.Add("abc") + f.Add(":0") + f.Add(":99999") + f.Add("[::1]:8080") + f.Add("host with space:80") + + f.Fuzz(func(t *testing.T, addr string) { + _ = validateAddr(addr) // must not panic + }) +} diff --git a/cmd/cli.go b/cmd/cli.go index fcfede2..6356ebc 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -2,7 +2,6 @@ package cmd import ( "context" - "errors" "fmt" "log/slog" "os/signal" @@ -11,8 +10,8 @@ import ( tea "charm.land/bubbletea/v2" "github.com/koopa0/koopa/internal/app" + "github.com/koopa0/koopa/internal/chat" "github.com/koopa0/koopa/internal/config" - "github.com/koopa0/koopa/internal/session" "github.com/koopa0/koopa/internal/tui" ) @@ -20,63 +19,42 @@ import ( func runCLI() error { cfg, err := config.Load() if err != nil { - return err + return fmt.Errorf("loading config: %w", err) } ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() - runtime, err := app.NewChatRuntime(ctx, cfg) + a, err := app.Setup(ctx, cfg) if err != nil { - return fmt.Errorf("failed to initialize runtime: %w", err) + return fmt.Errorf("initializing application: %w", err) } defer func() { - if closeErr := runtime.Close(); closeErr != nil { - slog.Warn("runtime close error", "error", closeErr) + if closeErr := a.Close(); closeErr != nil { + slog.Warn("shutdown error", "error", closeErr) } }() - sessionID, err := getOrCreateSessionID(ctx, runtime.App.SessionStore, cfg) + agent, err := a.CreateAgent() if err != nil { - return fmt.Errorf("failed to get session: %w", err) + return fmt.Errorf("creating agent: %w", err) } - model, err := tui.New(ctx, runtime.Flow, sessionID) - if err != nil { - return fmt.Errorf("failed to create TUI: %w", err) - } - program := tea.NewProgram(model, tea.WithContext(ctx)) - - if _, err = program.Run(); err != nil { - return fmt.Errorf("TUI exited: %w", err) - } - return nil -} + flow := chat.NewFlow(a.Genkit, agent) -// getOrCreateSessionID returns a valid session ID, creating a new session if needed. -func getOrCreateSessionID(ctx context.Context, store *session.Store, cfg *config.Config) (string, error) { - currentID, err := session.LoadCurrentSessionID() + sessionID, err := a.SessionStore.ResolveCurrentSession(ctx) if err != nil { - return "", fmt.Errorf("failed to load session: %w", err) + return fmt.Errorf("resolving session: %w", err) } - if currentID != nil { - if _, err = store.Session(ctx, *currentID); err == nil { - return currentID.String(), nil - } - if !errors.Is(err, session.ErrSessionNotFound) { - return "", fmt.Errorf("failed to validate session: %w", err) - } - } - - newSess, err := store.CreateSession(ctx, "New Session", cfg.ModelName, "You are a helpful assistant.") + model, err := tui.New(ctx, flow, sessionID) if err != nil { - return "", fmt.Errorf("failed to create session: %w", err) + return fmt.Errorf("creating TUI: %w", err) } + program := tea.NewProgram(model, tea.WithContext(ctx)) - if err := session.SaveCurrentSessionID(newSess.ID); err != nil { - slog.Warn("failed to save session state", "error", err) + if _, err = program.Run(); err != nil { + return fmt.Errorf("running TUI: %w", err) } - - return newSess.ID.String(), nil + return nil } diff --git a/cmd/cmd.go b/cmd/cmd.go new file mode 100644 index 0000000..3dfb5f5 --- /dev/null +++ b/cmd/cmd.go @@ -0,0 +1,76 @@ +// Package cmd provides CLI commands for Koopa. +// +// Commands: +// - cli: Interactive terminal chat with Bubble Tea TUI +// - serve: HTTP API server with SSE streaming +// - mcp: Model Context Protocol server for IDE integration +// +// Signal handling and graceful shutdown are implemented +// for all commands via context cancellation. +package cmd + +import ( + "fmt" + "log/slog" + "os" +) + +// Execute is the main entry point for the Koopa CLI application. +func Execute() error { + // Initialize logger once at entry point + level := slog.LevelInfo + if os.Getenv("DEBUG") != "" { + level = slog.LevelDebug + } + slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level}))) + + if len(os.Args) < 2 { + runHelp() + return nil + } + + switch os.Args[1] { + case "cli": + return runCLI() + case "serve": + return runServe() + case "mcp": + return runMCP() + case "version", "--version", "-v": + runVersion() + return nil + case "help", "--help", "-h": + runHelp() + return nil + default: + return fmt.Errorf("unknown command: %s", os.Args[1]) + } +} + +// runHelp displays the help message. +func runHelp() { + fmt.Println("Koopa - Your terminal AI personal assistant") + fmt.Println() + fmt.Println("Usage:") + fmt.Println(" koopa cli Start interactive chat mode") + fmt.Println(" koopa serve [addr] Start HTTP API server (default: 127.0.0.1:3400)") + fmt.Println(" koopa mcp Start MCP server (for Claude Desktop/Cursor)") + fmt.Println(" koopa --version Show version information") + fmt.Println(" koopa --help Show this help") + fmt.Println() + fmt.Println("CLI Commands (in interactive mode):") + fmt.Println(" /help Show available commands") + fmt.Println(" /version Show version") + fmt.Println(" /clear Clear conversation history") + fmt.Println(" /exit, /quit Exit Koopa") + fmt.Println() + fmt.Println("Shortcuts:") + fmt.Println(" Ctrl+D Exit Koopa") + fmt.Println(" Ctrl+C Cancel current input") + fmt.Println() + fmt.Println("Environment Variables:") + fmt.Println(" GEMINI_API_KEY Required: Gemini API key") + fmt.Println(" DEBUG Optional: Enable debug logging") + fmt.Println() + fmt.Println("Learn more: https://github.com/koopa0/koopa") +} diff --git a/cmd/e2e_test.go b/cmd/e2e_test.go index 2e3dd93..b3bf767 100644 --- a/cmd/e2e_test.go +++ b/cmd/e2e_test.go @@ -9,7 +9,7 @@ import ( "os" "os/exec" "path/filepath" - goruntime "runtime" // Alias to avoid confusion with app.ChatRuntime + goruntime "runtime" // Alias to avoid conflict with runtime package name "strings" "testing" "time" @@ -101,7 +101,7 @@ func findOrBuildKoopa(t *testing.T) string { cmd := exec.Command("go", "build", "-o", binName, ".") cmd.Dir = projectRoot if output, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("Failed to build koopa: %v\nOutput: %s", err, output) + t.Fatalf("go build error: %v\nOutput: %s", err, output) } return koopaBin @@ -141,14 +141,14 @@ func TestE2E_VersionCommand(t *testing.T) { output, err := ctx.runKoopaCommand(shortTimeout, "version") if err != nil { - t.Fatalf("version command unexpected error: %v", err) + t.Fatalf("running version command: %v", err) } if !strings.Contains(output, "Koopa") { t.Errorf("version command output = %q, want to contain %q", output, "Koopa") } - if !strings.Contains(output, "v0.") { - t.Errorf("version command output = %q, want to contain %q", output, "v0.") + if !strings.Contains(output, "v") { + t.Errorf("version command output = %q, want to contain %q", output, "v") } } @@ -159,7 +159,7 @@ func TestE2E_ErrorRecovery(t *testing.T) { t.Run("help command works", func(t *testing.T) { output, err := ctx.runKoopaCommand(shortTimeout, "help") if err != nil { - t.Errorf("help command unexpected error: %v", err) + t.Errorf("running help command: %v", err) } if !strings.Contains(strings.ToLower(output), "koopa") { t.Errorf("help command output = %q, want to contain %q", output, "koopa") @@ -178,38 +178,10 @@ func TestE2E_ErrorRecovery(t *testing.T) { // Version command should still work without API key if err != nil { - t.Errorf("version command without API key unexpected error: %v", err) + t.Errorf("running version command without API key: %v", err) } if !strings.Contains(output, "Koopa") { t.Errorf("version command output = %q, want to contain %q", output, "Koopa") } }) } - -// TestE2E_IntegrationTestHelper verifies E2E test infrastructure -func TestE2E_IntegrationTestHelper(t *testing.T) { - ctx := setupE2ETest(t) - - // Verify binary exists - if _, err := os.Stat(ctx.koopaBin); err != nil { - t.Errorf("koopa binary should exist at %q, but got error: %v", ctx.koopaBin, err) - } - - // Verify working directory - if info, err := os.Stat(ctx.workDir); err != nil || !info.IsDir() { - t.Errorf("working directory should exist at %q, but got error: %v", ctx.workDir, err) - } - - // Verify environment - if ctx.databaseURL == "" { - t.Error("DATABASE_URL should be set") - } - if ctx.apiKey == "" { - t.Error("GEMINI_API_KEY should be set") - } - - t.Logf("E2E test infrastructure:") - t.Logf(" Binary: %s", ctx.koopaBin) - t.Logf(" WorkDir: %s", ctx.workDir) - t.Logf(" Database: %s", ctx.databaseURL) -} diff --git a/cmd/execute.go b/cmd/execute.go deleted file mode 100644 index 9bf77eb..0000000 --- a/cmd/execute.go +++ /dev/null @@ -1,58 +0,0 @@ -// Package cmd provides CLI commands for Koopa. -// -// Commands: -// - cli: Interactive terminal chat with Bubble Tea TUI -// - serve: HTTP API server with SSE streaming -// - mcp: Model Context Protocol server for IDE integration -// -// Signal handling and graceful shutdown are implemented -// for all commands via context cancellation. -package cmd - -import ( - "fmt" - "log/slog" - "os" - - "github.com/koopa0/koopa/internal/log" -) - -// Version is set at build time via ldflags: -// -// go build -ldflags "-X github.com/koopa0/koopa/cmd.Version=1.0.0" -// -// Default value "dev" indicates a development build. -var Version = "dev" - -// Execute is the main entry point for the Koopa CLI application. -func Execute() error { - // Initialize logger once at entry point - level := slog.LevelInfo - if os.Getenv("DEBUG") != "" { - level = slog.LevelDebug - } - slog.SetDefault(log.New(log.Config{Level: level})) - - if len(os.Args) < 2 { - runHelp() - return nil - } - - switch os.Args[1] { - case "cli": - return runCLI() - case "serve": - return runServe() - case "mcp": - return runMCP() - case "version", "--version", "-v": - runVersion() - return nil - case "help", "--help", "-h": - runHelp() - return nil - default: - fmt.Fprintf(os.Stderr, "Error: unknown command %q\n", os.Args[1]) - return fmt.Errorf("unknown command: %s", os.Args[1]) - } -} diff --git a/cmd/integration_test.go b/cmd/integration_test.go index e2e8618..0b0deef 100644 --- a/cmd/integration_test.go +++ b/cmd/integration_test.go @@ -5,133 +5,81 @@ package cmd import ( "context" - "fmt" "os" "path/filepath" - "runtime" "testing" "time" - "github.com/koopa0/koopa/internal/agent/chat" + "github.com/google/uuid" "github.com/koopa0/koopa/internal/app" + "github.com/koopa0/koopa/internal/chat" "github.com/koopa0/koopa/internal/config" + "github.com/koopa0/koopa/internal/testutil" "github.com/koopa0/koopa/internal/tui" ) -// findProjectRoot finds the project root directory by looking for go.mod. -func findProjectRoot() (string, error) { - _, filename, _, ok := runtime.Caller(0) - if !ok { - return "", fmt.Errorf("runtime.Caller failed to get caller info") - } - - dir := filepath.Dir(filename) - for { - goModPath := filepath.Join(dir, "go.mod") - if _, err := os.Stat(goModPath); err == nil { - return dir, nil - } - parent := filepath.Dir(dir) - if parent == dir { - return "", fmt.Errorf("go.mod not found in any parent directory of %s", filename) - } - dir = parent - } -} - -// TestTUI_Integration tests the TUI can be created with real runtime. -// Note: Bubble Tea TUI cannot be fully tested without a real TTY. -// These tests verify initialization and component wiring. -func TestTUI_Integration(t *testing.T) { - if os.Getenv("GEMINI_API_KEY") == "" { - t.Skip("GEMINI_API_KEY not set - skipping integration test") - } +// setupApp is a test helper that creates an App instance. +func setupApp(t *testing.T) *app.App { + t.Helper() - // Reset Flow singleton for test isolation chat.ResetFlowForTesting() cfg, err := config.Load() if err != nil { - t.Fatalf("Failed to load config: %v", err) + t.Fatalf("config.Load() error: %v", err) } - // Set absolute path for prompts directory (required for tests running from different directories) - projectRoot, err := findProjectRoot() + projectRoot, err := testutil.FindProjectRoot() if err != nil { - t.Fatalf("Failed to find project root: %v", err) + t.Fatalf("testutil.FindProjectRoot() error: %v", err) } cfg.PromptDir = filepath.Join(projectRoot, "prompts") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - // Initialize runtime - runtime, err := app.NewChatRuntime(ctx, cfg) + a, err := app.Setup(ctx, cfg) if err != nil { - t.Fatalf("Failed to initialize runtime: %v", err) + cancel() + t.Fatalf("app.Setup() error: %v", err) } + t.Cleanup(func() { - if err := runtime.Close(); err != nil { - t.Logf("runtime close error: %v", err) + if err := a.Close(); err != nil { + t.Logf("app close error: %v", err) } + cancel() }) - // Verify TUI can be created with real Flow - tuiModel, err := tui.New(ctx, runtime.Flow, "test-session-id") - if err != nil { - t.Fatalf("Failed to create TUI: %v", err) - } - - // Verify Init returns a command - cmd := tuiModel.Init() - if cmd == nil { - t.Error("Init should return a command (blink + spinner)") - } + return a } -// TestTUI_SlashCommands tests slash command handling. -func TestTUI_SlashCommands(t *testing.T) { +// TestTUI_Integration tests the TUI can be created with real dependencies. +// Note: Bubble Tea TUI cannot be fully tested without a real TTY. +// These tests verify initialization and component wiring. +func TestTUI_Integration(t *testing.T) { if os.Getenv("GEMINI_API_KEY") == "" { t.Skip("GEMINI_API_KEY not set - skipping integration test") } - // Reset Flow singleton for test isolation - chat.ResetFlowForTesting() + a := setupApp(t) - cfg, err := config.Load() + agent, err := a.CreateAgent() if err != nil { - t.Fatalf("Failed to load config: %v", err) + t.Fatalf("CreateAgent() error: %v", err) } - // Set absolute path for prompts directory (required for tests running from different directories) - projectRoot, err := findProjectRoot() - if err != nil { - t.Fatalf("Failed to find project root: %v", err) - } - cfg.PromptDir = filepath.Join(projectRoot, "prompts") + flow := chat.NewFlow(a.Genkit, agent) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - runtime, err := app.NewChatRuntime(ctx, cfg) + tuiModel, err := tui.New(ctx, flow, uuid.New()) if err != nil { - t.Fatalf("Failed to initialize runtime: %v", err) + t.Fatalf("tui.New() error: %v", err) } - t.Cleanup(func() { - if err := runtime.Close(); err != nil { - t.Logf("runtime close error: %v", err) - } - }) - tuiModel, err := tui.New(ctx, runtime.Flow, "test-session-id") - if err != nil { - t.Fatalf("Failed to create TUI: %v", err) - } - - // Test /help command by simulating the message flow - // Note: Full TUI testing requires teatest or similar framework - view := tuiModel.View() - if view.Content == nil { - t.Error("View content should not be nil") + cmd := tuiModel.Init() + if cmd == nil { + t.Error("Init should return a command (blink + spinner)") } } diff --git a/cmd/mcp.go b/cmd/mcp.go index c0948c6..5b8a0c6 100644 --- a/cmd/mcp.go +++ b/cmd/mcp.go @@ -4,125 +4,52 @@ import ( "context" "fmt" "log/slog" - "os" "os/signal" "syscall" - "time" "github.com/koopa0/koopa/internal/app" "github.com/koopa0/koopa/internal/config" "github.com/koopa0/koopa/internal/mcp" - "github.com/koopa0/koopa/internal/security" - "github.com/koopa0/koopa/internal/tools" mcpSdk "github.com/modelcontextprotocol/go-sdk/mcp" ) -// runMCP initializes and starts the MCP server. -// This is called when the user runs `koopa mcp`. +// runMCP initializes and starts the MCP server on stdio transport. func runMCP() error { cfg, err := config.Load() if err != nil { - fmt.Fprintln(os.Stderr, err) - return err + return fmt.Errorf("loading config: %w", err) } ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() - return RunMCP(ctx, cfg, Version) -} - -// RunMCP starts the MCP server on stdio transport -// -// Architecture: -// - Creates all toolsets with necessary dependencies -// - Creates MCP Server wrapping the toolsets -// - Connects to stdio transport for Claude Desktop/Cursor -// - Signal handling is done by caller (executeMCP) -// -// Error handling: -// - Returns error if initialization fails (App, Toolsets, Server) -// - Returns error if server connection fails -// - Graceful shutdown on context cancellation -func RunMCP(ctx context.Context, cfg *config.Config, version string) error { - slog.Info("starting MCP server", "version", version) + slog.Info("starting MCP server", "version", Version) - // Initialize application - application, cleanup, err := app.InitializeApp(ctx, cfg) + a, err := app.Setup(ctx, cfg) if err != nil { - return fmt.Errorf("failed to initialize application: %w", err) + return fmt.Errorf("initializing application: %w", err) } - // Cleanup order: App.Close (goroutines) first, then cleanup (DB pool, OTel) - defer cleanup() defer func() { - if closeErr := application.Close(); closeErr != nil { - slog.Warn("app close error", "error", closeErr) + if closeErr := a.Close(); closeErr != nil { + slog.Warn("shutdown error", "error", closeErr) } }() - // Create all required tools with logger - logger := slog.Default() - - // 1. FileTools - fileTools, err := tools.NewFileTools(application.PathValidator, logger) - if err != nil { - return fmt.Errorf("failed to create file tools: %w", err) - } - - // 2. SystemTools - cmdValidator := security.NewCommand() - envValidator := security.NewEnv() - systemTools, err := tools.NewSystemTools(cmdValidator, envValidator, logger) - if err != nil { - return fmt.Errorf("failed to create system tools: %w", err) - } - - // 3. NetworkTools - networkTools, err := tools.NewNetworkTools(tools.NetworkConfig{ - SearchBaseURL: cfg.SearXNG.BaseURL, - FetchParallelism: cfg.WebScraper.Parallelism, - FetchDelay: time.Duration(cfg.WebScraper.DelayMs) * time.Millisecond, - FetchTimeout: time.Duration(cfg.WebScraper.TimeoutMs) * time.Millisecond, - }, logger) - if err != nil { - return fmt.Errorf("failed to create network tools: %w", err) - } - - // 4. KnowledgeTools (optional - requires retriever from App) - var knowledgeTools *tools.KnowledgeTools - toolCategories := []string{"file", "system", "network"} - if application.Retriever != nil { - kt, ktErr := tools.NewKnowledgeTools(application.Retriever, application.DocStore, logger) - if ktErr != nil { - slog.Warn("knowledge tools unavailable", "error", ktErr) - } else { - knowledgeTools = kt - toolCategories = append(toolCategories, "knowledge") - } - } - - // Create MCP Server with all tools mcpServer, err := mcp.NewServer(mcp.Config{ - Name: "koopa", - Version: version, - FileTools: fileTools, - SystemTools: systemTools, - NetworkTools: networkTools, - KnowledgeTools: knowledgeTools, + Name: "koopa", + Version: Version, + Logger: slog.Default(), + File: a.File, + System: a.System, + Network: a.Network, + Knowledge: a.Knowledge, }) if err != nil { - return fmt.Errorf("failed to create MCP server: %w", err) + return fmt.Errorf("creating MCP server: %w", err) } - slog.Info("MCP server initialized", - "name", "koopa", - "version", version, - "tools", toolCategories) - slog.Info("starting MCP server on stdio transport") + slog.Info("MCP server ready", "name", "koopa", "version", Version, "transport", "stdio") - // Run server on stdio transport - // This is a blocking call that handles all MCP protocol communication - // The server will run until ctx is canceled or an error occurs if err := mcpServer.Run(ctx, &mcpSdk.StdioTransport{}); err != nil { return fmt.Errorf("MCP server error: %w", err) } diff --git a/cmd/serve.go b/cmd/serve.go index 55f0673..2e19997 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -3,87 +3,87 @@ package cmd import ( "context" "errors" - "flag" "fmt" "log/slog" - "net" "net/http" - "os" "os/signal" - "strconv" - "strings" "syscall" "time" "github.com/koopa0/koopa/internal/api" "github.com/koopa0/koopa/internal/app" + "github.com/koopa0/koopa/internal/chat" "github.com/koopa0/koopa/internal/config" ) // Server timeout configuration. const ( - ReadHeaderTimeout = 10 * time.Second - ReadTimeout = 30 * time.Second - WriteTimeout = 2 * time.Minute // SSE streaming needs longer timeout - IdleTimeout = 120 * time.Second - ShutdownTimeout = 30 * time.Second + readHeaderTimeout = 10 * time.Second + readTimeout = 30 * time.Second + writeTimeout = 2 * time.Minute // SSE streaming needs longer timeout + idleTimeout = 2 * time.Minute + shutdownTimeout = 30 * time.Second ) -// RunServe starts the HTTP API server (JSON REST + Health checks). -// -// Architecture: -// - Validates required configuration (HMAC_SECRET) -// - Initializes the application runtime -// - Creates the API server with all routes -// - Signal handling is done by caller (executeServe) -func RunServe(ctx context.Context, cfg *config.Config, version, addr string) error { - logger := slog.Default() - - // Validate HMAC_SECRET for serve mode - if cfg.HMACSecret == "" { - return errors.New("HMAC_SECRET environment variable is required for serve mode (min 32 characters)") +// runServe initializes and starts the HTTP API server. +func runServe() error { + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("loading config: %w", err) } - if len(cfg.HMACSecret) < 32 { - return fmt.Errorf("HMAC_SECRET must be at least 32 characters, got %d", len(cfg.HMACSecret)) + if err = cfg.ValidateServe(); err != nil { + return fmt.Errorf("validating config: %w", err) } - logger.Info("starting HTTP API server", "version", version) + addr, err := parseServeAddr() + if err != nil { + return fmt.Errorf("parsing address: %w", err) + } + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() - // Initialize runtime with all components - runtime, err := app.NewChatRuntime(ctx, cfg) + logger := slog.Default() + logger.Info("starting HTTP API server", "version", Version) + + a, err := app.Setup(ctx, cfg) if err != nil { - return fmt.Errorf("failed to initialize runtime: %w", err) + return fmt.Errorf("initializing application: %w", err) } defer func() { - if closeErr := runtime.Close(); closeErr != nil { - logger.Warn("runtime close error", "error", closeErr) + if closeErr := a.Close(); closeErr != nil { + logger.Warn("shutdown error", "error", closeErr) } }() - // Create API server (JSON REST + Health checks) + agent, err := a.CreateAgent() + if err != nil { + return fmt.Errorf("creating agent: %w", err) + } + + flow := chat.NewFlow(a.Genkit, agent) + apiServer, err := api.NewServer(api.ServerConfig{ Logger: logger, - Genkit: runtime.App.Genkit, - ModelName: cfg.FullModelName(), - ChatFlow: runtime.Flow, - SessionStore: runtime.App.SessionStore, + ChatAgent: agent, + ChatFlow: flow, + SessionStore: a.SessionStore, CSRFSecret: []byte(cfg.HMACSecret), CORSOrigins: cfg.CORSOrigins, IsDev: cfg.PostgresSSLMode == "disable", TrustProxy: cfg.TrustProxy, }) if err != nil { - return fmt.Errorf("failed to create API server: %w", err) + return fmt.Errorf("creating API server: %w", err) } - // Create HTTP server srv := &http.Server{ Addr: addr, Handler: apiServer.Handler(), - ReadHeaderTimeout: ReadHeaderTimeout, - ReadTimeout: ReadTimeout, - WriteTimeout: WriteTimeout, - IdleTimeout: IdleTimeout, + ReadHeaderTimeout: readHeaderTimeout, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + IdleTimeout: idleTimeout, } logger.Info("HTTP server ready", @@ -100,10 +100,10 @@ func RunServe(ctx context.Context, cfg *config.Config, version, addr string) err select { case <-ctx.Done(): logger.Info("shutting down HTTP server") - shutdownCtx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout) - defer cancel() + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer shutdownCancel() if err := srv.Shutdown(shutdownCtx); err != nil { - return fmt.Errorf("server shutdown failed: %w", err) + return fmt.Errorf("shutting down server: %w", err) } <-errCh return nil @@ -111,90 +111,6 @@ func RunServe(ctx context.Context, cfg *config.Config, version, addr string) err if errors.Is(err, http.ErrServerClosed) { return nil } - return err + return fmt.Errorf("HTTP server: %w", err) } } - -// runServe initializes and starts the HTTP API server. -// This is called when the user runs `koopa serve`. -func runServe() error { - cfg, err := config.Load() - if err != nil { - fmt.Fprintln(os.Stderr, err) - return err - } - - addr, err := parseServeAddr() - if err != nil { - return err - } - - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer cancel() - - return RunServe(ctx, cfg, Version, addr) -} - -// parseServeAddr parses and validates the server address from command line arguments. -// Uses flag.FlagSet for standard Go flag parsing, supporting: -// - koopa serve :8080 (positional) -// - koopa serve --addr :8080 (flag) -// - koopa serve -addr :8080 (single dash) -func parseServeAddr() (string, error) { - const defaultAddr = "127.0.0.1:3400" - - serveFlags := flag.NewFlagSet("serve", flag.ContinueOnError) - serveFlags.SetOutput(os.Stderr) - - addr := serveFlags.String("addr", defaultAddr, "Server address (host:port)") - - args := []string{} - if len(os.Args) > 2 { - args = os.Args[2:] - } - - // Check for positional argument first (koopa serve :8080) - if len(args) > 0 && !strings.HasPrefix(args[0], "-") { - *addr = args[0] - args = args[1:] - } - - if err := serveFlags.Parse(args); err != nil { - return "", fmt.Errorf("failed to parse serve flags: %w", err) - } - - if err := validateAddr(*addr); err != nil { - return "", fmt.Errorf("invalid address %q: %w", *addr, err) - } - - return *addr, nil -} - -// validateAddr validates the server address format. -func validateAddr(addr string) error { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return fmt.Errorf("must be in host:port format: %w", err) - } - - if host != "" && host != "localhost" { - if ip := net.ParseIP(host); ip == nil { - if strings.ContainsAny(host, " \t\n") { - return fmt.Errorf("invalid host: %s", host) - } - } - } - - if port == "" { - return fmt.Errorf("port is required") - } - portNum, err := strconv.Atoi(port) - if err != nil { - return fmt.Errorf("port must be numeric: %w", err) - } - if portNum < 0 || portNum > 65535 { - return fmt.Errorf("port must be 0-65535 (0 = auto-assign), got %d", portNum) - } - - return nil -} diff --git a/cmd/version.go b/cmd/version.go index b0511cc..00c630b 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -2,35 +2,14 @@ package cmd import "fmt" +// Version is set at build time via ldflags: +// +// go build -ldflags "-X github.com/koopa0/koopa/cmd.Version=1.0.0" +// +// Default value "dev" indicates a development build. +var Version = "dev" + // runVersion displays version information. func runVersion() { fmt.Printf("Koopa v%s\n", Version) } - -// runHelp displays the help message. -func runHelp() { - fmt.Println("Koopa - Your terminal AI personal assistant") - fmt.Println() - fmt.Println("Usage:") - fmt.Println(" koopa cli Start interactive chat mode") - fmt.Println(" koopa serve [addr] Start HTTP API server (default: 127.0.0.1:3400)") - fmt.Println(" koopa mcp Start MCP server (for Claude Desktop/Cursor)") - fmt.Println(" koopa --version Show version information") - fmt.Println(" koopa --help Show this help") - fmt.Println() - fmt.Println("CLI Commands (in interactive mode):") - fmt.Println(" /help Show available commands") - fmt.Println(" /version Show version") - fmt.Println(" /clear Clear conversation history") - fmt.Println(" /exit, /quit Exit Koopa") - fmt.Println() - fmt.Println("Shortcuts:") - fmt.Println(" Ctrl+D Exit Koopa") - fmt.Println(" Ctrl+C Cancel current input") - fmt.Println() - fmt.Println("Environment Variables:") - fmt.Println(" GEMINI_API_KEY Required: Gemini API key") - fmt.Println(" DEBUG Optional: Enable debug logging") - fmt.Println() - fmt.Println("Learn more: https://github.com/koopa0/koopa") -} diff --git a/db/migrate.go b/db/migrate.go index bbd2469..06ceb71 100644 --- a/db/migrate.go +++ b/db/migrate.go @@ -31,44 +31,38 @@ func Migrate(connURL string) error { // Create source driver from embedded filesystem source, err := iofs.New(migrationsFS, "migrations") if err != nil { - slog.Error("failed to create migration source", "error", err) - return fmt.Errorf("failed to create migration source: %w", err) + return fmt.Errorf("creating migration source: %w", err) } // Convert postgres:// or postgresql:// to pgx5:// scheme for golang-migrate pgx v5 driver dbURL, err := convertToMigrateURL(connURL) if err != nil { - slog.Error("invalid database URL", "error", err) return err } // Create migrate instance with pgx5 driver m, err := migrate.NewWithSourceInstance("iofs", source, dbURL) if err != nil { - slog.Error("failed to connect to database for migrations", "error", err) - return fmt.Errorf("failed to create migrate instance: %w", err) + return fmt.Errorf("creating migrate instance: %w", err) } + // best-effort: close errors are non-fatal during migration teardown defer func() { srcErr, dbErr := m.Close() if srcErr != nil { - slog.Warn("failed to close migration source", "error", srcErr) + slog.Warn("closing migration source", "error", srcErr) } if dbErr != nil { - slog.Warn("failed to close migration database connection", "error", dbErr) + slog.Warn("closing migration database connection", "error", dbErr) } }() // Check for dirty state before running migrations version, dirty, verErr := m.Version() if verErr != nil && !errors.Is(verErr, migrate.ErrNilVersion) { - slog.Error("failed to check migration version", "error", verErr) - return fmt.Errorf("failed to check migration version: %w", verErr) + return fmt.Errorf("checking migration version: %w", verErr) } if dirty { - slog.Error("database is in dirty migration state - manual intervention required", - "version", version, - "hint", fmt.Sprintf("inspect schema and run: migrate force %d", version)) - return fmt.Errorf("database in dirty state (version=%d), manual cleanup required", version) + return fmt.Errorf("database in dirty state (version=%d): inspect schema and run: migrate force %d", version, version) } // Run migrations @@ -78,16 +72,13 @@ func Migrate(connURL string) error { return nil } - // Check for dirty state after failure + // Include dirty state info in error if migration left database dirty postVersion, postDirty, postErr := m.Version() if postErr == nil && postDirty { - slog.Error("migration failed - database now in dirty state", - "version", postVersion, - "hint", fmt.Sprintf("fix the migration and run: migrate force %d", postVersion)) + return fmt.Errorf("running migrations (database now dirty at version %d, fix and run: migrate force %d): %w", postVersion, postVersion, err) } - slog.Error("failed to run migrations", "error", err) - return fmt.Errorf("failed to run migrations: %w", err) + return fmt.Errorf("running migrations: %w", err) } finalVersion, finalDirty, verErr := m.Version() @@ -108,7 +99,7 @@ func convertToMigrateURL(connURL string) (string, error) { // Parse URL to validate and extract components u, err := url.Parse(connURL) if err != nil { - return "", fmt.Errorf("failed to parse database URL: %w", err) + return "", fmt.Errorf("parsing database URL: %w", err) } // Validate scheme diff --git a/db/migrations/000001_init_schema.down.sql b/db/migrations/000001_init_schema.down.sql index a715372..18a34c2 100644 --- a/db/migrations/000001_init_schema.down.sql +++ b/db/migrations/000001_init_schema.down.sql @@ -2,31 +2,18 @@ -- Drops all objects created by 000001_init_schema.up.sql in reverse order -- ============================================================================ --- Drop Messages Table (including triggers and indexes) +-- Drop Messages Table -- ============================================================================ -DROP TRIGGER IF EXISTS update_message_updated_at ON message; -DROP INDEX IF EXISTS idx_message_content_gin; -DROP INDEX IF EXISTS idx_message_status; -DROP INDEX IF EXISTS idx_incomplete_messages; -DROP INDEX IF EXISTS idx_message_session_seq; -DROP INDEX IF EXISTS idx_message_session_id; -DROP TABLE IF EXISTS message; +DROP TABLE IF EXISTS messages; -- ============================================================================ --- Drop Sessions Table (including triggers and indexes) +-- Drop Sessions Table (including indexes) -- ============================================================================ -DROP TRIGGER IF EXISTS update_sessions_updated_at ON sessions; DROP INDEX IF EXISTS idx_sessions_updated_at; DROP TABLE IF EXISTS sessions; --- ============================================================================ --- Drop Helper Functions --- ============================================================================ - -DROP FUNCTION IF EXISTS update_updated_at_column(); - -- ============================================================================ -- Drop Documents Table (including indexes) -- ============================================================================ diff --git a/db/migrations/000001_init_schema.up.sql b/db/migrations/000001_init_schema.up.sql index 49cc5a5..45b6705 100644 --- a/db/migrations/000001_init_schema.up.sql +++ b/db/migrations/000001_init_schema.up.sql @@ -13,7 +13,7 @@ CREATE EXTENSION IF NOT EXISTS vector; CREATE TABLE IF NOT EXISTS documents ( id TEXT PRIMARY KEY, content TEXT NOT NULL, - embedding vector(768) NOT NULL, -- text-embedding-004 dimension + embedding vector(768) NOT NULL, -- gemini-embedding-001 truncated via OutputDimensionality source_type TEXT, -- Metadata column for filtering metadata JSONB -- Additional metadata in JSON format ); @@ -30,19 +30,6 @@ CREATE INDEX IF NOT EXISTS idx_documents_source_type ON documents(source_type); CREATE INDEX IF NOT EXISTS idx_documents_metadata_gin ON documents USING GIN (metadata jsonb_path_ops); --- ============================================================================ --- Helper Functions --- ============================================================================ - --- Auto-update updated_at timestamp -CREATE OR REPLACE FUNCTION update_updated_at_column() -RETURNS TRIGGER AS $$ -BEGIN - NEW.updated_at = NOW(); - RETURN NEW; -END; -$$ LANGUAGE plpgsql; - -- ============================================================================ -- Sessions Table -- ============================================================================ @@ -51,65 +38,24 @@ CREATE TABLE IF NOT EXISTS sessions ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), title TEXT, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - model_name TEXT, - system_prompt TEXT, - message_count INTEGER DEFAULT 0 + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at DESC); --- Use DO block for trigger (no IF NOT EXISTS syntax for triggers) -DO $$ -BEGIN - IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_sessions_updated_at') THEN - CREATE TRIGGER update_sessions_updated_at - BEFORE UPDATE ON sessions - FOR EACH ROW - EXECUTE FUNCTION update_updated_at_column(); - END IF; -END $$; - -- ============================================================================ -- Messages Table -- ============================================================================ -CREATE TABLE IF NOT EXISTS message ( +CREATE TABLE IF NOT EXISTS messages ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), session_id UUID NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, role TEXT NOT NULL, content JSONB NOT NULL, sequence_number INTEGER NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - status TEXT NOT NULL DEFAULT 'completed' - CHECK (status IN ('streaming', 'completed', 'failed')), - updated_at TIMESTAMPTZ DEFAULT NOW(), + -- UNIQUE constraint automatically creates index on (session_id, sequence_number) CONSTRAINT unique_message_sequence UNIQUE (session_id, sequence_number), CONSTRAINT message_role_check CHECK (role IN ('user', 'assistant', 'system', 'tool')) ); - -CREATE INDEX IF NOT EXISTS idx_message_session_id ON message(session_id); -CREATE INDEX IF NOT EXISTS idx_message_session_seq ON message(session_id, sequence_number); -CREATE INDEX IF NOT EXISTS idx_incomplete_messages ON message(session_id, updated_at) - WHERE status IN ('streaming', 'failed'); - --- Index for querying failed/streaming messages -CREATE INDEX IF NOT EXISTS idx_message_status ON message(session_id, status) - WHERE status != 'completed'; - --- Index for message.content (ai.Part array stored as JSONB) --- Enables fast queries like: WHERE content @> '[{"text": "search term"}]' -CREATE INDEX IF NOT EXISTS idx_message_content_gin - ON message USING GIN (content jsonb_path_ops); - --- Use DO block for trigger (no IF NOT EXISTS syntax for triggers) -DO $$ -BEGIN - IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_message_updated_at') THEN - CREATE TRIGGER update_message_updated_at - BEFORE UPDATE ON message - FOR EACH ROW - EXECUTE FUNCTION update_updated_at_column(); - END IF; -END $$; diff --git a/db/queries/documents.sql b/db/queries/documents.sql deleted file mode 100644 index 3773879..0000000 --- a/db/queries/documents.sql +++ /dev/null @@ -1,63 +0,0 @@ --- Documents queries for sqlc --- Generated code will be in internal/sqlc/documents.sql.go - --- name: UpsertDocument :exec -INSERT INTO documents (id, content, embedding, source_type, metadata) -VALUES ($1, $2, $3, $4, $5) -ON CONFLICT (id) DO UPDATE SET - content = EXCLUDED.content, - embedding = EXCLUDED.embedding, - source_type = EXCLUDED.source_type, - metadata = EXCLUDED.metadata; - --- name: SearchDocuments :many -SELECT id, content, metadata, - (1 - (embedding <=> sqlc.arg(query_embedding)::vector))::float8 AS similarity -FROM documents -WHERE metadata @> sqlc.arg(filter_metadata)::jsonb -ORDER BY similarity DESC -LIMIT sqlc.arg(result_limit); - --- name: SearchDocumentsAll :many -SELECT id, content, metadata, - (1 - (embedding <=> sqlc.arg(query_embedding)::vector))::float8 AS similarity -FROM documents -ORDER BY similarity DESC -LIMIT sqlc.arg(result_limit); - --- name: CountDocuments :one -SELECT COUNT(*) -FROM documents -WHERE metadata @> $1::jsonb; - --- name: CountDocumentsAll :one -SELECT COUNT(*) -FROM documents; - --- name: GetDocument :one -SELECT id, content, metadata -FROM documents -WHERE id = $1; - --- name: DeleteDocument :exec -DELETE FROM documents -WHERE id = $1; - --- ===== Optimized RAG Queries (SQL-level filtering) ===== - --- name: SearchBySourceType :many --- Generic search by source_type using dedicated indexed column -SELECT id, content, metadata, - (1 - (embedding <=> sqlc.arg(query_embedding)::vector))::float8 AS similarity -FROM documents -WHERE source_type = sqlc.arg(source_type)::text -ORDER BY similarity DESC -LIMIT sqlc.arg(result_limit); - --- name: ListDocumentsBySourceType :many --- List all documents by source_type using dedicated indexed column --- Used for listing indexed files without needing embeddings -SELECT id, content, metadata -FROM documents -WHERE source_type = sqlc.arg(source_type)::text -LIMIT sqlc.arg(result_limit); diff --git a/db/queries/sessions.sql b/db/queries/sessions.sql index 2258b3e..849204b 100644 --- a/db/queries/sessions.sql +++ b/db/queries/sessions.sql @@ -1,37 +1,26 @@ --- Sessions queries for sqlc +-- Sessions and messages queries for sqlc -- Generated code will be in internal/sqlc/sessions.sql.go -- name: CreateSession :one -INSERT INTO sessions (title, model_name, system_prompt) -VALUES ($1, $2, $3) +INSERT INTO sessions (title) +VALUES ($1) RETURNING *; -- name: Session :one -SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count +SELECT id, title, created_at, updated_at FROM sessions WHERE id = $1; --- name: ListSessions :many -SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count +-- name: Sessions :many +SELECT id, title, created_at, updated_at FROM sessions ORDER BY updated_at DESC LIMIT sqlc.arg(result_limit) OFFSET sqlc.arg(result_offset); --- name: ListSessionsWithMessages :many --- Only list sessions that have messages or titles (not empty sessions) --- This is used for sidebar to hide "New Chat" placeholder sessions -SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count -FROM sessions -WHERE message_count > 0 OR title IS NOT NULL -ORDER BY updated_at DESC -LIMIT sqlc.arg(result_limit) -OFFSET sqlc.arg(result_offset); - -- name: UpdateSessionUpdatedAt :exec UPDATE sessions -SET updated_at = NOW(), - message_count = sqlc.arg(message_count) +SET updated_at = NOW() WHERE id = sqlc.arg(session_id); -- name: UpdateSessionTitle :exec @@ -47,57 +36,24 @@ WHERE id = $1; -- name: AddMessage :exec -- Add a message to a session -INSERT INTO message (session_id, role, content, sequence_number) +INSERT INTO messages (session_id, role, content, sequence_number) VALUES ($1, $2, $3, $4); -- name: Messages :many -- Get all messages for a session ordered by sequence -SELECT * -FROM message +SELECT id, session_id, role, content, sequence_number, created_at +FROM messages WHERE session_id = sqlc.arg(session_id) ORDER BY sequence_number ASC LIMIT sqlc.arg(result_limit) OFFSET sqlc.arg(result_offset); --- name: GetMaxSequenceNumber :one --- Get max sequence number for a session +-- name: MaxSequenceNumber :one +-- Get max sequence number for a session (returns 0 if no messages) SELECT COALESCE(MAX(sequence_number), 0)::integer AS max_seq -FROM message +FROM messages WHERE session_id = $1; --- name: CountMessages :one --- Count messages in a session -SELECT COUNT(*)::integer AS count -FROM message -WHERE session_id = sqlc.arg(session_id); - -- name: LockSession :one -- Locks the session row to prevent concurrent modifications SELECT id FROM sessions WHERE id = $1 FOR UPDATE; - --- name: DeleteMessages :exec --- Delete all messages in a session -DELETE FROM message -WHERE session_id = sqlc.arg(session_id); - --- name: AddMessageWithID :one --- Add message with pre-assigned ID and status (for streaming) -INSERT INTO message (id, session_id, role, content, status, sequence_number) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING *; - --- name: UpdateMessageContent :exec --- Update message content and mark as completed -UPDATE message -SET content = $2, - status = 'completed', - updated_at = NOW() -WHERE id = $1; - --- name: UpdateMessageStatus :exec --- Update message status (streaming/completed/failed) -UPDATE message -SET status = $2, - updated_at = NOW() -WHERE id = $1; - diff --git a/docs/architecture-report.md b/docs/architecture-report.md deleted file mode 100644 index 6a75196..0000000 --- a/docs/architecture-report.md +++ /dev/null @@ -1,1654 +0,0 @@ -# Koopa Project Architecture Report - -> Generated: 2026-02-04 | Branch: main | Commit: 26e635b - ---- - -## 1. Project Overview - -### Module Information - -| Field | Value | -|-------|-------| -| **Go Module** | `github.com/koopa0/koopa` | -| **Go Version** | 1.25.1 | -| **License** | MIT | -| **Entry Point** | `main.go` → `cmd.Execute()` | - -### Directory Structure (3 Levels) - -``` -koopa/ -├── main.go # Entry point → cmd.Execute() -├── go.mod / go.sum # Dependencies -├── Taskfile.yml # Build tasks (task CLI) -├── docker-compose.yml # PostgreSQL + SearXNG + Redis -├── .golangci.yml # Linter config -├── .mcp.json # MCP server config (Genkit) -├── .env.example # Environment template -├── config.example.yaml # Advanced config template (~/.koopa/config.yaml) -├── GENKIT.md # Genkit framework guidelines -├── CLAUDE.md # Claude Code project rules -│ -├── cmd/ # CLI Commands (Cobra-like, manual) -│ ├── execute.go # Command dispatcher -│ ├── cli.go # Interactive TUI mode (BubbleTea) -│ ├── serve.go # HTTP API server -│ ├── mcp.go # MCP server (stdio) -│ ├── version.go # Version display -│ ├── e2e_test.go # CLI E2E tests -│ └── integration_test.go # CLI integration tests -│ -├── internal/ -│ ├── agent/ # AI Agent abstraction -│ │ ├── errors.go # Sentinel errors -│ │ └── chat/ # Chat agent (Genkit Flow) -│ │ ├── chat.go # Core agent (Execute, ExecuteStream) -│ │ ├── flow.go # Genkit StreamingFlow definition -│ │ ├── retry.go # Exponential backoff retry -│ │ ├── circuit.go # Circuit breaker pattern -│ │ ├── tokens.go # Token budget management -│ │ └── *_test.go # Tests -│ │ -│ ├── app/ # Application lifecycle (DI) -│ │ ├── app.go # App container struct -│ │ ├── runtime.go # Runtime wrapper (App + Flow + cleanup) -│ │ ├── wire.go # Wire DI providers (10-step chain) -│ │ ├── wire_gen.go # Wire generated code -│ │ └── *_test.go # Tests -│ │ -│ ├── config/ # Configuration -│ │ ├── config.go # Main Config struct + Load() + Validate() -│ │ ├── tools.go # MCP/SearXNG/WebScraper config -│ │ └── *_test.go # Tests -│ │ -│ ├── knowledge/ # Knowledge store (placeholder) -│ │ -│ ├── log/ # Logger setup -│ │ └── log.go # slog wrapper (Logger = *slog.Logger) -│ │ -│ ├── mcp/ # MCP Server implementation -│ │ ├── doc.go # Architecture documentation -│ │ ├── server.go # MCP server (10 tools) -│ │ ├── file.go # File tool MCP handlers -│ │ ├── system.go # System tool MCP handlers -│ │ ├── network.go # Network tool MCP handlers -│ │ ├── util.go # Result → MCP conversion -│ │ └── *_test.go # Tests -│ │ -│ ├── observability/ # OpenTelemetry setup -│ │ └── datadog.go # OTLP HTTP exporter → Datadog Agent -│ │ -│ ├── rag/ # RAG retriever/indexer -│ │ ├── constants.go # Source types, schema config -│ │ ├── system.go # System knowledge indexing (6 docs) -│ │ └── doc.go # Package doc -│ │ -│ ├── security/ # Input validators (5 modules) -│ │ ├── path.go # Path traversal prevention -│ │ ├── command.go # Command injection prevention -│ │ ├── env.go # Env variable access control -│ │ ├── url.go # SSRF prevention -│ │ ├── prompt.go # Prompt injection detection -│ │ └── *_test.go # Tests (incl. fuzz) -│ │ -│ ├── session/ # Session persistence (PostgreSQL) -│ │ ├── types.go # Session, Message, History structs -│ │ ├── store.go # Store (CRUD, transactions) -│ │ ├── errors.go # Sentinel errors, status constants -│ │ ├── state.go # ~/.koopa/current_session persistence -│ │ └── *_test.go # Tests -│ │ -│ ├── sqlc/ # Generated SQL code (sqlc) -│ │ ├── db.go # DBTX interface, Queries struct -│ │ ├── models.go # Document, Message, Session models -│ │ ├── documents.sql.go # Document queries -│ │ └── sessions.sql.go # Session/message queries -│ │ -│ ├── testutil/ # Test utilities -│ │ ├── db.go # Testcontainer PostgreSQL setup -│ │ ├── embedder.go # Deterministic test embedder -│ │ └── logger.go # No-op logger -│ │ -│ ├── tools/ # Tool implementations -│ │ ├── types.go # Result, Error, Status types -│ │ ├── metadata.go # DangerLevel, ToolMetadata registry -│ │ ├── emitter.go # ToolEventEmitter interface -│ │ ├── events.go # WithEvents wrapper -│ │ ├── file.go # File tools (5 tools) -│ │ ├── system.go # System tools (3 tools) -│ │ ├── network.go # Network tools (2 tools) -│ │ ├── knowledge.go # Knowledge tools (3 tools) -│ │ └── *_test.go # Tests (incl. fuzz) -│ │ -│ ├── tui/ # Terminal UI (BubbleTea) -│ │ ├── tui.go # Model + Init + Update + View -│ │ ├── keys.go # Key bindings -│ │ ├── commands.go # Streaming tea.Cmd -│ │ ├── styles.go # Lipgloss styles -│ │ └── *_test.go # Tests -│ │ -│ └── web/ # HTTP server + Web UI -│ ├── server.go # Server, route registration, security headers -│ ├── middleware.go # Recovery, Logging, MethodOverride, Session, CSRF -│ ├── handlers/ -│ │ ├── chat.go # POST /genui/send, GET /genui/stream -│ │ ├── pages.go # GET /genui (main page) -│ │ ├── sessions.go # Session/CSRF management -│ │ ├── health.go # Health/ready probes -│ │ └── *_test.go # Tests (unit + integration + fuzz) -│ ├── sse/ -│ │ └── writer.go # SSE writer (chunks, done, error, tools) -│ ├── page/ -│ │ └── chat.templ # Chat page template -│ ├── layout/ -│ │ └── app.templ # Base HTML layout -│ ├── component/ -│ │ ├── message_bubble.templ -│ │ ├── sidebar.templ -│ │ ├── chat_input.templ -│ │ ├── empty_state.templ -│ │ ├── session_placeholders.templ -│ │ └── _reference/ # templUI Pro blocks (222 blocks) -│ ├── static/ # CSS (Tailwind), JS (HTMX, Elements, Prism) -│ └── e2e/ # Playwright E2E tests -│ -├── db/ # Database layer -│ ├── migrate.go # Embedded migration runner -│ ├── migrations/ -│ │ ├── 000001_init_schema.up.sql -│ │ └── 000001_init_schema.down.sql -│ └── queries/ -│ ├── documents.sql # 8 document queries -│ └── sessions.sql # 28 session/message queries -│ -├── build/ -│ ├── frontend/ # Tailwind CSS build config -│ └── sql/ -│ └── sqlc.yaml # SQLC code generation config -│ -├── prompts/ -│ └── koopa.prompt # Dotprompt system prompt (665 lines) -│ -├── scripts/ # Utility scripts -├── searxng/ # SearXNG Docker config -└── docs/ # Documentation -``` - -### Third-Party Dependencies - -| Dependency | Version | Purpose | -|---|---|---| -| **AI/LLM** | -| `github.com/firebase/genkit/go` | v1.2.0 | LLM orchestration, Flow, tools, Dotprompt | -| **Database** | -| `github.com/jackc/pgx/v5` | v5.7.6 | PostgreSQL driver (connection pooling) | -| `github.com/pgvector/pgvector-go` | v0.3.0 | pgvector extension (vector embeddings) | -| `github.com/golang-migrate/migrate/v4` | v4.19.1 | Database schema migrations | -| **MCP** | -| `github.com/modelcontextprotocol/go-sdk` | v1.1.0 | MCP server SDK (official) | -| **TUI/CLI** | -| `charm.land/bubbletea/v2` | v2.0.0-rc.2 | Interactive terminal UI framework | -| `charm.land/bubbles/v2` | v2.0.0-rc.1 | BubbleTea components (textarea, spinner) | -| `charm.land/lipgloss/v2` | v2.0.0-beta.3 | Terminal styling/formatting | -| `github.com/charmbracelet/glamour` | v0.10.0 | Markdown → terminal rendering | -| **Web** | -| `github.com/a-h/templ` | v0.3.960 | Go HTML template compiler (SSR) | -| **Web Scraping** | -| `github.com/gocolly/colly/v2` | v2.2.0 | Web scraping framework | -| `github.com/PuerkitoBio/goquery` | v1.11.0 | HTML parsing (jQuery-like) | -| `github.com/go-shiori/go-readability` | v0.0.0-20250217 | Article extraction (Readability) | -| **Configuration** | -| `github.com/spf13/viper` | v1.21.0 | Config file + env var management | -| `github.com/google/wire` | v0.7.0 | Compile-time dependency injection | -| **Utilities** | -| `github.com/google/uuid` | v1.6.0 | UUID generation | -| `github.com/google/jsonschema-go` | v0.3.0 | JSON Schema inference for tool inputs | -| `github.com/gofrs/flock` | v0.13.0 | File-based locking | -| `golang.org/x/sync` | v0.18.0 | errgroup for lifecycle management | -| `golang.org/x/time` | v0.14.0 | rate.Limiter for rate limiting | -| **Observability** | -| `go.opentelemetry.io/otel/sdk` | v1.38.0 | OpenTelemetry tracing SDK | -| `go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp` | v1.38.0 | OTLP HTTP exporter (Datadog) | -| **Testing** | -| `github.com/stretchr/testify` | v1.11.1 | Assertions and mocking | -| `github.com/testcontainers/testcontainers-go` | v0.40.0 | Docker containers for integration tests | -| `github.com/testcontainers/testcontainers-go/modules/postgres` | v0.40.0 | PostgreSQL testcontainer | -| `github.com/playwright-community/playwright-go` | v0.5200.1 | Browser automation (E2E tests) | -| `go.uber.org/goleak` | v1.3.0 | Goroutine leak detector | - ---- - -## 2. Core Type Definitions - -### 2.1 Agent Types - -#### `internal/agent/chat/chat.go` - -```go -// Chat is Koopa's main conversational agent - FULLY IMPLEMENTED -type Chat struct { - languagePrompt string // Resolved language for prompt template - maxTurns int // Agentic loop iterations (default: 5) - ragTopK int // RAG documents to retrieve - retryConfig RetryConfig // LLM retry settings - circuitBreaker *CircuitBreaker // Failure handling - rateLimiter *rate.Limiter // Proactive rate limiting (default: 10/s, burst 30) - tokenBudget TokenBudget // Context window limits - g *genkit.Genkit - retriever ai.Retriever // For RAG retrieval - sessions *session.Store // Session persistence - logger log.Logger - tools []ai.Tool // Pre-registered tools - toolRefs []ai.ToolRef // Cached for efficiency - toolNames string // Comma-separated for logging - prompt ai.Prompt // Cached Dotprompt instance -} -// Methods: New(Config), Execute(ctx, sessionID, input), ExecuteStream(ctx, sessionID, input, callback) - -// Config contains all required parameters for Chat agent -type Config struct { - Genkit *genkit.Genkit - Retriever ai.Retriever - SessionStore *session.Store - Logger log.Logger - Tools []ai.Tool - MaxTurns int - RAGTopK int - Language string - RetryConfig RetryConfig - CircuitBreakerConfig CircuitBreakerConfig - RateLimiter *rate.Limiter - TokenBudget TokenBudget -} - -// Response represents the complete result of an agent execution -type Response struct { - FinalText string - ToolRequests []*ai.ToolRequest -} - -// StreamCallback is called for each chunk of streaming response -type StreamCallback func(ctx context.Context, chunk *ai.ModelResponseChunk) error -``` - -#### `internal/agent/chat/flow.go` - -```go -// Flow type alias for Genkit Streaming Flow -type Flow = core.Flow[Input, Output, StreamChunk] - -const FlowName = "koopa/chat" - -type Input struct { - Query string `json:"query"` - SessionID string `json:"sessionId"` -} - -type Output struct { - Response string `json:"response"` - SessionID string `json:"sessionId"` -} - -type StreamChunk struct { - Text string `json:"text"` -} - -// Singleton management: -// InitFlow(g, chat) - Initialize once via sync.Once -// GetFlow() *Flow - Get initialized flow (panics if not initialized) -// ResetFlowForTesting() - Reset for test isolation -``` - -#### `internal/agent/chat/retry.go` - -```go -type RetryConfig struct { - MaxRetries int // Default: 3 - InitialInterval time.Duration // Default: 500ms - MaxInterval time.Duration // Default: 10s -} -// Methods: DefaultRetryConfig(), executeWithRetry(), retryableError() -``` - -#### `internal/agent/chat/circuit.go` - -```go -type CircuitState int // CircuitClosed=0, CircuitOpen=1, CircuitHalfOpen=2 - -type CircuitBreakerConfig struct { - FailureThreshold int // Default: 5 - SuccessThreshold int // Default: 2 - Timeout time.Duration // Default: 30s -} - -type CircuitBreaker struct { // Thread-safe (sync.RWMutex) - mu sync.RWMutex - state CircuitState - failures int - successes int - lastFailure time.Time - failureThreshold int - successThreshold int - timeout time.Duration -} -// Methods: NewCircuitBreaker(), Allow(), Success(), Failure(), State(), Reset() -``` - -#### `internal/agent/chat/tokens.go` - -```go -type TokenBudget struct { - MaxHistoryTokens int // Default: 8000 - MaxInputTokens int // Default: 2000 - ReservedTokens int // Default: 4000 -} -// Methods: DefaultTokenBudget(), estimateTokens(), truncateHistory() -``` - -#### `internal/agent/errors.go` - -```go -var ( - ErrInvalidSession = errors.New("invalid session") - ErrExecutionFailed = errors.New("execution failed") -) -``` - -### 2.2 Session / Message Types - -#### `internal/session/types.go` - -```go -// History - Thread-safe conversation history (sync.RWMutex) - FULLY IMPLEMENTED -type History struct { - mu sync.RWMutex - messages []*ai.Message -} -// Methods: NewHistory(), SetMessages(), Messages(), Add(), AddMessage(), Count(), Clear() - -// Session represents a conversation session -type Session struct { - ID uuid.UUID - Title string - CreatedAt time.Time - UpdatedAt time.Time - ModelName string - SystemPrompt string - MessageCount int -} - -// Message represents a single conversation message -type Message struct { - ID uuid.UUID - SessionID uuid.UUID - Role string // "user" | "assistant" | "system" | "tool" - Content []*ai.Part // Genkit Part slice (stored as JSONB) - Status string // "streaming" | "completed" | "failed" - SequenceNumber int - CreatedAt time.Time -} -``` - -#### `internal/session/store.go` - -```go -// Store manages session persistence with PostgreSQL backend - FULLY IMPLEMENTED -type Store struct { - queries *sqlc.Queries - pool *pgxpool.Pool - logger *slog.Logger -} -// Methods: New(), CreateSession(), GetSession(), ListSessions(), -// ListSessionsWithMessages(), DeleteSession(), UpdateSessionTitle(), -// AddMessages(), GetMessages(), AppendMessages(), GetHistory(), -// CreateMessagePair(), GetUserMessageBefore(), GetMessageByID(), -// UpdateMessageContent(), UpdateMessageStatus() - -// MessagePair represents a user-assistant message pair for streaming -type MessagePair struct { - UserMsgID uuid.UUID - AssistantMsgID uuid.UUID - UserSeq int32 - AssistantSeq int32 -} -``` - -#### `internal/session/errors.go` - -```go -const ( - StatusStreaming = "streaming" - StatusCompleted = "completed" - StatusFailed = "failed" -) - -var ( - ErrSessionNotFound = errors.New("session not found") - ErrMessageNotFound = errors.New("message not found") -) - -const ( - DefaultHistoryLimit int32 = 100 - MaxHistoryLimit int32 = 10000 - MinHistoryLimit int32 = 10 -) -``` - -### 2.3 Tool / MCP Types - -#### `internal/tools/types.go` - -```go -type Status string -const ( - StatusSuccess Status = "success" - StatusError Status = "error" -) - -type ErrorCode string -const ( - ErrCodeSecurity ErrorCode = "SecurityError" - ErrCodeNotFound ErrorCode = "NotFound" - ErrCodePermission ErrorCode = "PermissionDenied" - ErrCodeIO ErrorCode = "IOError" - ErrCodeExecution ErrorCode = "ExecutionError" - ErrCodeTimeout ErrorCode = "TimeoutError" - ErrCodeNetwork ErrorCode = "NetworkError" - ErrCodeValidation ErrorCode = "ValidationError" -) - -type Result struct { - Status Status `json:"status"` - Data any `json:"data,omitempty"` - Error *Error `json:"error,omitempty"` -} - -type Error struct { - Code ErrorCode `json:"code"` - Message string `json:"message"` - Details any `json:"details,omitempty"` -} -``` - -#### `internal/tools/metadata.go` - -```go -type DangerLevel int -const ( - DangerLevelSafe DangerLevel = iota // read_file, list_files, get_env, web_fetch - DangerLevelWarning // write_file (reversible) - DangerLevelDangerous // delete_file, execute_command (irreversible) - DangerLevelCritical // Reserved for future -) - -type ToolMetadata struct { - Name string - Description string - RequiresConfirmation bool - DangerLevel DangerLevel - IsDangerousFunc func(params map[string]any) bool - Category string -} -// Functions: GetToolMetadata(), GetAllToolMetadata(), IsDangerous(), -// RequiresConfirmation(), GetDangerLevel(), ListToolsByDangerLevel() -``` - -#### `internal/tools/emitter.go` - -```go -// ToolEventEmitter receives tool lifecycle events - INTERFACE -type ToolEventEmitter interface { - OnToolStart(name string) - OnToolComplete(name string) - OnToolError(name string) -} -// Functions: EmitterFromContext(ctx), ContextWithEmitter(ctx, emitter) -``` - -#### `internal/tools/file.go` - -```go -type FileTools struct { // FULLY IMPLEMENTED - pathVal *security.Path - logger log.Logger -} - -type ReadFileInput struct { - Path string `json:"path" jsonschema_description:"The file path to read"` -} -type WriteFileInput struct { - Path string `json:"path"` - Content string `json:"content"` -} -type ListFilesInput struct { - Path string `json:"path"` -} -type DeleteFileInput struct { - Path string `json:"path"` -} -type GetFileInfoInput struct { - Path string `json:"path"` -} -type FileEntry struct { - Name string `json:"name"` - Type string `json:"type"` // "file" | "directory" -} -``` - -#### `internal/tools/system.go` - -```go -type SystemTools struct { // FULLY IMPLEMENTED - cmdVal *security.Command - envVal *security.Env - logger log.Logger -} - -type ExecuteCommandInput struct { - Command string `json:"command"` - Args []string `json:"args,omitempty"` -} -type GetEnvInput struct { - Key string `json:"key"` -} -type CurrentTimeInput struct{} -``` - -#### `internal/tools/network.go` - -```go -type NetworkTools struct { // FULLY IMPLEMENTED - searchBaseURL string - searchClient *http.Client - fetchParallelism int - fetchDelay time.Duration - fetchTimeout time.Duration - urlValidator *security.URL - skipSSRFCheck bool - logger log.Logger -} - -type NetworkConfig struct { - SearchBaseURL string - FetchParallelism int - FetchDelay time.Duration - FetchTimeout time.Duration -} -``` - -#### `internal/tools/knowledge.go` - -```go -type KnowledgeTools struct { // FULLY IMPLEMENTED - retriever ai.Retriever - logger log.Logger -} - -type KnowledgeSearchInput struct { - Query string `json:"query"` - TopK int `json:"topK,omitempty"` -} -``` - -#### `internal/mcp/server.go` - -```go -type Server struct { // FULLY IMPLEMENTED - mcpServer *mcp.Server - fileTools *tools.FileTools - systemTools *tools.SystemTools - networkTools *tools.NetworkTools - name string - version string -} - -type Config struct { - Name string - Version string - FileTools *tools.FileTools - SystemTools *tools.SystemTools - NetworkTools *tools.NetworkTools -} -// Methods: NewServer(Config), Run(ctx, transport), registerTools() -``` - -### 2.4 Security Types - -#### `internal/security/path.go` - -```go -type Path struct { // FULLY IMPLEMENTED - allowedDirs []string - workDir string -} -// Methods: NewPath(allowedDirs), Validate(path) (string, error) - -var ( - ErrPathOutsideAllowed = errors.New("path is outside allowed directories") - ErrSymlinkOutsideAllowed = errors.New("symbolic link points outside allowed directories") - ErrPathNullByte = errors.New("path contains null byte") -) -``` - -#### `internal/security/command.go` - -```go -type Command struct { // FULLY IMPLEMENTED - WHITELIST mode - blacklist []string - whitelist []string -} -// Methods: NewCommand(), ValidateCommand(cmd, args), QuoteCommandArgs(args) -// 53 whitelisted commands: ls, cat, grep, git, go, npm, etc. -``` - -#### `internal/security/env.go` - -```go -type Env struct { // FULLY IMPLEMENTED - sensitivePatterns []string // 87 patterns blocked -} -// Methods: NewEnv(), Validate(key), GetAllowedEnvNames() -// 23 allowed variables: PATH, HOME, GOPATH, etc. -``` - -#### `internal/security/url.go` - -```go -type URL struct { // FULLY IMPLEMENTED - allowedSchemes map[string]struct{} - blockedHosts map[string]struct{} -} -// Methods: NewURL(), Validate(rawURL), SafeTransport(), ValidateRedirect() -// Blocks: localhost, metadata endpoints, private IPs, link-local -``` - -#### `internal/security/prompt.go` - -```go -type PromptValidator struct { // FULLY IMPLEMENTED - patterns []*regexp.Regexp // 13 patterns -} - -type PromptInjectionResult struct { - Safe bool - Patterns []string -} -// Methods: NewPromptValidator(), Validate(input), IsSafe(input) -``` - -### 2.5 Config Types - -#### `internal/config/config.go` - -```go -type Config struct { // FULLY IMPLEMENTED - // AI - ModelName string - Temperature float32 - MaxTokens int - Language string - PromptDir string - // History - MaxHistoryMessages int32 - MaxTurns int - // Database - DatabasePath string - PostgresHost string - PostgresPort int - PostgresUser string - PostgresPassword string // Masked in MarshalJSON - PostgresDBName string - PostgresSSLMode string - // RAG - RAGTopK int - EmbedderModel string - // MCP - MCP MCPConfig - MCPServers map[string]MCPServer - // Tools - SearXNG SearXNGConfig - WebScraper WebScraperConfig - // Observability - Datadog DatadogConfig - // Security - HMACSecret string // Masked in MarshalJSON -} -// Methods: Load(), Validate(), MarshalJSON() -``` - -#### `internal/config/tools.go` - -```go -type MCPConfig struct { - Allowed []string - Excluded []string - Timeout int -} - -type MCPServer struct { - Command string - Args []string - Env map[string]string - Timeout int - IncludeTools []string - ExcludeTools []string -} - -type SearXNGConfig struct { - BaseURL string -} - -type WebScraperConfig struct { - Parallelism int - DelayMs int - TimeoutMs int -} -``` - -### 2.6 App / Runtime Types - -#### `internal/app/app.go` - -```go -type App struct { // FULLY IMPLEMENTED - Config *config.Config - Genkit *genkit.Genkit - Embedder ai.Embedder - DBPool *pgxpool.Pool - DocStore *postgresql.DocStore - Retriever ai.Retriever - SessionStore *session.Store - PathValidator *security.Path - Tools []ai.Tool - ctx context.Context - cancel context.CancelFunc - eg *errgroup.Group - egCtx context.Context -} -// Methods: Close(), Wait(), Go(func() error), CreateAgent(ctx) -``` - -#### `internal/app/runtime.go` - -```go -type Runtime struct { // FULLY IMPLEMENTED - App *App - Flow *chat.Flow - cleanup func() -} -// Factory: NewRuntime(ctx, cfg) → single init point for all entry points -// Methods: Close() error (App.Close() → Wire cleanup) -``` - -### 2.7 Web Handler Types - -#### `internal/web/handlers/chat.go` - -```go -// SSEWriter - INTERFACE -type SSEWriter interface { - WriteChunkRaw(msgID, htmlContent string) error - WriteDone(ctx context.Context, msgID string, comp templ.Component) error - WriteError(msgID, code, message string) error - WriteSidebarRefresh(sessionID, title string) error -} - -type ChatConfig struct { - Logger *slog.Logger - Genkit *genkit.Genkit - Flow *chat.Flow - Sessions *Sessions - SSEWriterFn func(w http.ResponseWriter) (SSEWriter, error) -} - -type Chat struct { // FULLY IMPLEMENTED - logger *slog.Logger - genkit *genkit.Genkit - flow *chat.Flow - sessions *Sessions - sseWriterFn func(w http.ResponseWriter) (SSEWriter, error) -} -// Methods: NewChat(ChatConfig), Send(w, r), Stream(w, r) - -type streamState struct { - msgID string - sessionID string - buffer strings.Builder -} -``` - -#### `internal/web/handlers/sessions.go` - -```go -type Sessions struct { // FULLY IMPLEMENTED - store *session.Store - hmacSecret []byte - isDev bool -} -// Methods: NewSessions(), GetOrCreate(r), ID(r), NewCSRFToken(sessionID), -// CheckCSRF(r), RegisterRoutes(mux) -``` - -#### `internal/web/sse/writer.go` - -```go -type Writer struct { // FULLY IMPLEMENTED - w io.Writer - flusher http.Flusher -} -// Methods: NewWriter(), WriteChunk(), WriteChunkRaw(), WriteDone(), -// WriteError(), WriteSidebarRefresh(), -// WriteToolStart(), WriteToolComplete(), WriteToolError() -``` - -### 2.8 TUI Types - -#### `internal/tui/tui.go` - -```go -type State int -const ( - StateInput State = iota // Awaiting input - StateThinking // Processing request - StateStreaming // Streaming response -) - -type Message struct { - Role string // "user" | "assistant" | "system" | "error" - Content string -} - -type TUI struct { // FULLY IMPLEMENTED - BubbleTea Model - input textarea.Model - history []string - historyIdx int - state State - lastCtrlC time.Time - spinner spinner.Model - output strings.Builder - viewBuf strings.Builder - messages []Message - streamCancel context.CancelFunc - streamEventCh <-chan streamEvent - chatFlow *chat.Flow - sessionID string - ctx context.Context - ctxCancel context.CancelFunc - width int - height int - styles Styles - markdown *markdownRenderer -} -// Implements tea.Model: Init(), Update(msg), View() -``` - ---- - -## 3. Genkit Integration - -### 3.1 Initialization - -**Location**: `internal/app/wire.go` lines 107-130 - -```go -g := genkit.Init(ctx, - genkit.WithPlugins(&googlegenai.GoogleAI{}, postgres), - genkit.WithPromptDir(promptDir), // Default: "prompts" -) -``` - -**Plugins Used**: -| Plugin | Purpose | -|--------|---------| -| `googlegenai.GoogleAI{}` | Gemini model access (chat + embeddings) | -| `postgresql.Postgres{}` | DocStore + Retriever for RAG (pgvector) | - -**Initialization Order** (Wire DI): -1. OpenTelemetry setup (tracing before Genkit) -2. Database pool (migrations run) -3. PostgreSQL plugin creation -4. Genkit Init with plugins -5. Embedder provision -6. RAG components (DocStore + Retriever) -7. Session store -8. Security validators -9. Tool registration (all 13 tools) -10. App construction - -### 3.2 Flow Definitions - -**One flow defined**: `koopa/chat` - -| Flow Name | Type | Input | Output | Stream Type | -|-----------|------|-------|--------|-------------| -| `koopa/chat` | `genkit.DefineStreamingFlow` | `Input{Query, SessionID}` | `Output{Response, SessionID}` | `StreamChunk{Text}` | - -**Flow lifecycle**: -- Singleton via `sync.Once` in `InitFlow(g, chat)` -- Access via `GetFlow()` (panics if not initialized) -- `ResetFlowForTesting()` for test isolation - -**Flow implementation** (`internal/agent/chat/flow.go` lines 87-149): -1. Parse session UUID from `input.SessionID` -2. Wrap `streamCb` into `StreamCallback` (adapts `StreamChunk` → `ai.ModelResponseChunk`) -3. Call `chat.ExecuteStream(ctx, sessionID, query, callback)` -4. Return `Output{Response, SessionID}` - -### 3.3 Tool Calling - -**13 tools registered** at app initialization via `genkit.DefineTool()`: - -| Category | Tool Name | Danger Level | -|----------|-----------|-------------| -| **File** | `read_file` | Safe | -| | `write_file` | Warning | -| | `list_files` | Safe | -| | `delete_file` | Dangerous | -| | `get_file_info` | Safe | -| **System** | `current_time` | Safe | -| | `execute_command` | Dangerous | -| | `get_env` | Safe | -| **Network** | `web_search` | Safe | -| | `web_fetch` | Safe | -| **Knowledge** | `search_history` | Safe | -| | `search_documents` | Safe | -| | `search_system_knowledge` | Safe | - -**Tool middleware**: `WithEvents` wrapper — emits `OnToolStart`/`OnToolComplete`/`OnToolError` lifecycle events via `ToolEventEmitter` interface (used for SSE tool status display). - -**No other middleware** (no rate limiting, no caching, no request/response interceptors). - -**Agentic loop**: `ai.WithMaxTurns(maxTurns)` — default 5 turns. LLM can call tools iteratively. - -### 3.4 Structured Output - -**Not used**. The flow uses plain string output: -- Input: JSON object `{query, sessionId}` -- Output: JSON object `{response, sessionId}` -- Streaming: JSON chunks `{text}` - -Tool inputs use `jsonschema_description` tags for schema inference, but no output schema validation. - -### 3.5 Session Management - -**Custom** (NOT Genkit built-in). PostgreSQL-backed via `session.Store`: - -``` -Agent.ExecuteStream() - → sessions.GetHistory(ctx, sessionID) // Load []*ai.Message from DB - → generateResponse(ctx, input, messages) // LLM call - → sessions.AppendMessages(ctx, sessionID) // Persist new messages -``` - -Messages stored as JSONB in `message.content` column, serialized from `[]*ai.Part`. - -### 3.6 Streaming - -**Two-layer architecture**: - -1. **Agent → Flow** (internal): `ai.WithStreaming(callback)` where callback receives `*ai.ModelResponseChunk` -2. **Flow → HTTP** (SSE): Go 1.23 `range-over-func` iterator over `flow.Stream(ctx, input)` - -**SSE endpoint**: `GET /genui/stream?msgId=X&session_id=Y` -- 5-minute timeout -- HTML escaping (XSS prevention) -- OOB swaps for HTMX (sidebar refresh, final message replacement) - ---- - -## 4. Data Layer - -### 4.1 Database - -| Field | Value | -|-------|-------| -| **DBMS** | PostgreSQL 17 with pgvector extension | -| **Driver** | `pgx/v5` (connection pooling) | -| **Pool Config** | MaxConns=10, MinConns=2, MaxLifetime=30m, MaxIdle=5m, HealthCheck=1m | -| **Vector Extension** | pgvector (vector(768), HNSW index, cosine distance) | -| **Query Generation** | sqlc (type-safe Go from SQL) | -| **Migration Tool** | golang-migrate v4 | -| **Docker Image** | `pgvector/pgvector:pg17` | - -### 4.2 SQLC Configuration - -**Config**: `build/sql/sqlc.yaml` -- Engine: PostgreSQL -- Driver: `pgx/v5` -- Output: `internal/sqlc/` -- Features: JSON tags, empty slices (not nil), UUID type override - -**Query Files**: - -| File | Queries | Purpose | -|------|---------|---------| -| `db/queries/documents.sql` | 8 | Document CRUD + vector search | -| `db/queries/sessions.sql` | 28 | Session/message CRUD + streaming lifecycle | - -### 4.3 Migration Files - -Single migration: `db/migrations/000001_init_schema.up.sql` - -Migrations are **embedded** via `//go:embed` and run automatically at startup in `provideDBPool()`. - -### 4.4 Complete Schema (CREATE TABLE) - -```sql --- Extension -CREATE EXTENSION IF NOT EXISTS vector; - --- Helper function -CREATE OR REPLACE FUNCTION update_updated_at_column() -RETURNS TRIGGER AS $$ -BEGIN - NEW.updated_at = NOW(); - RETURN NEW; -END; -$$ LANGUAGE plpgsql; - --- Table: documents (RAG/Knowledge Store) -CREATE TABLE IF NOT EXISTS documents ( - id TEXT PRIMARY KEY, - content TEXT NOT NULL, - embedding vector(768) NOT NULL, - source_type TEXT, - metadata JSONB -); -CREATE INDEX IF NOT EXISTS idx_documents_embedding - ON documents USING hnsw (embedding vector_cosine_ops) - WITH (m = 16, ef_construction = 64); -CREATE INDEX IF NOT EXISTS idx_documents_source_type ON documents(source_type); -CREATE INDEX IF NOT EXISTS idx_documents_metadata_gin - ON documents USING GIN (metadata jsonb_path_ops); - --- Table: sessions -CREATE TABLE IF NOT EXISTS sessions ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - title TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - model_name TEXT, - system_prompt TEXT, - message_count INTEGER DEFAULT 0 -); -CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at DESC); -CREATE TRIGGER update_sessions_updated_at - BEFORE UPDATE ON sessions FOR EACH ROW - EXECUTE FUNCTION update_updated_at_column(); - --- Table: message -CREATE TABLE IF NOT EXISTS message ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - session_id UUID NOT NULL REFERENCES sessions(id) ON DELETE CASCADE, - role TEXT NOT NULL, - content JSONB NOT NULL, - sequence_number INTEGER NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - status TEXT NOT NULL DEFAULT 'completed' - CHECK (status IN ('streaming', 'completed', 'failed')), - updated_at TIMESTAMPTZ DEFAULT NOW(), - CONSTRAINT unique_message_sequence UNIQUE (session_id, sequence_number), - CONSTRAINT message_role_check CHECK (role IN ('user', 'assistant', 'system', 'tool')) -); -CREATE INDEX IF NOT EXISTS idx_message_session_id ON message(session_id); -CREATE INDEX IF NOT EXISTS idx_message_session_seq ON message(session_id, sequence_number); -CREATE INDEX IF NOT EXISTS idx_incomplete_messages ON message(session_id, updated_at) - WHERE status IN ('streaming', 'failed'); -CREATE INDEX IF NOT EXISTS idx_message_status ON message(session_id, status) - WHERE status != 'completed'; -CREATE INDEX IF NOT EXISTS idx_message_content_gin - ON message USING GIN (content jsonb_path_ops); -CREATE TRIGGER update_message_updated_at - BEFORE UPDATE ON message FOR EACH ROW - EXECUTE FUNCTION update_updated_at_column(); -``` - -### 4.5 pgvector Usage - -| Aspect | Detail | -|--------|--------| -| **Dimension** | 768 (text-embedding-004 model) | -| **Index Type** | HNSW (m=16, ef_construction=64) | -| **Distance Metric** | Cosine (`<=>` operator) | -| **Similarity Formula** | `(1 - (embedding <=> query_embedding))::float8` | -| **Table** | `documents` | -| **Source Types** | `conversation`, `file`, `system` | -| **System Knowledge** | 6 pre-indexed documents at startup | - ---- - -## 5. Channel Adapter Implementation - -### Current State: No External Channel Adapters - -The project has **three access channels**, all built-in: - -| Channel | Implementation | Transport | -|---------|---------------|-----------| -| **CLI/TUI** | `internal/tui/` (BubbleTea) | Terminal stdin/stdout | -| **Web Chat** | `internal/web/` (Templ + HTMX) | HTTP + SSE | -| **MCP** | `internal/mcp/` (MCP SDK) | stdio | - -**No Telegram, LINE, Discord, Slack, or WhatsApp adapters exist.** - -### Web Chat Details - -| Aspect | Detail | -|--------|--------| -| **Main Files** | `internal/web/server.go`, `handlers/chat.go`, `handlers/pages.go`, `handlers/sessions.go`, `sse/writer.go` | -| **Template Engine** | Templ (Go SSR) | -| **CSS Framework** | Tailwind CSS (compiled, embedded) | -| **Interactivity** | HTMX + Tailwind Plus Elements | -| **Streaming** | Server-Sent Events (SSE) | -| **Session** | HTTP cookie (30-day expiry) + CSRF tokens | -| **Routes** | `GET /genui`, `POST /genui/send`, `GET /genui/stream`, `POST /genui/sessions/*` | - -### CLI/TUI Details - -| Aspect | Detail | -|--------|--------| -| **Framework** | BubbleTea v2 | -| **State Machine** | 3 states: Input → Thinking → Streaming | -| **Streaming** | Go 1.23 range-over-func iterator, discriminated union channel | -| **Key Bindings** | Enter=submit, Shift+Enter=newline, Esc=cancel, Ctrl+C×2=quit | -| **Markdown** | Glamour rendering (graceful degradation to plain text) | -| **History** | In-memory (max 100 entries), Up/Down navigation | - ---- - -## 6. MCP Integration - -### 6.1 SDK - -| Field | Value | -|-------|-------| -| **SDK** | `github.com/modelcontextprotocol/go-sdk` v1.1.0 (official) | -| **Role** | **MCP Server** only (not client) | -| **Transport** | Stdio | -| **Entry Point** | `cmd/mcp.go` → `mcp.NewServer(config).Run(ctx, &StdioTransport{})` | - -### 6.2 MCP Server Tools (10 tools) - -Same as Genkit tools minus the 3 knowledge tools (knowledge tools require Genkit retriever, MCP server doesn't initialize full Genkit): - -| Tool | Description | -|------|-------------| -| `read_file` | Read file content (max 10MB) | -| `write_file` | Create/overwrite files | -| `list_files` | List directory contents | -| `delete_file` | Delete files | -| `get_file_info` | File metadata | -| `current_time` | System time | -| `execute_command` | Whitelisted commands | -| `get_env` | Non-sensitive env vars | -| `web_search` | SearXNG search | -| `web_fetch` | URL fetch with SSRF protection | - -### 6.3 Tool Registry Management - -- Tools registered via `mcp.AddTool()` in `Server.registerTools()` -- Input schemas auto-generated using `jsonschema.For[T](nil)` -- Results converted via `resultToMCP()` with error detail sanitization (blocks stack traces, file paths, API keys) - -### 6.4 Connected MCP Servers - -The `.mcp.json` configures one MCP server for **development** use with Claude Code: - -```json -{ - "mcpServers": { - "genkit": { - "command": "genkit", - "args": ["mcp", "--no-update-notification"] - } - } -} -``` - -`config.example.yaml` shows additional configurable MCP servers (fetch, filesystem, github) but these are **configuration templates**, not actively connected. - ---- - -## 7. Permission / Security - -### 7.1 Security Validators (5 modules) - -| Validator | File | Prevents | Mechanism | -|-----------|------|----------|-----------| -| **Path** | `security/path.go` | Directory traversal (CWE-22) | Whitelist allowed dirs, symlink resolution, null byte rejection | -| **Command** | `security/command.go` | Command injection (CWE-78) | Whitelist 53 safe commands, argument validation | -| **Env** | `security/env.go` | Info leakage | Block 87 sensitive patterns, allow 23 safe vars | -| **URL** | `security/url.go` | SSRF (CWE-918) | Block private IPs, metadata endpoints, DNS rebinding protection | -| **Prompt** | `security/prompt.go` | Prompt injection | 13 regex patterns, Unicode normalization | - -### 7.2 Tool Call Permission Check - -Security validation is **inline** — each tool validates its inputs before execution: -- `FileTools.ReadFile()` → `pathVal.Validate(input.Path)` -- `SystemTools.ExecuteCommand()` → `cmdVal.ValidateCommand(cmd, args)` -- `SystemTools.GetEnv()` → `envVal.Validate(key)` -- `NetworkTools.WebFetch()` → `urlValidator.Validate(url)` + `SafeTransport()` - -Security failures return `Result{Status: StatusError, Error: &Error{Code: ErrCodeSecurity}}` — business errors, not Go errors. This allows the LLM to handle rejections gracefully. - -### 7.3 Approval Flow - -**Not implemented**. `ToolMetadata.RequiresConfirmation` and `DangerLevel` are defined but there is no runtime approval mechanism. The metadata system is in place as infrastructure for a future approval flow. - -### 7.4 Audit Logging - -**slog-based** security event logging with structured fields: - -```go -logger.Warn("security_event", - "type", "path_traversal_attempt", - "path", unsafePath, - "allowed_dirs", allowedDirs, -) -``` - -Events logged for: path traversal, command injection, sensitive env access, SSRF blocks, symlink traversal, prompt injection. - -**No database-backed audit log**. All events go to application logs only. - ---- - -## 8. Frontend - -### 8.1 Technology Stack - -| Aspect | Technology | -|--------|-----------| -| **Template Engine** | Templ v0.3.960 (Go SSR, type-safe) | -| **CSS Framework** | Tailwind CSS (compiled, embedded in binary) | -| **Interactivity** | HTMX + HTMX SSE Extension | -| **Client-side** | Tailwind Plus Elements (dropdowns, modals) | -| **Code Highlighting** | Prism.js | -| **No SPA framework** | No Angular, React, or Vue | - -### 8.2 Pages and Components - -| Type | File | Purpose | -|------|------|---------| -| **Layout** | `layout/app.templ` | Base HTML (head, body, scripts) | -| **Page** | `page/chat.templ` | Two-column chat layout (sidebar + feed) | -| **Component** | `component/message_bubble.templ` | User/assistant message rendering | -| | `component/sidebar.templ` | Session list sidebar | -| | `component/chat_input.templ` | Message input form | -| | `component/empty_state.templ` | Empty state view | -| | `component/session_placeholders.templ` | Loading skeletons | - -### 8.3 Real-time Updates - -- **SSE** (Server-Sent Events) for streaming AI responses -- **HTMX OOB swaps** for sidebar refresh, session field updates -- **No WebSocket** implementation - -### 8.4 Dashboard - -**Not implemented**. Only a chat interface exists. - -### 8.5 Static Asset Management - -- **Production**: Assets embedded via `//go:embed` (single binary deployment) -- **Development**: `assets_dev.go` loads from filesystem (build tag: `dev`) -- **JS Libraries**: HTMX 2.x, htmx-sse extension, Tailwind Plus Elements, Prism.js — all vendored locally (no CDN) - ---- - -## 9. Deployment & Configuration - -### 9.1 Docker Compose - -```yaml -services: - postgres: - image: pgvector/pgvector:pg17 - ports: ["5432:5432"] - environment: - POSTGRES_USER: koopa - POSTGRES_PASSWORD: koopa_dev_password - POSTGRES_DB: koopa - volumes: - - postgres_data:/var/lib/postgresql/data - - ./db/migrations:/docker-entrypoint-initdb.d # Auto-init - healthcheck: pg_isready -U koopa - - searxng: - image: searxng/searxng:latest - ports: ["8888:8080"] - depends_on: [redis] - healthcheck: wget --no-verbose --tries=1 http://localhost:8080/healthz - - redis: - image: valkey/valkey:8-alpine - healthcheck: valkey-cli ping -``` - -**No Dockerfile for Koopa itself** — built as a Go binary, not containerized. - -### 9.2 Environment Variables - -| Variable | Required | Default | Purpose | -|----------|----------|---------|---------| -| `GEMINI_API_KEY` | Yes | — | Google Gemini API key | -| `DATABASE_URL` | Yes | — | PostgreSQL connection string | -| `HMAC_SECRET` | Web only | — | CSRF token signing (min 32 chars) | -| `KOOPA_MODEL_NAME` | No | `gemini-2.5-flash` | LLM model name | -| `KOOPA_TEMPERATURE` | No | `0.7` | LLM temperature | -| `KOOPA_MAX_TOKENS` | No | `2048` | Max response tokens | -| `KOOPA_MAX_HISTORY_MESSAGES` | No | `100` | Conversation history limit | -| `KOOPA_RAG_TOP_K` | No | `3` | RAG documents to retrieve | -| `DEBUG` | No | `false` | Debug logging | -| `DD_API_KEY` | No | — | Datadog APM | - -### 9.3 Configuration File - -**Location**: `~/.koopa/config.yaml` - -Used for MCP server definitions and advanced tool configuration. See `config.example.yaml` for template. - -### 9.4 Startup Flow - -``` -main.go - → cmd.Execute() - → Parse flags (--version, --help) - → Dispatch to: runCLI() | RunServe() | RunMCP() - -runCLI(): - → config.Load() - → signal.NotifyContext(SIGINT, SIGTERM) - → app.NewRuntime(ctx, cfg) - → Wire: InitializeApp() (10-step provider chain) - → chat.New(Config{...}) # Create agent - → chat.InitFlow(g, agent) # Register Genkit Flow (sync.Once) - → Load/create session ID (~/.koopa/current_session) - → tui.New(ctx, flow, sessionID) - → tea.NewProgram(model).Run() - -RunServe(): - → Validate HMAC_SECRET - → app.NewRuntime(ctx, cfg) - → web.NewServer(ServerConfig{...}) - → http.Server{...} with timeouts - → Graceful shutdown on SIGINT/SIGTERM - -RunMCP(): - → app.InitializeApp(ctx, cfg) # Wire directly - → Create FileTools, SystemTools, NetworkTools - → mcp.NewServer(Config{...}) - → server.Run(ctx, &StdioTransport{}) -``` - ---- - -## 10. Testing - -### 10.1 Test Types - -| Type | Count | Build Tag | Runner | -|------|-------|-----------|--------| -| **Unit Tests** | ~70 files | none | `go test ./...` | -| **Integration Tests** | ~10 files | `integration` | `go test -tags integration ./...` | -| **Fuzz Tests** | 4 files | none | `go test -fuzz=FuzzXxx ./...` | -| **E2E Tests** | ~5 files | `e2e` | Playwright browser automation | -| **Race Tests** | 1 file | none | `go test -race ./...` | - -### 10.2 Testing Patterns - -**Unit Test Example** (handler testing): -```go -func TestChat_Send(t *testing.T) { - fw := SetupTest(t) // Test framework with mock deps - defer fw.Cleanup() - req := httptest.NewRequest(...) - w := httptest.NewRecorder() - fw.Chat.Send(w, req) - assert.Equal(t, 200, w.Code) -} -``` - -**Integration Test Example** (real database): -```go -//go:build integration - -func TestPages_Chat_LoadsHistoryFromDatabase(t *testing.T) { - // Uses testcontainers PostgreSQL - db := testutil.SetupTestDB(t) - store := session.New(db, logger) - // ... test with real DB operations -} -``` - -**Fuzz Test Example** (security): -```go -func FuzzMessageContent(f *testing.F) { - f.Add("normal text") - f.Add("") - f.Add("'; DROP TABLE messages; --") - f.Fuzz(func(t *testing.T, content string) { - // Verify no panic, proper escaping - }) -} -``` - -### 10.3 Test Infrastructure - -| Component | Location | Purpose | -|-----------|----------|---------| -| `testutil.SetupTestDB()` | `internal/testutil/db.go` | PostgreSQL testcontainer | -| `testutil.DiscardLogger()` | `internal/testutil/logger.go` | No-op logger | -| Deterministic embedder | `internal/testutil/embedder.go` | Predictable embeddings for tests | -| Mock SSE Writer | `internal/web/handlers/chat_test.go` | Records SSE events | -| Browser fixture | `internal/web/fixture_test.go` | Playwright browser context | -| SSE event parser | `internal/testutil/` | Parse SSE format | - -### 10.4 Coverage - -Not explicitly measured (no CI coverage badge), but `coverage.out` exists at root. Test infrastructure suggests moderate-to-good coverage for core packages (agent, tools, security, handlers). - ---- - -## 11. Architecture Flow Diagram (Working) - -### Message Processing Flow (Web Chat) - -``` -User types message in browser - │ - ▼ -POST /genui/send (HTMX form submission) - │ - ├── Parse form: content, session_id, csrf_token - ├── Validate CSRF token (HMAC-SHA256) - ├── Lazy session creation (if first message) - │ - ├── Render HTML response: - │ ├── User message bubble (OOB swap into #message-feed) - │ ├── Assistant skeleton with SSE connection: - │ │
- │ └── OOB swaps: session_id field, csrf token refresh - │ - ▼ -GET /genui/stream (SSE endpoint) - │ - ├── Parse query params: msgId, session_id, query - ├── Create 5-minute timeout context - │ - ├── flow.Stream(ctx, Input{Query, SessionID}) - │ │ - │ ▼ - │ Chat.ExecuteStream(ctx, sessionID, query, callback) - │ │ - │ ├── sessions.GetHistory(ctx, sessionID) # Load from PostgreSQL - │ ├── deepCopyMessages(history) # Prevent Genkit data race - │ ├── truncateHistory(messages, tokenBudget) # Context window management - │ ├── retrieveRAGContext(ctx, query) # pgvector search (5s timeout) - │ │ - │ ├── circuitBreaker.Allow() # Check circuit state - │ ├── executeWithRetry(ctx, opts) # Exponential backoff - │ │ │ - │ │ ├── rateLimiter.Wait(ctx) # Rate limiting - │ │ └── prompt.Execute(ctx, opts...) # Genkit → Gemini API - │ │ │ - │ │ ├── ai.WithTools(toolRefs...) # 13 tools available - │ │ ├── ai.WithMaxTurns(5) # Agentic loop - │ │ ├── ai.WithStreaming(callback) # Stream chunks - │ │ └── ai.WithDocs(ragDocs...) # RAG context - │ │ │ - │ │ ▼ - │ │ [Gemini processes, may call tools] - │ │ │ - │ │ ├── Tool call → security validation → execute → Result - │ │ ├── Tool result → back to Gemini → next turn - │ │ └── Final text response - │ │ - │ ├── circuitBreaker.Success() - │ ├── sessions.AppendMessages(ctx, sessionID, [user, assistant]) - │ └── Return Response{FinalText, ToolRequests} - │ - ├── For each StreamChunk: - │ └── sseWriter.WriteChunkRaw(msgID, escapedHTML) → SSE "chunk" event - │ - ├── maybeGenerateTitle(sessionID, content) # AI title gen - ├── sseWriter.WriteSidebarRefresh(sessionID, title) # HX-Trigger - └── sseWriter.WriteDone(ctx, msgID, finalComponent) # SSE "done" event - │ - ▼ -Browser: HTMX processes SSE events - ├── "chunk" events → swap into assistant bubble (innerHTML) - ├── Sidebar refresh → re-render session list - └── "done" event → close SSE connection -``` - -### Message Processing Flow (CLI/TUI) - -``` -User types in terminal textarea - │ - ├── Enter key → submit - │ - ▼ -TUI.handleKey() → StateInput → StateThinking - │ - ├── startStream(query) → tea.Cmd - │ │ - │ ├── Create buffered channel (100 items) - │ ├── Spawn goroutine with 5-min timeout - │ └── flow.Stream(ctx, Input{Query, SessionID}) - │ │ - │ └── [Same agent flow as Web Chat above] - │ - ├── streamStartedMsg → StateStreaming - │ ├── streamTextMsg → append to output.Builder - │ ├── streamTextMsg → append to output.Builder - │ └── ... - │ - └── streamDoneMsg → StateInput - ├── Append to messages slice (max 100) - └── Clear output buffer -``` - ---- - -## 12. Gap Analysis - -| Component | Status | Details | -|-----------|--------|---------| -| **Agent Runtime (Genkit flow-based)** | ✅ 完成 | `genkit.DefineStreamingFlow`, agentic loop with 5 max turns, retry + circuit breaker + rate limiting + token budget | -| **Channel Adapter: Telegram** | ❌ 尚未開始 | 無任何 Telegram 相關程式碼 | -| **Channel Adapter: LINE** | ❌ 尚未開始 | 無任何 LINE 相關程式碼 | -| **Channel Adapter: Web Chat** | ✅ 完成 | Templ SSR + HTMX + SSE streaming, session management, CSRF protection | -| **Permission Engine (Always/RequireApproval/RoleOnly)** | 🟡 部分完成 | **已完成**: `DangerLevel` enum (Safe/Warning/Dangerous/Critical), `ToolMetadata.RequiresConfirmation` flag, inline security validators (5 modules). **缺少**: Runtime approval flow (approve/reject UI), role-based access control, permission policy engine | -| **MCP Bridge (client + permission layer)** | 🟡 部分完成 | **已完成**: MCP Server (10 tools, stdio transport, official SDK v1.1.0). **缺少**: MCP Client (connecting to external MCP servers at runtime), permission layer for MCP tool calls | -| **Event Bus (structured events for observability)** | 🟡 部分完成 | **已完成**: `ToolEventEmitter` interface with `OnToolStart/Complete/Error`, context-based propagation, SSE tool status display. **缺少**: General-purpose event bus, structured event types beyond tools, event persistence, event subscribers/listeners pattern | -| **Memory: Short-term (conversation context)** | ✅ 完成 | `session.History` (thread-safe `[]*ai.Message`), token budget truncation (8K history limit), loaded from PostgreSQL per request | -| **Memory: Long-term (PostgreSQL cross-conversation)** | ✅ 完成 | `sessions` + `message` tables, JSONB content storage, sequence numbering, message status lifecycle (streaming/completed/failed), session listing with pagination | -| **Memory: Knowledge Base (RAG with pgvector)** | ✅ 完成 | `documents` table with vector(768), HNSW index, cosine similarity, 3 source types (conversation/file/system), 6 system knowledge docs, 3 knowledge tools for agent | -| **Session Management** | ✅ 完成 | PostgreSQL-backed `session.Store`, HTTP cookie sessions (30-day), CSRF tokens (HMAC-SHA256), lazy session creation, CLI session persistence (~/.koopa/current_session) | -| **Audit Logging (PostgreSQL)** | 🟡 部分完成 | **已完成**: slog-based security event logging (path traversal, command injection, SSRF, prompt injection). **缺少**: PostgreSQL audit table, queryable audit trail, audit log retention policy | -| **Approval Flow (pending → approved/rejected)** | ❌ 尚未開始 | `ToolMetadata.RequiresConfirmation` exists as schema but no runtime implementation. No pending state, no approval UI, no notification system | -| **Dashboard Frontend (Angular)** | ❌ 尚未開始 | 無 Angular 程式碼。目前只有 Templ SSR chat interface | -| **REST API for Dashboard** | 🟡 部分完成 | **已完成**: `GET /genui` (chat page), `POST /genui/send`, `GET /genui/stream` (SSE), `POST /genui/sessions/*`, `GET /health`, `GET /ready`. **缺少**: RESTful CRUD API for sessions/messages (JSON), admin/dashboard-specific endpoints, API versioning | -| **WebSocket for real-time events** | ❌ 尚未開始 | 目前使用 SSE (Server-Sent Events) 進行即時串流,無 WebSocket 實作 | -| **Docker Compose deployment** | 🟡 部分完成 | **已完成**: PostgreSQL + pgvector, SearXNG + Redis, health checks, volume mounts. **缺少**: Koopa 應用本身的 Dockerfile, production docker-compose profile, nginx/reverse proxy | -| **CLI / Configuration** | ✅ 完成 | 三種模式 (cli/serve/mcp), Viper config (.env + YAML), signal handling + graceful shutdown, Taskfile build system | - -### Summary Table - -| Status | Count | Items | -|--------|-------|-------| -| ✅ 完成 | 7 | Agent Runtime, Web Chat, Short-term Memory, Long-term Memory, Knowledge Base (RAG), Session Management, CLI/Config | -| 🟡 部分完成 | 6 | Permission Engine, MCP Bridge, Event Bus, Audit Logging, REST API, Docker Compose | -| ❌ 尚未開始 | 5 | Telegram Adapter, LINE Adapter, Approval Flow, Dashboard (Angular), WebSocket | - ---- - -## Appendix: TODO/FIXME/HACK Comments - -| File | Comment | -|------|---------| -| `internal/web/server.go:86-96` | TODO: Implement Settings and Search handlers | -| `internal/web/server.go:112-114` | TODO: Settings and Search routes | -| `internal/agent/chat/chat.go:412` | TODO: File Genkit GitHub issue (data race workaround) | -| `internal/app/app_test.go:108` | TODO: Re-enable when Toolset migration complete | -| `cmd/e2e_test.go:198-203` | FIXME: MCP server exits immediately (requires test harness) | -| `cmd/e2e_test.go:255-260` | FIXME: MCP communication fails with EOF (needs proper test client) | - ---- - -## Appendix: Environment Configuration Template - -**`.env.example`**: -```bash -# Required (All Modes) -GEMINI_API_KEY=your-api-key-here -DATABASE_URL=postgres://koopa:koopa_dev_password@localhost:5432/koopa?sslmode=disable - -# Required (Web Mode Only) -HMAC_SECRET= # openssl rand -base64 32 - -# Optional Model Settings -KOOPA_MODEL_NAME=gemini-2.5-flash -KOOPA_TEMPERATURE=0.7 -KOOPA_MAX_TOKENS=2048 -KOOPA_MAX_HISTORY_MESSAGES=100 -KOOPA_RAG_TOP_K=3 - -# Optional -DEBUG=false -DD_API_KEY= # Datadog APM -``` - ---- - -## Appendix: Build Commands - -```bash -# Development -task build:dev # Build with filesystem assets -task css:watch # Watch CSS changes -task generate # Generate templ files -task sqlc # Generate SQL code - -# Testing -task test # All tests -task test:race # With race detector -task test:unit # Unit only -task test:integration # With testcontainers -task test:fuzz # Fuzz tests (30s each) -task test:e2e # Playwright browser tests - -# Quality -task lint # golangci-lint -task check # lint + test + build - -# Production -task build # Embedded assets binary -go build -o koopa ./... # Manual build -``` diff --git a/docs/proposals/001-phase-a-remove-templ-htmx-json-api-skeleton.md b/docs/proposals/001-phase-a-remove-templ-htmx-json-api-skeleton.md deleted file mode 100644 index 3e4f91e..0000000 --- a/docs/proposals/001-phase-a-remove-templ-htmx-json-api-skeleton.md +++ /dev/null @@ -1,306 +0,0 @@ -# Proposal 001: Phase A — Remove Templ/HTMX, Build JSON API Skeleton - -> Status: PENDING REVIEW -> Author: Claude Code -> Date: 2026-02-04 - ---- - -## Summary - -Remove all Templ/HTMX/SSE frontend code from `internal/web/`, migrate business logic to a new `internal/api/` package, and establish a JSON REST API skeleton. This is an atomic operation — `koopa serve` must work before and after. - ---- - -## 2.1 Proposed Directory Structure - -``` -internal/api/ -├── server.go # HTTP server, route registration, CORS, security headers -├── response.go # JSON response helpers: WriteJSON, WriteError, standard envelope -├── middleware/ -│ ├── recovery.go # Panic recovery (from web/middleware.go) -│ ├── logging.go # Request logging (from web/middleware.go) -│ ├── cors.go # CORS for Angular dev server (NEW) -│ └── auth.go # Session cookie + CSRF validation (from web/handlers/sessions.go + web/middleware.go) -└── v1/ - ├── chat.go # POST /api/v1/chat/send, GET /api/v1/chat/stream (JSON SSE) - ├── chat_test.go # Chat handler tests (JSON response validation) - ├── sessions.go # GET/POST/DELETE /api/v1/sessions, session auth logic - ├── sessions_test.go # Session handler tests - ├── health.go # GET /api/v1/health, GET /api/v1/ready - └── health_test.go # Health handler tests -``` - -### Design Rationale - -**Why `internal/api/` not `internal/web/` rename?** -- Clean break. No leftover references to templ/htmx imports. -- `internal/web/` will be re-used later for Angular `dist/` embedding (Phase B+). - -**Why `v1/` versioning?** -- API versioning from day one. Angular client will target `/api/v1/`. -- Future breaking changes go to `v2/` without disrupting existing clients. - -**Why `middleware/` as a sub-package?** -- Follows existing pattern separation (middleware was already logically separate in `web/middleware.go`). -- Allows middleware to be tested independently. - -**Why `response.go` at package root?** -- Shared by all `v1/` handlers. Avoids circular imports. -- Single source of truth for JSON envelope format. - ---- - -## Response Envelope - -All JSON responses follow a consistent envelope: - -```go -// internal/api/response.go - -// Envelope is the standard JSON response wrapper. -type Envelope struct { - Data any `json:"data,omitempty"` - Error *Error `json:"error,omitempty"` -} - -type Error struct { - Code string `json:"code"` - Message string `json:"message"` -} - -func WriteJSON(w http.ResponseWriter, status int, data any) { ... } -func WriteError(w http.ResponseWriter, status int, code, message string) { ... } -``` - ---- - -## Route Mapping (Old → New) - -| Old Route | New Route | Method | Handler | Notes | -|-----------|-----------|--------|---------|-------| -| `GET /genui` | _(removed)_ | — | — | Page rendering, replaced by Angular | -| `POST /genui/send` | `POST /api/v1/chat/send` | POST | `v1.Chat.Send` | Returns JSON instead of HTML | -| `GET /genui/stream` | `GET /api/v1/chat/stream` | GET | `v1.Chat.Stream` | JSON SSE (not HTML SSE) | -| `GET /genui/sessions` | `GET /api/v1/sessions` | GET | `v1.Sessions.List` | JSON array | -| `POST /genui/sessions` | `POST /api/v1/sessions` | POST | `v1.Sessions.Create` | JSON response | -| `DELETE /genui/sessions/{id}` | `DELETE /api/v1/sessions/{id}` | DELETE | `v1.Sessions.Delete` | JSON response | -| `GET /health` | `GET /api/v1/health` | GET | `v1.Health.Health` | Same logic | -| `GET /ready` | `GET /api/v1/ready` | GET | `v1.Health.Ready` | Same logic | -| — | `OPTIONS /api/v1/*` | OPTIONS | CORS middleware | NEW: Preflight | - ---- - -## Streaming Strategy - -The SSE **transport protocol** is kept (it's standard HTTP, not HTMX-specific). What changes is the **content format**: - -**Before (HTML SSE)**: -``` -event: chunk -data:
Hello
- -event: done -data:
...
-``` - -**After (JSON SSE)**: -``` -event: chunk -data: {"msgId":"abc-123","text":"Hello"} - -event: tool_start -data: {"msgId":"abc-123","tool":"web_search","message":"Searching..."} - -event: tool_complete -data: {"msgId":"abc-123","tool":"web_search","message":"Done"} - -event: done -data: {"msgId":"abc-123","sessionId":"xyz","title":"Chat Title"} - -event: error -data: {"msgId":"abc-123","code":"timeout","message":"Request timed out"} -``` - -This preserves the existing streaming architecture (Genkit Flow → SSE) while making the output consumable by Angular's `EventSource` API. The SSE writer is rewritten as a simple JSON event writer in `internal/api/v1/chat.go` — no separate `sse/` package needed. - ---- - -## Middleware Stack - -``` -Request - → CORS (new) - → Recovery (from web/middleware.go) - → Logging (from web/middleware.go) - → Auth (session cookie + CSRF, from web/middleware.go + handlers/sessions.go) - → Route handler -``` - -**CORS middleware** (`middleware/cors.go`): -- Reads allowed origins from config (`KOOPA_CORS_ORIGINS`, default: `http://localhost:4200`) -- Handles preflight `OPTIONS` requests (204 No Content) -- Sets `Access-Control-Allow-Origin`, `Allow-Methods`, `Allow-Headers`, `Allow-Credentials` - -**Auth middleware** (`middleware/auth.go`): -- Consolidates session cookie reading + CSRF validation from the old `RequireSession` + `RequireCSRF` + `Sessions` struct. -- HMAC-SHA256 token logic preserved exactly as-is. -- Pre-session CSRF pattern preserved for lazy session creation. -- CSRF token read from `X-CSRF-Token` header (instead of form field) for JSON API compatibility. - -**MethodOverride middleware**: Removed. JSON API uses proper HTTP methods directly. - ---- - -## Business Logic Migration Map - -| Source | Destination | What Migrates | -|--------|-------------|---------------| -| `web/middleware.go` RecoveryMiddleware | `api/middleware/recovery.go` | Panic catch, error response (→ JSON) | -| `web/middleware.go` LoggingMiddleware | `api/middleware/logging.go` | Request logging (unchanged) | -| `web/middleware.go` RequireSession | `api/middleware/auth.go` | Lazy session creation logic | -| `web/middleware.go` RequireCSRF | `api/middleware/auth.go` | CSRF validation logic | -| `web/handlers/sessions.go` HMAC logic | `api/middleware/auth.go` | Token gen/validation (unchanged) | -| `web/handlers/sessions.go` cookie logic | `api/middleware/auth.go` | Cookie set/read (unchanged) | -| `web/handlers/sessions.go` List/Create/Delete | `api/v1/sessions.go` | CRUD → JSON responses | -| `web/handlers/chat.go` Send | `api/v1/chat.go` | Content validation, session handling → JSON | -| `web/handlers/chat.go` Stream | `api/v1/chat.go` | Flow streaming → JSON SSE | -| `web/handlers/chat.go` title gen | `api/v1/chat.go` | AI title generation (unchanged) | -| `web/handlers/chat.go` error classify | `api/v1/chat.go` | Error classification (unchanged) | -| `web/handlers/health.go` | `api/v1/health.go` | Health probes (unchanged) | -| `web/handlers/tool_display.go` | `api/v1/chat.go` (inline) | Tool display messages (simplified) | -| `web/handlers/convert.go` extractTextContent | `api/v1/sessions.go` | Message text extraction (unchanged) | -| `web/sse/writer.go` SSE format | `api/v1/chat.go` (inline) | SSE protocol (`event:`, `data:`, flush) | - ---- - -## What Gets Deleted (No Migration) - -| File/Dir | Reason | -|----------|--------| -| `internal/web/page/` | Templ page templates | -| `internal/web/layout/` | Templ layout templates | -| `internal/web/component/` | Templ components + 222 reference blocks | -| `internal/web/static/` | CSS, JS, embedded assets | -| `internal/web/sse/` | HTML SSE writer (replaced by inline JSON SSE) | -| `internal/web/e2e/` | Playwright test assets | -| `internal/web/handlers/htmx.go` | HTMX detection helper | -| `internal/web/handlers/pages.go` | HTML page rendering | -| `internal/web/handlers/tool_emitter.go` | SSE-specific tool emitter | -| `internal/web/fixture_test.go` | Playwright fixtures | -| `internal/web/e2e_*.go` | All E2E browser tests | -| `internal/web/server_test.go` | Tests for HTML server | -| `internal/web/static/assets_test.go` | Tests for embedded assets | -| `build/frontend/` | Tailwind CSS build config | - ---- - -## Wire DI Impact - -**No changes to `wire.go` or `wire_gen.go`.** - -`web.NewServer` was never part of Wire — it's created manually in `cmd/serve.go`. The new `api.NewServer` will also be created manually in `cmd/serve.go`, using the same Runtime components. - ---- - -## cmd/serve.go Changes - -```go -// Before -webServer, err := web.NewServer(web.ServerConfig{ - Logger: logger, - Genkit: runtime.App.Genkit, - ChatFlow: runtime.Flow, - SessionStore: runtime.App.SessionStore, - CSRFSecret: []byte(cfg.HMACSecret), - Config: cfg, -}) - -// After -apiServer := api.NewServer(api.ServerConfig{ - Logger: logger, - Genkit: runtime.App.Genkit, - ChatFlow: runtime.Flow, - SessionStore: runtime.App.SessionStore, - CSRFSecret: []byte(cfg.HMACSecret), - CORSOrigins: cfg.CORSOrigins, - IsDev: cfg.Debug, -}) -``` - ---- - -## Config Changes - -**New fields in `config.Config`**: -```go -CORSOrigins []string // from KOOPA_CORS_ORIGINS, default: ["http://localhost:4200"] -``` - -**New env vars in `.env.example`**: -```bash -KOOPA_CORS_ORIGINS=http://localhost:4200 # Comma-separated allowed origins -``` - -**Removed env vars**: None. `HMAC_SECRET` stays (still used for CSRF). - ---- - -## Dependency Removal - -| Dependency | Action | Reason | -|------------|--------|--------| -| `github.com/a-h/templ` | Remove | No more templ templates | -| `github.com/playwright-community/playwright-go` | Remove | Only used in `internal/web/` E2E tests | - -Verified: `cmd/e2e_test.go` does NOT use playwright (uses `os/exec` only). - ---- - -## Taskfile.yml Changes - -| Task | Action | -|------|--------| -| `css` | Remove | -| `css:watch` | Remove | -| `generate` | Change to `sqlc generate` only (remove `templ generate`) | -| `install:templ` | Remove | -| `install:npm` | Remove (no more build/frontend) | -| `build` | Remove templ/css deps | -| `build:dev` | Remove templ/css deps | -| `test:e2e` | Remove | -| `fmt` | Remove `templ fmt` | -| `dev` | Remove templ/css generation | - ---- - -## Acceptance Criteria - -All criteria from the user's spec apply. Additionally: - -1. `go build ./...` — zero errors -2. `golangci-lint run ./...` — no errors -3. `go test ./...` — all pass (deleted web tests excluded) -4. `go vet ./...` — clean -5. `koopa serve` — starts, responds to `/api/v1/health` -6. `koopa` (CLI) — unaffected -7. `koopa mcp` — unaffected -8. CORS preflight returns 204 with correct headers -9. No `.templ` files remain -10. No `htmx` references in Go code -11. No `github.com/a-h/templ` in go.mod -12. `internal/web/` directory does not exist - ---- - -## Risk Assessment - -| Risk | Mitigation | -|------|-----------| -| Wire DI breaks | Wire is not involved — server created in cmd/serve.go | -| Import cycle | `api/` depends on `session/`, `agent/chat/`, `config/` — same as `web/` did | -| Missing business logic | Comprehensive migration map above; each handler verified | -| Streaming breaks | SSE transport preserved, only content format changes (HTML → JSON) | -| Tests fail | Migrated tests adapted to JSON assertions | -| go.mod stale deps | `go mod tidy` cleans up automatically | diff --git a/docs/proposals/001-v010-release-plan.md b/docs/proposals/001-v010-release-plan.md deleted file mode 100644 index cc4484a..0000000 --- a/docs/proposals/001-v010-release-plan.md +++ /dev/null @@ -1,315 +0,0 @@ -# Proposal 001: v0.1.0 Release Plan - -## Summary - -Prepare Koopa for its first public release by: -1. Cleaning up code quality (testify removal, skipped tests) -2. Adding multi-model support (Gemini + Ollama + OpenAI) -3. Exposing RAG tools via MCP -4. Writing README and release automation - -## Scope - -**In scope (v0.1.0):** -- testify → stdlib + cmp.Diff migration (16 files) -- Multi-model: Gemini (current) + Ollama + OpenAI via Genkit plugins -- Config-driven model/provider selection -- MCP server exposes Knowledge/RAG tools -- README with install instructions and quick start -- goreleaser for macOS (arm64/amd64) + Linux (arm64/amd64) -- Complete skipped tests where feasible - -**Out of scope (deferred to v0.2.0):** -- Claude/Anthropic support (Genkit plugin disables tool use) -- MCP Client (connecting to external MCP servers) -- Hybrid RAG (vector + keyword search) -- Extensible user-defined tool system -- Agent loop enhancement (multi-step task decomposition) - -## Detailed Design - -### 1. testify Migration (16 files) - -**What changes:** -Replace all `assert.*` / `require.*` calls with stdlib patterns: - -```go -// Before -assert.Equal(t, want, got) -require.NoError(t, err) - -// After -if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("FuncName() mismatch (-want +got):\n%s", diff) -} -if err != nil { - t.Fatalf("FuncName() unexpected error: %v", err) -} -``` - -**Files affected (16):** -- `internal/agent/chat/integration_rag_test.go` -- `internal/agent/chat/integration_test.go` -- `internal/agent/chat/integration_streaming_test.go` -- `internal/agent/chat/flow_test.go` -- `internal/agent/chat/chat_test.go` -- `internal/tools/system_integration_test.go` -- `internal/tools/register_test.go` -- `internal/tools/file_integration_test.go` -- `internal/tools/network_integration_test.go` -- `internal/session/integration_test.go` -- `internal/rag/system_test.go` -- `internal/mcp/integration_test.go` -- `internal/observability/datadog_test.go` -- `internal/testutil/postgres.go` -- `internal/testutil/sse.go` -- `cmd/e2e_test.go` - -After migration, remove `github.com/stretchr/testify` from `go.mod`. - -**Risk:** Low. Mechanical replacement, no logic changes. - -### 2. Multi-Model Support - -**Current state:** -- `wire.go:119` hardcodes `&googlegenai.GoogleAI{}` as the only AI plugin -- `wire.go:134` hardcodes `googlegenai.GoogleAIEmbedder()` as embedder -- `prompts/koopa.prompt:2` hardcodes `model: googleai/gemini-2.5-flash` - -**Design:** - -#### 2a. Config Changes (`internal/config/config.go`) - -Add provider field: - -```go -type Config struct { - // AI configuration - Provider string `mapstructure:"provider" json:"provider"` // "gemini" (default), "ollama", "openai" - ModelName string `mapstructure:"model_name" json:"model_name"` // e.g. "gemini-2.5-flash", "llama3.3", "gpt-4o" - // ... existing fields ... - - // Ollama configuration - OllamaHost string `mapstructure:"ollama_host" json:"ollama_host"` // default: "http://localhost:11434" - - // OpenAI configuration (env: OPENAI_API_KEY) - // No config field needed - key read by Genkit plugin from env -} -``` - -Defaults: -```yaml -provider: gemini -model_name: gemini-2.5-flash -ollama_host: http://localhost:11434 -``` - -Env overrides: -``` -KOOPA_PROVIDER=ollama -KOOPA_MODEL_NAME=llama3.3 -KOOPA_OLLAMA_HOST=http://localhost:11434 -OPENAI_API_KEY=sk-... # for openai provider -``` - -#### 2b. Dynamic Plugin Loading (`internal/app/wire.go`) - -Replace hardcoded GoogleAI with provider switch: - -```go -func provideGenkit(ctx context.Context, cfg *config.Config, _ OtelShutdown, postgres *postgresql.Postgres) (*genkit.Genkit, error) { - plugins := []genkit.Plugin{postgres} // PostgreSQL always needed for RAG - - switch cfg.Provider { - case "gemini", "": - plugins = append(plugins, &googlegenai.GoogleAI{}) - case "ollama": - plugins = append(plugins, &ollama.Ollama{ServerAddress: cfg.OllamaHost}) - case "openai": - plugins = append(plugins, &openai.OpenAI{}) - default: - return nil, fmt.Errorf("unsupported provider: %q", cfg.Provider) - } - - g := genkit.Init(ctx, - genkit.WithPlugins(plugins...), - genkit.WithPromptDir(promptDir), - ) - return g, nil -} -``` - -#### 2c. Dynamic Model in Dotprompt - -The dotprompt `model:` field determines which model Genkit uses. Two options: - -**Option A: Override at runtime (preferred)** -Keep the dotprompt as-is with a default model. Override the model when executing the prompt via `ai.WithModel()` option. This requires checking if Genkit's prompt execution API supports model override. - -**Option B: Generate dotprompt at startup** -Write the dotprompt file dynamically based on config. This is fragile and not recommended. - -**Option C: Multiple dotprompt files** -Create `prompts/koopa-gemini.prompt`, `prompts/koopa-ollama.prompt`, etc. Select based on config. Duplicates prompt content. - -**Recommendation: Option A** — override at execution time. The `ai.WithModel()` option in `ai.PromptExecuteOption` allows this. The dotprompt file remains the default/fallback. - -Implementation in `chat.go`: -```go -// In generateResponse(), add model override to opts -if modelOverride != "" { - model := genkit.LookupModel(c.g, modelOverride) - if model != nil { - opts = append(opts, ai.WithModel(model)) - } -} -``` - -The model name format follows Genkit convention: -- Gemini: `googleai/gemini-2.5-flash` -- Ollama: `ollama/llama3.3` -- OpenAI: `openai/gpt-4o` - -Config `model_name` stores the short name (`gemini-2.5-flash`), and the provider prefix is derived from `cfg.Provider`. - -#### 2d. Embedder Selection - -Embedder is needed for RAG (pgvector). Options per provider: - -| Provider | Embedder | Model | -|----------|----------|-------| -| gemini | `googlegenai.GoogleAIEmbedder` | text-embedding-004 | -| ollama | `ollama.Embedder` | nomic-embed-text | -| openai | `openai.Embedder` | text-embedding-3-small | - -```go -func provideEmbedder(g *genkit.Genkit, cfg *config.Config) ai.Embedder { - switch cfg.Provider { - case "ollama": - return ollama.Embedder(g, cfg.EmbedderModel) - case "openai": - return openai.Embedder(g, cfg.EmbedderModel) - default: - return googlegenai.GoogleAIEmbedder(g, cfg.EmbedderModel) - } -} -``` - -**Note:** Switching embedder provider changes vector dimensions. Existing pgvector data from one embedder is incompatible with another. This is acceptable for v0.1.0 (users start fresh). Document this limitation. - -#### 2e. Validation - -- `gemini`: Require `GEMINI_API_KEY` env var -- `ollama`: Require `ollama_host` reachable (health check at startup) -- `openai`: Require `OPENAI_API_KEY` env var - -### 3. MCP RAG Tools - -**Current state:** MCP server exposes File, System, Network tools but NOT Knowledge tools. - -**Change:** Register knowledge tools in MCP server alongside existing tools. - -**File:** `internal/mcp/server.go` (or wherever MCP tools are registered) - -Add the 3 knowledge tools: -- `search_history` — search past conversations -- `search_documents` — search indexed documents -- `search_system_knowledge` — search system knowledge base - -**Risk:** Low. Tools already exist and are tested. Just wire them into MCP registration. - -### 4. README - -Structure: -``` -# Koopa -One-line description. - -## Features -- Terminal AI assistant (TUI) -- JSON REST API with SSE streaming -- MCP server for IDE integration -- 13 built-in tools (file, system, network, knowledge) -- RAG with pgvector -- Multi-model (Gemini, Ollama, OpenAI) - -## Quick Start -### Prerequisites -- Go 1.23+ -- PostgreSQL 17 with pgvector -- Docker (for SearXNG + Redis) - -### Install -go install github.com/koopa0/koopa@latest -# or download from releases - -### Setup -docker compose up -d -export GEMINI_API_KEY="your-key" -koopa - -### Using Ollama (local models) -ollama pull llama3.3 -export KOOPA_PROVIDER=ollama -export KOOPA_MODEL_NAME=llama3.3 -koopa - -## Configuration -## Architecture -## License -``` - -### 5. Release Automation (goreleaser) - -Create `.goreleaser.yml`: -- Platforms: macOS (arm64, amd64), Linux (arm64, amd64) -- Binary name: `koopa` -- GitHub release with checksums -- Homebrew tap (optional, can add later) - -### 6. Skipped Tests - -7 skipped tests identified: -- 5 in `internal/api/session_test.go` — need PostgreSQL integration -- 2 in `cmd/e2e_test.go` — need MCP test harness - -**Action:** -- `session_test.go`: Convert to testcontainers-based integration tests -- `e2e_test.go`: Assess feasibility. If MCP test harness is complex, document as known limitation. - -## Implementation Order - -| Step | Task | Depends On | Estimated Files | -|------|------|------------|-----------------| -| 1 | testify migration | none | 16 test files | -| 2 | Remove testify from go.mod | step 1 | go.mod, go.sum | -| 3 | Config: add Provider, OllamaHost fields | none | config.go, validation.go, config_test.go | -| 4 | wire.go: dynamic plugin loading | step 3 | wire.go, wire_gen.go | -| 5 | chat.go: model override at execution | step 4 | chat.go, chat_test.go | -| 6 | Embedder: provider-based selection | step 4 | wire.go | -| 7 | MCP: register knowledge tools | none | mcp/server.go or tools registration | -| 8 | go.mod: add ollama, openai plugins | step 4 | go.mod, go.sum | -| 9 | Skipped tests | none | session_test.go, e2e_test.go | -| 10 | README | after all features | README.md | -| 11 | goreleaser config | none | .goreleaser.yml | -| 12 | Final verification | all steps | - | - -## Risks - -| Risk | Impact | Mitigation | -|------|--------|------------| -| Genkit Ollama plugin bugs | Tool calling fails | Test with popular models (llama3.3, qwen2.5) early | -| Embedder dimension mismatch | RAG breaks when switching provider | Document: switching provider requires fresh DB | -| Wire regeneration | Build breaks | Run `wire ./internal/app/` after wire.go changes | -| Dotprompt model override | May not work as expected | Verify `ai.WithModel()` in prompt execution; fallback to Option C if needed | - -## Open Questions - -1. Should `config.yaml` support per-command provider override? (e.g., TUI uses Ollama but API uses Gemini) - - **Recommendation:** No. Single provider per instance for v0.1.0. - -2. Should we add a `koopa config` CLI command to set provider/model interactively? - - **Recommendation:** Defer. Environment variables and config file are sufficient. - -3. Default embedder model names for Ollama/OpenAI — need to verify exact model IDs. - - **Action:** Check during implementation. diff --git a/docs/proposals/002-v020-comprehensive-refactoring.md b/docs/proposals/002-v020-comprehensive-refactoring.md deleted file mode 100644 index 3b078ad..0000000 --- a/docs/proposals/002-v020-comprehensive-refactoring.md +++ /dev/null @@ -1,409 +0,0 @@ -# Proposal 002: v0.2.0 Comprehensive Refactoring - -## Overview - -Based on a full-codebase comprehension (5 parallel agents analyzed every layer), this proposal defines a phased refactoring plan to bring Koopa to production quality. - -**Scope**: From `main.go` to every internal package. All 3 modes (CLI, API, MCP) preserved. -**Priority**: CLI + HTTP API first, MCP in parallel where non-conflicting. -**Framework**: Genkit as core AI orchestration, Bubble Tea v2 for TUI. -**Storage**: PostgreSQL + pgvector retained as required dependency. - ---- - -## Phase 1: Critical Bug Fixes (Low Risk, Immediate) - -No architecture changes. Pure bug fixes with existing patterns. - -### 1.1 Fix `listenForStream()` Stack Overflow - -**File**: `internal/tui/commands.go` -**Problem**: Recursive call on empty events can overflow stack. -**Fix**: Replace recursion with `for` loop. - -```go -// Before (recursive): -default: - return listenForStream(eventCh)() - -// After (iterative): -for { - event, ok := <-eventCh - if !ok { return streamErrorMsg{err: ...} } - if event.err != nil { return streamErrorMsg{...} } - if event.done { return streamDoneMsg{...} } - if event.text != "" { return streamTextMsg{...} } - // empty event: loop instead of recurse -} -``` - -### 1.2 Fix `truncateHistory` Break Logic - -**File**: `internal/agent/chat/tokens.go` -**Problem**: Uses `break` instead of `continue`, drops old small messages after first large one. -**Fix**: Change `break` to `continue`. - -### 1.3 Remove Ghost `knowledge_store` Reference - -**File**: `internal/api/chat.go:59` -**Problem**: UI message references a tool that doesn't exist. -**Fix**: Remove the `knowledge_store` entry from toolDisplayMessages map. - -### 1.4 Remove Dead `NormalizeMaxHistoryMessages` - -**File**: `internal/config/validation.go` -**Problem**: Function defined but never called. -**Fix**: Delete the function (validation already handles range checks in `Validate()`). - -### 1.5 Fix RAG Non-Atomic UPSERT - -**File**: `internal/rag/system.go` -**Problem**: Delete + Insert without transaction; delete failure silently swallowed. -**Fix**: Wrap in PostgreSQL transaction. - ---- - -## Phase 2: Remove Wire, Simplify DI (Medium Risk) - -Wire adds indirection and `wire_gen.go` is hand-maintained despite claiming to be generated. - -### 2.1 Replace Wire with Manual DI - -**Delete**: `internal/app/wire.go`, `internal/app/wire_gen.go` -**Modify**: `internal/app/app.go` - -Replace Wire-generated `InitializeApp()` with explicit construction: - -```go -func InitializeApp(ctx context.Context, cfg *config.Config) (*App, func(), error) { - // 1. OTel setup - otelShutdown, err := observability.Setup(ctx, cfg.Datadog) - - // 2. DB pool + migrations - pool, err := connectDB(ctx, cfg) - - // 3. Genkit init - g := initGenkit(ctx, cfg) - - // 4. Embedder - embedder := lookupEmbedder(g, cfg) - - // 5. RAG components - docStore, retriever := initRAG(ctx, g, pool, embedder, cfg) - - // 6. Session store - sessionStore := session.NewStore(pool) - - // 7. Security validators - pathValidator := security.NewPath(...) - - // 8. Tools - tools := registerTools(g, pathValidator, ...) - - // 9. Construct App - app := &App{...} - - cleanup := func() { - pool.Close() - otelShutdown(ctx) - } - - return app, cleanup, nil -} -``` - -### 2.2 Unify Tool Registration - -**Problem**: Tools are created twice — once for Genkit (wire.go), once for MCP (cmd/mcp.go). -**Fix**: Create tools once, register to both Genkit and MCP from same instances. - -```go -// internal/tools/registry.go (new file) -type Registry struct { - File *FileTools - System *SystemTools - Network *NetworkTools - Knowledge *KnowledgeTools -} - -func NewRegistry(pathValidator *security.Path, ...) *Registry { ... } - -// Register to Genkit -func (r *Registry) RegisterGenkit(g *genkit.Genkit) ([]ai.Tool, error) { ... } -``` - -MCP server accepts `*tools.Registry` instead of individual tool structs. - -### 2.3 Refactor App Struct - -Split the God Object into focused structs: - -```go -type App struct { - Config *config.Config - AI *AIComponents // Genkit, Embedder - Storage *StorageComponents // DBPool, DocStore, Retriever, SessionStore - Tools *tools.Registry - - // Lifecycle - ctx context.Context - cancel context.CancelFunc - eg *errgroup.Group -} - -type AIComponents struct { - Genkit *genkit.Genkit - Embedder ai.Embedder -} - -type StorageComponents struct { - Pool *pgxpool.Pool - DocStore *postgresql.DocStore - Retriever ai.Retriever - SessionStore *session.Store -} -``` - ---- - -## Phase 3: TUI Upgrade to Bubble Tea v2 (High Value) - -### 3.1 Dependency Upgrade - -``` -charm.land/bubbletea/v2 -charm.land/bubbles/v2 -charm.land/lipgloss/v2 -github.com/charmbracelet/glamour (for markdown rendering) -``` - -### 3.2 Rewrite TUI with v2 Patterns - -**Key changes**: -- `View()` returns `tea.View` struct (declarative altscreen, cursor) -- `tea.KeyPressMsg` replaces `tea.KeyMsg` -- `tea.PasteMsg` for multi-line paste -- Synchronized output (Mode 2026) reduces flicker during streaming - -**Streaming pattern** (replaces recursive listenForStream): -```go -func waitForStream(ch <-chan streamEvent) tea.Cmd { - return func() tea.Msg { - for { - event, ok := <-ch - if !ok { return streamDoneMsg{} } - if event.err != nil { return streamErrorMsg{event.err} } - if event.done { return streamDoneMsg{event.output} } - if event.text != "" { return streamChunkMsg{event.text} } - } - } -} -``` - -### 3.3 Add Glamour Markdown Rendering - -LLM responses contain markdown. Currently rendered as plain text. -- Use `glamour` for completed messages -- Raw text during active streaming (partial markdown may not parse) -- Re-render with glamour when stream completes - -### 3.4 Improve TUI Components - -- `bubbles/viewport` for scrollable chat history -- `bubbles/textarea` for multi-line input (shift+enter for newlines in v2) -- `bubbles/spinner` for "thinking" indicator -- `bubbles/help` for keyboard shortcuts -- `lipgloss` for message bubble styling (user vs assistant) - ---- - -## Phase 4: API Improvements + Angular UI Spec (Medium Risk) - -### 4.1 Fix Session Ownership - -Add session-cookie based authorization: -```go -func (sm *sessionManager) authorizeSession(r *http.Request, sessionID uuid.UUID) error { - cookieSessionID := sm.getSessionID(r) - if cookieSessionID != sessionID { - return ErrForbidden - } - return nil -} -``` - -### 4.2 Security Headers - -Add missing HSTS header: -```go -w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") -``` - -### 4.3 Add API Versioning - -Prefix all routes with `/api/v1/`: -``` -/api/v1/csrf-token -/api/v1/sessions -/api/v1/sessions/{id} -/api/v1/sessions/{id}/messages -/api/v1/chat -/api/v1/chat/stream -``` - -### 4.4 Standardize Response Envelope - -```json -{ - "data": { ... }, - "error": null -} -``` - -or on error: -```json -{ - "data": null, - "error": { "code": "not_found", "message": "session not found" } -} -``` - -### 4.5 Angular UI Spec Document - -Create `docs/api-spec.md` with: -- Complete endpoint documentation (request/response schemas) -- SSE event types and payloads -- Authentication flow (CSRF tokens, session cookies) -- WebSocket upgrade path (future consideration) -- Error code reference - ---- - -## Phase 5: RAG & Knowledge Improvements - -### 5.1 Update System Knowledge Content - -Current 6 documents are outdated. Update to match actual registered tools. - -### 5.2 Implement Knowledge Store Tool - -The UI references `knowledge_store` but it doesn't exist. Either: -- Option A: Implement it (user can save conversations/documents to knowledge base) -- Option B: Remove UI reference (done in Phase 1.3) - -Decision deferred to user. - -### 5.3 Improve Token Estimation - -Replace `rune/2` with proper estimation: -```go -func estimateTokens(text string) int { - // Use tiktoken-go or similar for accurate counting - // Fallback to heuristic per language detection -} -``` - -Or use Genkit's built-in token counter if available. - ---- - -## Phase 6: Security Hardening - -### 6.1 Fix Prompt Injection Homoglyph Bypass - -Add Unicode NFKD normalization before pattern matching: -```go -import "golang.org/x/text/unicode/normalize" - -normalized := normalize.NFKD.String(input) -// Then apply regex patterns -``` - -### 6.2 Enforce SafeTransport Usage - -`SafeTransport()` exists but is never used in production. Wire it into network tools. - -### 6.3 Command Whitelist Gaps - -Add missing blocked patterns: -- `npm install` (postinstall hooks execute code) -- `go test -run` (executes test code) - ---- - -## Phase 7: Testing Overhaul - -### 7.1 Remove Pointless Tests - -Tests that only call `t.Skip()` with no body add no value. -(Most already removed in Phase A cleanup.) - -### 7.2 Add Missing Critical Tests - -| Test | Module | Why | -|------|--------|-----| -| Stream empty event handling | tui | Validates Phase 1.1 fix | -| Concurrent stream start/cancel | tui | Race condition coverage | -| deepCopyMessages mutation | chat | Direct unit test (currently indirect) | -| MCP JSON-RPC protocol | mcp | No protocol-level tests exist | -| Session ownership check | api | Validates Phase 4.1 fix | - -### 7.3 TUI Testing with teatest - -Use `charmbracelet/x/exp/teatest` for golden file testing of TUI output. - ---- - -## Execution Order - -Phases can be partially parallelized: - -``` -Phase 1 (bug fixes) ─────────────────────────→ immediate -Phase 2 (wire removal, DI) ──────────────────→ after Phase 1 -Phase 3 (TUI v2) ──────┐ after Phase 2 -Phase 4 (API + spec) ──┤ parallel after Phase 2 -Phase 5 (RAG) ─────────┘ after Phase 2 -Phase 6 (security) ──────────────────────────→ after Phase 4 -Phase 7 (testing) ───────────────────────────→ continuous -``` - -**CLI and API work don't conflict** — Phase 3 (TUI) and Phase 4 (API) can run in parallel after Phase 2 establishes the new DI structure. - ---- - -## Out of Scope - -- Angular Web UI implementation (separate project) -- New AI providers beyond gemini/ollama/openai -- Multi-user authentication system -- WebSocket support (future consideration) -- Kubernetes deployment manifests - ---- - -## Risk Assessment - -| Phase | Risk | Mitigation | -|-------|------|------------| -| 1 (bugs) | Low | Pure fixes, existing tests validate | -| 2 (wire) | Medium | Manual DI is well-understood; test suite validates | -| 3 (TUI v2) | High | v2 is RC.2, API may change; pin specific version | -| 4 (API) | Medium | Existing tests cover endpoints; add ownership tests | -| 5 (RAG) | Low | Isolated module, minimal cross-dependencies | -| 6 (security) | Medium | Security changes need careful review | -| 7 (testing) | Low | Only adds tests, no production code risk | - ---- - -## Success Criteria - -- [ ] `go build ./...` passes -- [ ] `go vet ./...` passes -- [ ] `golangci-lint run ./...` — 0 issues -- [ ] `go test -race ./...` — all pass -- [ ] All 4 red issues resolved -- [ ] Wire removed, manual DI working -- [ ] TUI running on Bubble Tea v2 -- [ ] API spec document complete -- [ ] No regressions in existing functionality diff --git a/docs/proposals/003-naming-conventions-and-dead-code.md b/docs/proposals/003-naming-conventions-and-dead-code.md deleted file mode 100644 index 455415a..0000000 --- a/docs/proposals/003-naming-conventions-and-dead-code.md +++ /dev/null @@ -1,324 +0,0 @@ -# Proposal 003: Naming Conventions & Dead Code Cleanup - -## Overview - -Based on full-codebase go-reviewer and security-reviewer findings (post-Proposal 002), this proposal addresses Go naming convention violations and dead code in the session layer. - -**Scope**: `session.Store` method renames, dead code removal, `clientIP()` security fix. -**Risk**: Medium — public API rename across multiple consumers, but zero interface breakage. -**Tier**: 2 (existing feature modification, no new packages or types). - -### Reviewer Status - -- go-reviewer: 0 BLOCKING, 3 WARNING, 3 SUGGESTION — all addressed below -- security-reviewer: 0 CRITICAL, 1 HIGH, 3 MEDIUM — all addressed below - ---- - -## Phase 1: Remove Dead Code (Low Risk) - -Remove 2 exported methods on `session.Store` that have **zero callers and zero tests**, plus their associated sentinel error. - -### 1.1 Delete `GetUserMessageBefore` - -**File**: `internal/session/store.go:611-658` -**Evidence**: Comprehend agent confirmed 0 production callers, 0 test callers. -**SQL**: `db/queries/sessions.sql:104-113` — `-- name: GetUserMessageBefore :one` - -Delete the Store method and remove the SQL query. - -### 1.2 Delete `GetMessageByID` - -**File**: `internal/session/store.go:660-676` -**Evidence**: Comprehend agent confirmed 0 production callers, 0 test callers. -**SQL**: `db/queries/sessions.sql:115-119` — `-- name: GetMessageByID :one` - -Same approach — delete Store method and SQL query. - -### 1.3 Delete `ErrMessageNotFound` - -**File**: `internal/session/errors.go:42-43` - -After removing the two methods above, `ErrMessageNotFound` has zero remaining references. Remove the sentinel error to avoid dead code. -(go-reviewer SUGGESTION #1 — confirmed via grep: all 5 usages are within the two deleted methods) - -### 1.4 Update doc.go - -Remove any references to deleted methods in `internal/session/doc.go`. - -### 1.5 Regenerate sqlc - -```bash -cd build/sql && sqlc generate -``` - -This removes the generated code for the two deleted SQL queries. - -### Verification - -```bash -go build ./... # Confirms no broken references -go vet ./... # Static analysis -go test ./... # No test regression (methods had 0 tests) -``` - ---- - -## Phase 2: Rename Store Methods (Medium Risk) - -Remove `Get` prefix from 3 `session.Store` getter methods per `naming.md`: -> "No `Get` prefix for getters" - -### 2.1 Rename SQL Queries - -**File**: `db/queries/sessions.sql` - -| Before | After | Line | -|--------|-------|------| -| `-- name: GetSession :one` | `-- name: Session :one` | 9 | -| `-- name: GetMessages :many` | `-- name: Messages :many` | 53 | - -**Note**: `GetHistory` has no SQL query (it's composed in Go from `Session` + `Messages`). - -### 2.2 Regenerate sqlc - -```bash -cd build/sql && sqlc generate -``` - -This produces cascading changes in `internal/sqlc/sessions.sql.go`: -- `func (q *Queries) GetSession(...)` -> `func (q *Queries) Session(...)` -- `func (q *Queries) GetMessages(...)` -> `func (q *Queries) Messages(...)` -- `type GetMessagesParams` -> `type MessagesParams` - -### 2.3 Rename `session.Store` Methods - -**File**: `internal/session/store.go` - -| Before | After | Line | -|--------|-------|------| -| `func (s *Store) GetSession(...)` | `func (s *Store) Session(...)` | 90 | -| `func (s *Store) GetMessages(...)` | `func (s *Store) Messages(...)` | 323 | -| `func (s *Store) GetHistory(...)` | `func (s *Store) History(...)` | 404 | - -Internal self-calls within `store.go`: -- Line 406: `s.GetSession(ctx, sessionID)` -> `s.Session(ctx, sessionID)` -- Line 417: `s.GetMessages(ctx, sessionID, limit, 0)` -> `s.Messages(ctx, sessionID, limit, 0)` - -Internal sqlc calls: -- Line 91: `s.queries.GetSession(ctx, sessionID)` -> `s.queries.Session(ctx, sessionID)` -- Line 324: `s.queries.GetMessages(ctx, sqlc.GetMessagesParams{...})` -> `s.queries.Messages(ctx, sqlc.MessagesParams{...})` - -### 2.4 Update Production Callers - -| File | Line | Before | After | -|------|------|--------|-------| -| `cmd/cli.go` | 64 | `store.GetSession(ctx, *currentID)` | `store.Session(ctx, *currentID)` | -| `internal/api/session.go` | 57 | `sm.store.GetSession(r.Context(), sessionID)` | `sm.store.Session(r.Context(), sessionID)` | -| `internal/api/session.go` | 268 | `sm.store.GetSession(r.Context(), sessionID)` | `sm.store.Session(r.Context(), sessionID)` | -| `internal/api/session.go` | 313 | `sm.store.GetSession(r.Context(), id)` | `sm.store.Session(r.Context(), id)` | -| `internal/api/chat.go` | 321 | `h.sessions.store.GetSession(ctx, sessionUUID)` | `h.sessions.store.Session(ctx, sessionUUID)` | -| `internal/api/session.go` | 340 | `sm.store.GetMessages(r.Context(), id, 100, 0)` | `sm.store.Messages(r.Context(), id, 100, 0)` | -| `internal/agent/chat/chat.go` | 248 | `c.sessions.GetHistory(ctx, sessionID)` | `c.sessions.History(ctx, sessionID)` | - -**7 production call sites total.** - -### 2.5 Update Test Files - -Approximately ~40 test call sites across: -- `internal/session/integration_test.go` (~30 sites) -- `internal/session/benchmark_test.go` (~10 sites) - -All are direct `store.GetSession(...)`, `store.GetMessages(...)`, `store.GetHistory(...)` calls -> rename to `store.Session(...)`, `store.Messages(...)`, `store.History(...)`. - -**Note on `internal/api/session_test.go`**: 4 test functions contain `GetSession`/`GetSessionMessages` in their names (`TestGetSession_InvalidUUID`, `TestGetSession_MissingID`, `TestGetSessionMessages_InvalidUUID`, `TestGetSessionMessages_OwnershipDenied`). These test the HTTP handler methods (`sm.getSession`, `sm.getSessionMessages`) which are *unexported* and *not* being renamed. The test function names are **intentionally left unchanged** as they reference the handler, not the Store method. -(go-reviewer SUGGESTION #2 — addressed) - -### 2.6 Update Documentation - -- `internal/session/doc.go` — update method name references -- `internal/session/errors.go:34` — update doc comment example `store.GetSession(ctx, id)` -> `store.Session(ctx, id)` -- `internal/agent/chat/doc.go` — update `GetHistory` reference - -(go-reviewer WARNING #1 — `errors.go` doc comment added to list) - -### Verification - -```bash -cd build/sql && sqlc generate # Regenerate -go build ./... # Compile check -go vet ./... # Static analysis -golangci-lint run ./... # 0 issues -go test -race ./... # Full test suite -``` - ---- - -## Phase 3: Fix clientIP() Proxy Header Trust (Low-Medium Risk) - -### Problem - -`internal/api/middleware.go:363-383` — `clientIP()` unconditionally trusts `X-Forwarded-For` header. Any client can spoof this header to bypass rate limiting. -(security-reviewer H2 — current vulnerability being fixed) - -### Fix - -Add a `TrustProxy` configuration option. When `false` (default), ignore proxy headers entirely. When `true`, prefer `X-Real-IP` over `X-Forwarded-For` to prevent spoofing. -(security-reviewer H1 — switched to X-Real-IP priority to avoid leftmost-XFF spoofing) - -#### 3.1 Add config chain - -**File**: `internal/config/config.go` -- Add `TrustProxy bool` field to config struct -- Add `viper.SetDefault("trust_proxy", false)` in `setDefaults` -- Add `mustBind("trust_proxy", "KOOPA_TRUST_PROXY")` in `bindEnvVariables` - -(security-reviewer M2 — full config wiring) - -**File**: `config.example.yaml` -- Add `trust_proxy: false` with comment explaining when to set `true` - -(security-reviewer M1 — migration note for existing proxy deployments) - -#### 3.2 Wire through ServerConfig - -**File**: `internal/api/server.go` - -```go -type ServerConfig struct { - // ... existing fields ... - TrustProxy bool // Trust X-Real-IP/X-Forwarded-For headers (set true behind reverse proxy) -} -``` - -**File**: `cmd/serve.go` — pass `cfg.TrustProxy` into `api.ServerConfig` - -#### 3.3 Update clientIP function - -**File**: `internal/api/middleware.go` - -```go -// clientIP extracts the client IP from the request. -// If trustProxy is true, checks X-Real-IP first (single-valued, set by proxy), -// then X-Forwarded-For. Validates extracted IP with net.ParseIP. -// Otherwise, uses RemoteAddr only (safe default for direct exposure). -func clientIP(r *http.Request, trustProxy bool) string { - if trustProxy { - // Prefer X-Real-IP: single value set by the reverse proxy, not spoofable - if xri := r.Header.Get("X-Real-IP"); xri != "" { - ip := strings.TrimSpace(xri) - if net.ParseIP(ip) != nil { - return ip - } - } - // Fallback: X-Forwarded-For (first entry, client-provided — less trustworthy) - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - raw, _, _ := strings.Cut(xff, ",") - ip := strings.TrimSpace(raw) - if net.ParseIP(ip) != nil { - return ip - } - } - } - - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return r.RemoteAddr - } - return ip -} -``` - -Key changes from original proposal (per security review): -1. **X-Real-IP checked first** — single value set by trusted proxy, not appendable by client (H1 fix) -2. **`net.ParseIP` validation** — rejects garbage values, falls back to RemoteAddr (M3 fix) -3. **`trustProxy` parameter** — gated by config, default `false` (existing fix) - -(security-reviewer H1 + M3 — addressed) - -#### 3.4 Update rateLimitMiddleware - -**File**: `internal/api/middleware.go` - -Pass `trustProxy` through closure to `clientIP`: - -```go -func rateLimitMiddleware(rl *rateLimiter, trustProxy bool, logger *slog.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip := clientIP(r, trustProxy) - // ... rest unchanged - }) - } -} -``` - -Update call site in `server.go:83` to pass `cfg.TrustProxy`. - -#### 3.5 Update tests - -**File**: `internal/api/middleware_test.go` - -- Update all 6 `clientIP(r)` calls to `clientIP(r, true)` for existing proxy-trust test cases -- Add new test cases for `clientIP(r, false)` verifying proxy headers are ignored -- Add test case for invalid IP in proxy header (falls back to RemoteAddr) - -(go-reviewer WARNING #2 + SUGGESTION #3 — middleware_test.go added to file list) - -### Verification - -```bash -go build ./... -go vet ./... -golangci-lint run ./... -go test -race ./... -``` - ---- - -## Out of Scope - -- RAG non-atomic upsert — documented, low-risk, runs only at startup -- Secret masking improvements — cosmetic, no security impact -- sqlc `emit_interface` change — unnecessary (no consumers need an interface) -- `GetMaxSequenceNumber` rename — not a getter on Store, it's an internal sqlc query name -- `TestGetSession_*` test function renames in `api/session_test.go` — test HTTP handlers, not Store - ---- - -## Implementation Order - -| Step | Phase | Files Modified | Risk | -|------|-------|---------------|------| -| 1 | Phase 1 | store.go, sessions.sql, errors.go, doc.go | Low | -| 2 | sqlc regenerate | `cd build/sql && sqlc generate` | None | -| 3 | Phase 2 | sessions.sql, store.go, cli.go, session.go, chat.go, doc.go, errors.go, tests | Medium | -| 4 | sqlc regenerate | `cd build/sql && sqlc generate` | None | -| 5 | Phase 3 | config.go, server.go, serve.go, middleware.go, middleware_test.go, config.example.yaml | Low-Medium | -| 6 | Full verification | — | — | - -Each phase is independently verifiable. Phases 1+2 can be combined into a single sqlc regeneration if committed together. -(go-reviewer SUGGESTION #4 — sqlc path made consistent, combination noted) - ---- - -## Risk Assessment - -| Risk | Mitigation | -|------|-----------| -| Missed call site breaks build | `go build ./...` catches immediately | -| sqlc param type rename missed | Compiler error on `sqlc.GetMessagesParams` | -| Test name references old method | Find-and-replace + `go test` catches | -| `clientIP` signature change breaks callers | Only 1 production call site + 6 test calls; compiler catches | -| TrustProxy default=false behind proxy | Documented migration: set `KOOPA_TRUST_PROXY=true` or `trust_proxy: true` | - ---- - -## Estimated Edit Count - -| Phase | Files | Edit Sites | -|-------|-------|-----------| -| Phase 1 (dead code) | 4 | ~7 | -| Phase 2 (rename) | 9 | ~50 | -| Phase 3 (clientIP) | 6 | ~15 | -| **Total** | **~15** | **~72** | diff --git a/go.mod b/go.mod index 4e9ed08..554d324 100644 --- a/go.mod +++ b/go.mod @@ -24,8 +24,8 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 go.opentelemetry.io/otel/sdk v1.38.0 go.uber.org/goleak v1.3.0 - golang.org/x/sync v0.18.0 golang.org/x/time v0.14.0 + google.golang.org/genai v1.41.0 ) require ( @@ -170,12 +170,12 @@ require ( golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/oauth2 v0.33.0 // indirect + golang.org/x/sync v0.18.0 // indirect golang.org/x/sys v0.38.0 // indirect golang.org/x/term v0.37.0 // indirect golang.org/x/text v0.31.0 // indirect google.golang.org/api v0.252.0 // indirect google.golang.org/appengine v1.6.8 // indirect - google.golang.org/genai v1.41.0 // indirect google.golang.org/genproto v0.0.0-20251014184007-4626949a642f // indirect google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect diff --git a/internal/agent/chat/chat_test.go b/internal/agent/chat/chat_test.go deleted file mode 100644 index 8592aea..0000000 --- a/internal/agent/chat/chat_test.go +++ /dev/null @@ -1,565 +0,0 @@ -package chat - -import ( - "context" - "errors" - "log/slog" - "strings" - "testing" - - "github.com/firebase/genkit/go/ai" -) - -// TestNew_ValidationErrors tests constructor validation -func TestNew_ValidationErrors(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - cfg Config - errContains string - }{ - { - name: "nil genkit", - cfg: Config{}, - errContains: "genkit instance is required", - }, - { - name: "nil retriever", - cfg: Config{ - Genkit: nil, // Still nil, so we'll get Genkit error first - }, - errContains: "genkit instance is required", - }, - { - name: "nil logger - requires all previous deps", - cfg: Config{ - // Missing Genkit - }, - errContains: "genkit instance is required", - }, - { - name: "empty tools - requires all previous deps", - cfg: Config{ - // Missing Genkit - }, - errContains: "genkit instance is required", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - _, err := New(tt.cfg) - if err == nil { - t.Fatal("New() expected error, got nil") - } - if !strings.Contains(err.Error(), tt.errContains) { - t.Errorf("New() error = %q, want to contain %q", err.Error(), tt.errContains) - } - }) - } -} - -// TestConstants tests package constants -func TestConstants(t *testing.T) { - t.Parallel() - - t.Run("Name constant", func(t *testing.T) { - t.Parallel() - if Name != "chat" { - t.Errorf("Name = %q, want %q", Name, "chat") - } - }) - - t.Run("Description is not empty", func(t *testing.T) { - t.Parallel() - if Description == "" { - t.Error("Description is empty, want non-empty") - } - }) - - t.Run("KoopaPromptName is set", func(t *testing.T) { - t.Parallel() - if KoopaPromptName != "koopa" { - t.Errorf("KoopaPromptName = %q, want %q", KoopaPromptName, "koopa") - } - }) -} - -// TestStreamCallback_Type tests the StreamCallback type definition -func TestStreamCallback_Type(t *testing.T) { - t.Parallel() - - t.Run("nil callback is valid", func(t *testing.T) { - t.Parallel() - var callback StreamCallback - if callback != nil { - t.Errorf("nil callback = %v, want nil", callback) - } - }) - - t.Run("callback can be assigned", func(t *testing.T) { - t.Parallel() - called := false - callback := StreamCallback(func(_ context.Context, _ *ai.ModelResponseChunk) error { - called = true - return nil - }) - if callback == nil { - t.Fatal("callback is nil, want non-nil") - } - err := callback(context.Background(), nil) - if err != nil { - t.Errorf("callback() unexpected error: %v", err) - } - if !called { - t.Error("callback was not called") - } - }) - - t.Run("callback can return error", func(t *testing.T) { - t.Parallel() - expectedErr := errors.New("test error") - callback := StreamCallback(func(_ context.Context, _ *ai.ModelResponseChunk) error { - return expectedErr - }) - err := callback(context.Background(), nil) - if !errors.Is(err, expectedErr) { - t.Errorf("callback() = %v, want %v", err, expectedErr) - } - }) -} - -// TestConfig_Structure tests the Config struct -func TestConfig_Structure(t *testing.T) { - t.Parallel() - - t.Run("zero value has nil fields", func(t *testing.T) { - t.Parallel() - var cfg Config - if cfg.Genkit != nil { - t.Errorf("cfg.Genkit = %v, want nil", cfg.Genkit) - } - if cfg.Retriever != nil { - t.Errorf("cfg.Retriever = %v, want nil", cfg.Retriever) - } - if cfg.SessionStore != nil { - t.Errorf("cfg.SessionStore = %v, want nil", cfg.SessionStore) - } - if cfg.Logger != nil { - t.Errorf("cfg.Logger = %v, want nil", cfg.Logger) - } - if cfg.Tools != nil { - t.Errorf("cfg.Tools = %v, want nil", cfg.Tools) - } - }) -} - -// TestChat_RetrieveRAGContext_SkipsWhenTopKZero tests RAG context retrieval -func TestChat_RetrieveRAGContext_SkipsWhenTopKZero(t *testing.T) { - t.Parallel() - - t.Run("returns nil when topK is zero", func(t *testing.T) { - t.Parallel() - c := &Chat{ - ragTopK: 0, - logger: slog.Default(), - } - docs := c.retrieveRAGContext(context.Background(), "test query") - if docs != nil { - t.Errorf("retrieveRAGContext() = %v, want nil", docs) - } - }) - - t.Run("returns nil when topK is negative", func(t *testing.T) { - t.Parallel() - c := &Chat{ - ragTopK: -1, - logger: slog.Default(), - } - docs := c.retrieveRAGContext(context.Background(), "test query") - if docs != nil { - t.Errorf("retrieveRAGContext() = %v, want nil", docs) - } - }) -} - -// ============================================================================= -// Edge Case Tests for Real Scenarios -// ============================================================================= - -// TestChat_EmptyResponseHandling tests that empty model responses are handled gracefully. -func TestChat_EmptyResponseHandling(t *testing.T) { - t.Parallel() - - t.Run("empty string triggers fallback", func(t *testing.T) { - t.Parallel() - // Test the logic of empty response detection - responseText := "" - if strings.TrimSpace(responseText) == "" { - responseText = FallbackResponseMessage - } - if !strings.Contains(responseText, "apologize") { - t.Errorf("fallback response = %q, want to contain %q", responseText, "apologize") - } - if responseText == "" { - t.Error("fallback response is empty, want non-empty") - } - }) - - t.Run("whitespace-only triggers fallback", func(t *testing.T) { - t.Parallel() - responseText := " \n\t " - if strings.TrimSpace(responseText) == "" { - responseText = FallbackResponseMessage - } - if !strings.Contains(responseText, "apologize") { - t.Errorf("fallback response = %q, want to contain %q", responseText, "apologize") - } - }) - - t.Run("valid response is preserved", func(t *testing.T) { - t.Parallel() - responseText := "Hello, I'm here to help!" - originalText := responseText - if strings.TrimSpace(responseText) == "" { - responseText = FallbackResponseMessage - } - if responseText != originalText { - t.Errorf("responseText = %q, want %q", responseText, originalText) - } - }) -} - -// TestChat_ContextCancellation tests graceful handling of context cancellation. -func TestChat_ContextCancellation(t *testing.T) { - t.Parallel() - - t.Run("canceled context is detected", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) - cancel() // Cancel immediately - - // Verify context is canceled - if !errors.Is(ctx.Err(), context.Canceled) { - t.Errorf("ctx.Err() = %v, want context.Canceled", ctx.Err()) - } - }) - - t.Run("deadline exceeded is different from canceled", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 0) - defer cancel() - - // Wait for timeout - <-ctx.Done() - - // DeadlineExceeded is different from Canceled - if !errors.Is(ctx.Err(), context.DeadlineExceeded) { - t.Errorf("ctx.Err() = %v, want context.DeadlineExceeded", ctx.Err()) - } - if errors.Is(ctx.Err(), context.Canceled) { - t.Errorf("ctx.Err() = context.Canceled, want context.DeadlineExceeded") - } - }) -} - -// TestChat_MaxTurnsProtection tests that conversation doesn't loop infinitely. -// Safety: Prevents runaway agent loops that could exhaust resources. -func TestChat_MaxTurnsProtection(t *testing.T) { - t.Parallel() - - t.Run("max turns concept validation", func(t *testing.T) { - t.Parallel() - // In a real agent loop, we would track turns - maxTurns := 10 - currentTurn := 0 - - // Simulate turn counting - for i := 0; i < 100; i++ { - currentTurn++ - if currentTurn >= maxTurns { - break - } - } - - if currentTurn != maxTurns { - t.Errorf("currentTurn = %d, want %d (should stop at max turns)", currentTurn, maxTurns) - } - }) -} - -// TestChat_ToolFailureRecovery tests that the agent can continue after tool failures. -// Resilience: Agent should gracefully handle tool execution errors. -func TestChat_ToolFailureRecovery(t *testing.T) { - t.Parallel() - - t.Run("tool error is wrapped", func(t *testing.T) { - t.Parallel() - toolErr := errors.New("tool failed: file not found") - wrappedErr := errors.New("tool execution failed: " + toolErr.Error()) - if !strings.Contains(wrappedErr.Error(), "tool execution failed") { - t.Errorf("wrappedErr = %q, want to contain %q", wrappedErr.Error(), "tool execution failed") - } - if !strings.Contains(wrappedErr.Error(), "file not found") { - t.Errorf("wrappedErr = %q, want to contain %q", wrappedErr.Error(), "file not found") - } - }) - - t.Run("tool error does not crash agent", func(t *testing.T) { - t.Parallel() - // Simulate error handling that doesn't propagate - var lastErr error - handleToolError := func(err error) { - lastErr = err // Log but don't crash - } - - handleToolError(errors.New("tool failed")) - if lastErr == nil { - t.Error("lastErr is nil, want non-nil") - } - // Agent continues running - }) -} - -// ============================================================================= -// deepCopyMessages / deepCopyPart / shallowCopyMap Tests -// ============================================================================= - -func TestDeepCopyMessages_NilInput(t *testing.T) { - t.Parallel() - got := deepCopyMessages(nil) - if got != nil { - t.Errorf("deepCopyMessages(nil) = %v, want nil", got) - } -} - -func TestDeepCopyMessages_EmptySlice(t *testing.T) { - t.Parallel() - got := deepCopyMessages([]*ai.Message{}) - if got == nil { - t.Fatal("deepCopyMessages(empty) = nil, want non-nil empty slice") - } - if len(got) != 0 { - t.Errorf("deepCopyMessages(empty) len = %d, want 0", len(got)) - } -} - -func TestDeepCopyMessages_MutateOriginalText(t *testing.T) { - t.Parallel() - - original := []*ai.Message{ - ai.NewUserMessage(ai.NewTextPart("hello world")), - } - - copied := deepCopyMessages(original) - - // Mutate the original message's content slice - original[0].Content[0].Text = "MUTATED" - - if copied[0].Content[0].Text != "hello world" { - t.Errorf("deepCopyMessages() copy was affected by original mutation: got %q, want %q", - copied[0].Content[0].Text, "hello world") - } -} - -func TestDeepCopyMessages_MutateOriginalContentSlice(t *testing.T) { - t.Parallel() - - original := []*ai.Message{ - ai.NewUserMessage(ai.NewTextPart("first"), ai.NewTextPart("second")), - } - - copied := deepCopyMessages(original) - - // Append to original's content slice — should not affect copy - original[0].Content = append(original[0].Content, ai.NewTextPart("third")) - - if len(copied[0].Content) != 2 { - t.Errorf("deepCopyMessages() copy content len = %d, want 2", len(copied[0].Content)) - } -} - -func TestDeepCopyMessages_PreservesRole(t *testing.T) { - t.Parallel() - - original := []*ai.Message{ - ai.NewUserMessage(ai.NewTextPart("q")), - ai.NewModelMessage(ai.NewTextPart("a")), - } - - copied := deepCopyMessages(original) - - if copied[0].Role != ai.RoleUser { - t.Errorf("deepCopyMessages()[0].Role = %q, want %q", copied[0].Role, ai.RoleUser) - } - if copied[1].Role != ai.RoleModel { - t.Errorf("deepCopyMessages()[1].Role = %q, want %q", copied[1].Role, ai.RoleModel) - } -} - -func TestDeepCopyMessages_Metadata(t *testing.T) { - t.Parallel() - - original := []*ai.Message{{ - Role: ai.RoleUser, - Content: []*ai.Part{ai.NewTextPart("test")}, - Metadata: map[string]any{"key": "value"}, - }} - - copied := deepCopyMessages(original) - - // Mutate original metadata - original[0].Metadata["key"] = "MUTATED" - - if copied[0].Metadata["key"] != "value" { - t.Errorf("deepCopyMessages() metadata was affected by mutation: got %q, want %q", - copied[0].Metadata["key"], "value") - } -} - -func TestDeepCopyPart_NilInput(t *testing.T) { - t.Parallel() - got := deepCopyPart(nil) - if got != nil { - t.Errorf("deepCopyPart(nil) = %v, want nil", got) - } -} - -func TestDeepCopyPart_TextPart(t *testing.T) { - t.Parallel() - - original := ai.NewTextPart("hello") - copied := deepCopyPart(original) - - original.Text = "MUTATED" - - if copied.Text != "hello" { - t.Errorf("deepCopyPart() text affected by mutation: got %q, want %q", copied.Text, "hello") - } -} - -func TestDeepCopyPart_ToolRequest(t *testing.T) { - t.Parallel() - - original := &ai.Part{ - Kind: ai.PartToolRequest, - ToolRequest: &ai.ToolRequest{ - Name: "read_file", - Input: map[string]any{"path": "/tmp/test"}, - }, - } - - copied := deepCopyPart(original) - - // Mutate original ToolRequest name - original.ToolRequest.Name = "MUTATED" - - if copied.ToolRequest.Name != "read_file" { - t.Errorf("deepCopyPart() ToolRequest.Name affected by mutation: got %q, want %q", - copied.ToolRequest.Name, "read_file") - } -} - -func TestDeepCopyPart_ToolResponse(t *testing.T) { - t.Parallel() - - original := &ai.Part{ - Kind: ai.PartToolResponse, - ToolResponse: &ai.ToolResponse{ - Name: "read_file", - Output: "file contents", - }, - } - - copied := deepCopyPart(original) - - original.ToolResponse.Name = "MUTATED" - - if copied.ToolResponse.Name != "read_file" { - t.Errorf("deepCopyPart() ToolResponse.Name affected by mutation: got %q, want %q", - copied.ToolResponse.Name, "read_file") - } -} - -func TestDeepCopyPart_Resource(t *testing.T) { - t.Parallel() - - original := &ai.Part{ - Kind: ai.PartMedia, - Resource: &ai.ResourcePart{Uri: "https://example.com/image.png"}, - } - - copied := deepCopyPart(original) - - original.Resource.Uri = "MUTATED" - - if copied.Resource.Uri != "https://example.com/image.png" { - t.Errorf("deepCopyPart() Resource.Uri affected by mutation: got %q, want %q", - copied.Resource.Uri, "https://example.com/image.png") - } -} - -func TestDeepCopyPart_PartMetadata(t *testing.T) { - t.Parallel() - - original := &ai.Part{ - Kind: ai.PartText, - Text: "test", - Custom: map[string]any{"c": "custom"}, - Metadata: map[string]any{"m": "meta"}, - } - - copied := deepCopyPart(original) - - original.Custom["c"] = "MUTATED" - original.Metadata["m"] = "MUTATED" - - if copied.Custom["c"] != "custom" { - t.Errorf("deepCopyPart() Custom map affected: got %q, want %q", copied.Custom["c"], "custom") - } - if copied.Metadata["m"] != "meta" { - t.Errorf("deepCopyPart() Metadata map affected: got %q, want %q", copied.Metadata["m"], "meta") - } -} - -func TestShallowCopyMap_NilInput(t *testing.T) { - t.Parallel() - got := shallowCopyMap(nil) - if got != nil { - t.Errorf("shallowCopyMap(nil) = %v, want nil", got) - } -} - -func TestShallowCopyMap_IndependentKeys(t *testing.T) { - t.Parallel() - - original := map[string]any{"a": "1", "b": "2"} - copied := shallowCopyMap(original) - - // Add new key to original - original["c"] = "3" - - if _, ok := copied["c"]; ok { - t.Error("shallowCopyMap() new key in original appeared in copy") - } - if len(copied) != 2 { - t.Errorf("shallowCopyMap() copy len = %d, want 2", len(copied)) - } -} - -func TestShallowCopyMap_MutateValue(t *testing.T) { - t.Parallel() - - original := map[string]any{"key": "value"} - copied := shallowCopyMap(original) - - // Overwrite original value - original["key"] = "MUTATED" - - if copied["key"] != "value" { - t.Errorf("shallowCopyMap() value affected by mutation: got %q, want %q", - copied["key"], "value") - } -} diff --git a/internal/agent/chat/flow_test.go b/internal/agent/chat/flow_test.go deleted file mode 100644 index 644288e..0000000 --- a/internal/agent/chat/flow_test.go +++ /dev/null @@ -1,154 +0,0 @@ -package chat - -import ( - "errors" - "testing" - - "github.com/koopa0/koopa/internal/agent" -) - -// TestFlowName tests the FlowName constant -func TestFlowName(t *testing.T) { - t.Parallel() - if FlowName != "koopa/chat" { - t.Errorf("FlowName = %q, want %q", FlowName, "koopa/chat") - } - if FlowName == "" { - t.Error("FlowName is empty, want non-empty") - } -} - -// TestStreamChunk_Structure tests the StreamChunk type -func TestStreamChunk_Structure(t *testing.T) { - t.Parallel() - - t.Run("zero value has empty text", func(t *testing.T) { - t.Parallel() - var chunk StreamChunk - if chunk.Text != "" { - t.Errorf("chunk.Text = %q, want %q", chunk.Text, "") - } - }) - - t.Run("can set text", func(t *testing.T) { - t.Parallel() - chunk := StreamChunk{Text: "Hello, World!"} - if chunk.Text != "Hello, World!" { - t.Errorf("chunk.Text = %q, want %q", chunk.Text, "Hello, World!") - } - }) - - t.Run("can hold unicode text", func(t *testing.T) { - t.Parallel() - chunk := StreamChunk{Text: "你好世界 🌍"} - if chunk.Text != "你好世界 🌍" { - t.Errorf("chunk.Text = %q, want %q", chunk.Text, "你好世界 🌍") - } - }) -} - -// TestInput_Structure tests the Input type -func TestInput_Structure(t *testing.T) { - t.Parallel() - - t.Run("zero value has empty fields", func(t *testing.T) { - t.Parallel() - var input Input - if input.Query != "" { - t.Errorf("input.Query = %q, want %q", input.Query, "") - } - if input.SessionID != "" { - t.Errorf("input.SessionID = %q, want %q", input.SessionID, "") - } - }) - - t.Run("can set all fields", func(t *testing.T) { - t.Parallel() - input := Input{ - Query: "What is the weather?", - SessionID: "test-session-123", - } - if input.Query != "What is the weather?" { - t.Errorf("input.Query = %q, want %q", input.Query, "What is the weather?") - } - if input.SessionID != "test-session-123" { - t.Errorf("input.SessionID = %q, want %q", input.SessionID, "test-session-123") - } - }) -} - -// TestOutput_Structure tests the Output type -func TestOutput_Structure(t *testing.T) { - t.Parallel() - - t.Run("zero value has empty fields", func(t *testing.T) { - t.Parallel() - var output Output - if output.Response != "" { - t.Errorf("output.Response = %q, want %q", output.Response, "") - } - if output.SessionID != "" { - t.Errorf("output.SessionID = %q, want %q", output.SessionID, "") - } - }) - - t.Run("can set response and session", func(t *testing.T) { - t.Parallel() - output := Output{ - Response: "The weather is sunny.", - SessionID: "test-session-123", - } - if output.Response != "The weather is sunny." { - t.Errorf("output.Response = %q, want %q", output.Response, "The weather is sunny.") - } - if output.SessionID != "test-session-123" { - t.Errorf("output.SessionID = %q, want %q", output.SessionID, "test-session-123") - } - }) -} - -// TestSentinelErrors_CanBeChecked tests that sentinel errors work correctly with errors.Is -func TestSentinelErrors_CanBeChecked(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - err error - sentinel error - }{ - {"ErrInvalidSession", agent.ErrInvalidSession, agent.ErrInvalidSession}, - {"ErrExecutionFailed", agent.ErrExecutionFailed, agent.ErrExecutionFailed}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - if !errors.Is(tt.err, tt.sentinel) { - t.Errorf("errors.Is(%v, %v) = false, want true", tt.err, tt.sentinel) - } - }) - } -} - -// TestWrappedErrors_PreserveSentinel tests that wrapped errors preserve sentinel checking -func TestWrappedErrors_PreserveSentinel(t *testing.T) { - t.Parallel() - - t.Run("wrapped invalid session error", func(t *testing.T) { - t.Parallel() - err := errors.New("original error") - wrapped := errors.Join(agent.ErrInvalidSession, err) - if !errors.Is(wrapped, agent.ErrInvalidSession) { - t.Errorf("errors.Is(wrapped, ErrInvalidSession) = false, want true") - } - }) - - t.Run("wrapped execution failed error", func(t *testing.T) { - t.Parallel() - err := errors.New("LLM timeout") - wrapped := errors.Join(agent.ErrExecutionFailed, err) - if !errors.Is(wrapped, agent.ErrExecutionFailed) { - t.Errorf("errors.Is(wrapped, ErrExecutionFailed) = false, want true") - } - }) -} diff --git a/internal/agent/chat/retry_test.go b/internal/agent/chat/retry_test.go deleted file mode 100644 index 94a7036..0000000 --- a/internal/agent/chat/retry_test.go +++ /dev/null @@ -1,196 +0,0 @@ -package chat - -import ( - "errors" - "testing" -) - -func TestDefaultRetryConfig(t *testing.T) { - t.Parallel() - - cfg := DefaultRetryConfig() - - if cfg.MaxRetries <= 0 { - t.Errorf("MaxRetries should be positive, got %d", cfg.MaxRetries) - } - if cfg.InitialInterval <= 0 { - t.Errorf("InitialInterval should be positive, got %v", cfg.InitialInterval) - } - if cfg.MaxInterval <= 0 { - t.Errorf("MaxInterval should be positive, got %v", cfg.MaxInterval) - } - if cfg.MaxInterval < cfg.InitialInterval { - t.Error("MaxInterval should be >= InitialInterval") - } -} - -func TestRetryableError(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - err error - expected bool - }{ - { - name: "nil error", - err: nil, - expected: false, - }, - { - name: "rate limit error", - err: errors.New("rate limit exceeded"), - expected: true, - }, - { - name: "quota exceeded error", - err: errors.New("quota exceeded for project"), - expected: true, - }, - { - name: "429 status code", - err: errors.New("HTTP 429: Too Many Requests"), - expected: true, - }, - { - name: "500 server error", - err: errors.New("HTTP 500 Internal Server Error"), - expected: true, - }, - { - name: "502 bad gateway", - err: errors.New("502 Bad Gateway"), - expected: true, - }, - { - name: "503 unavailable", - err: errors.New("503 Service Unavailable"), - expected: true, - }, - { - name: "504 gateway timeout", - err: errors.New("504 Gateway Timeout"), - expected: true, - }, - { - name: "unavailable keyword", - err: errors.New("service unavailable"), - expected: true, - }, - { - name: "connection reset", - err: errors.New("connection reset by peer"), - expected: true, - }, - { - name: "timeout error", - err: errors.New("request timeout"), - expected: true, - }, - { - name: "temporary error", - err: errors.New("temporary failure"), - expected: true, - }, - { - name: "non-retryable error", - err: errors.New("invalid API key"), - expected: false, - }, - { - name: "non-retryable 400 error", - err: errors.New("HTTP 400 Bad Request"), - expected: false, - }, - { - name: "non-retryable 401 error", - err: errors.New("HTTP 401 Unauthorized"), - expected: false, - }, - { - name: "non-retryable 403 error", - err: errors.New("HTTP 403 Forbidden"), - expected: false, - }, - { - name: "case insensitive rate limit", - err: errors.New("RATE LIMIT reached"), - expected: true, - }, - { - name: "case insensitive timeout", - err: errors.New("TIMEOUT occurred"), - expected: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - got := retryableError(tt.err) - if got != tt.expected { - t.Errorf("retryableError(%v) = %v, want %v", tt.err, got, tt.expected) - } - }) - } -} - -func TestContainsAny(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - s string - substrs []string - expected bool - }{ - { - name: "empty string", - s: "", - substrs: []string{"foo"}, - expected: false, - }, - { - name: "empty substrs", - s: "foo bar", - substrs: []string{}, - expected: false, - }, - { - name: "contains first substr", - s: "foo bar baz", - substrs: []string{"foo", "qux"}, - expected: true, - }, - { - name: "contains last substr", - s: "foo bar baz", - substrs: []string{"qux", "baz"}, - expected: true, - }, - { - name: "case insensitive match", - s: "FOO BAR BAZ", - substrs: []string{"foo"}, - expected: true, - }, - { - name: "no match", - s: "foo bar baz", - substrs: []string{"qux", "quux"}, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - got := containsAny(tt.s, tt.substrs...) - if got != tt.expected { - t.Errorf("containsAny(%q, %v) = %v, want %v", tt.s, tt.substrs, got, tt.expected) - } - }) - } -} diff --git a/internal/agent/doc.go b/internal/agent/doc.go deleted file mode 100644 index eeca0bd..0000000 --- a/internal/agent/doc.go +++ /dev/null @@ -1,33 +0,0 @@ -// Package agent provides sentinel errors for the chat agent. -// -// # Overview -// -// This package provides shared error types for building conversational AI agents. -// The main implementation is in the chat subpackage. -// -// # Errors -// -// The package provides sentinel errors for consistent error handling: -// -// agent.ErrInvalidSession // Invalid session ID format -// agent.ErrExecutionFailed // LLM or tool execution failed -// -// # Usage -// -// The chat subpackage provides the Chat agent implementation: -// -// import "github.com/koopa0/koopa/internal/agent/chat" -// -// chatAgent, err := chat.New(chat.Config{ -// Genkit: g, -// Retriever: retriever, -// SessionStore: sessionStore, -// Logger: logger, -// Tools: tools, -// MaxTurns: 10, -// RAGTopK: 5, -// Language: "auto", -// }) -// -// See the chat subpackage for the complete implementation. -package agent diff --git a/internal/agent/errors.go b/internal/agent/errors.go deleted file mode 100644 index 6b132f1..0000000 --- a/internal/agent/errors.go +++ /dev/null @@ -1,16 +0,0 @@ -// Package agent provides the agent abstraction layer for AI chat functionality. -package agent - -import "errors" - -// Sentinel errors for agent operations. -// Only errors that are checked with errors.Is() are defined here. -var ( - // ErrInvalidSession indicates the session ID is invalid or malformed. - // Used by: web/handlers/chat.go for HTTP status mapping - ErrInvalidSession = errors.New("invalid session") - - // ErrExecutionFailed indicates agent execution failed. - // Used by: web/handlers/chat.go for HTTP status mapping - ErrExecutionFailed = errors.New("execution failed") -) diff --git a/internal/api/chat.go b/internal/api/chat.go index 6bbe80c..1035986 100644 --- a/internal/api/chat.go +++ b/internal/api/chat.go @@ -11,31 +11,16 @@ import ( "strings" "time" - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" "github.com/google/uuid" - "github.com/koopa0/koopa/internal/agent" - "github.com/koopa0/koopa/internal/agent/chat" + "github.com/koopa0/koopa/internal/chat" "github.com/koopa0/koopa/internal/tools" ) // SSE timeout for streaming connections. const sseTimeout = 5 * time.Minute -// Title generation constants. -const ( - titleMaxLength = 50 - titleGenerationTimeout = 5 * time.Second - titleInputMaxRunes = 500 -) - -const titlePrompt = `Generate a concise title (max 50 characters) for a chat session based on this first message. -The title should capture the main topic or intent. -Return ONLY the title text, no quotes, no explanations, no punctuation at the end. - -Message: %s - -Title:` +// titleMaxLength is the maximum rune length for a fallback session title. +const titleMaxLength = 50 // Tool display info for JSON SSE events. type toolDisplayInfo struct { @@ -45,17 +30,20 @@ type toolDisplayInfo struct { } var toolDisplay = map[string]toolDisplayInfo{ - "web_search": {StartMsg: "搜尋網路中...", CompleteMsg: "搜尋完成", ErrorMsg: "搜尋服務暫時無法使用,請稍後再試"}, - "web_fetch": {StartMsg: "讀取網頁中...", CompleteMsg: "已讀取內容", ErrorMsg: "無法讀取網頁內容"}, - "read_file": {StartMsg: "讀取檔案中...", CompleteMsg: "已讀取檔案", ErrorMsg: "無法讀取檔案"}, - "write_file": {StartMsg: "寫入檔案中...", CompleteMsg: "已寫入檔案", ErrorMsg: "寫入檔案失敗"}, - "list_files": {StartMsg: "瀏覽目錄中...", CompleteMsg: "目錄瀏覽完成", ErrorMsg: "無法瀏覽目錄"}, - "delete_file": {StartMsg: "刪除檔案中...", CompleteMsg: "已刪除檔案", ErrorMsg: "刪除檔案失敗"}, - "get_file_info": {StartMsg: "取得檔案資訊中...", CompleteMsg: "已取得檔案資訊", ErrorMsg: "無法取得檔案資訊"}, - "execute_command": {StartMsg: "執行命令中...", CompleteMsg: "命令執行完成", ErrorMsg: "命令執行失敗"}, - "current_time": {StartMsg: "取得時間中...", CompleteMsg: "時間已取得", ErrorMsg: "無法取得時間"}, - "get_env": {StartMsg: "取得環境變數中...", CompleteMsg: "環境變數已取得", ErrorMsg: "無法取得環境變數"}, - "knowledge_search": {StartMsg: "搜尋知識庫中...", CompleteMsg: "知識庫搜尋完成", ErrorMsg: "無法搜尋知識庫"}, + "web_search": {StartMsg: "搜尋網路中...", CompleteMsg: "搜尋完成", ErrorMsg: "搜尋服務暫時無法使用,請稍後再試"}, + "web_fetch": {StartMsg: "讀取網頁中...", CompleteMsg: "已讀取內容", ErrorMsg: "無法讀取網頁內容"}, + "read_file": {StartMsg: "讀取檔案中...", CompleteMsg: "已讀取檔案", ErrorMsg: "無法讀取檔案"}, + "write_file": {StartMsg: "寫入檔案中...", CompleteMsg: "已寫入檔案", ErrorMsg: "寫入檔案失敗"}, + "list_files": {StartMsg: "瀏覽目錄中...", CompleteMsg: "目錄瀏覽完成", ErrorMsg: "無法瀏覽目錄"}, + "delete_file": {StartMsg: "刪除檔案中...", CompleteMsg: "已刪除檔案", ErrorMsg: "刪除檔案失敗"}, + "get_file_info": {StartMsg: "取得檔案資訊中...", CompleteMsg: "已取得檔案資訊", ErrorMsg: "無法取得檔案資訊"}, + "execute_command": {StartMsg: "執行命令中...", CompleteMsg: "命令執行完成", ErrorMsg: "命令執行失敗"}, + "current_time": {StartMsg: "取得時間中...", CompleteMsg: "時間已取得", ErrorMsg: "無法取得時間"}, + "get_env": {StartMsg: "取得環境變數中...", CompleteMsg: "環境變數已取得", ErrorMsg: "無法取得環境變數"}, + "search_history": {StartMsg: "搜尋對話記錄中...", CompleteMsg: "對話記錄搜尋完成", ErrorMsg: "無法搜尋對話記錄"}, + "search_documents": {StartMsg: "搜尋知識庫中...", CompleteMsg: "知識庫搜尋完成", ErrorMsg: "無法搜尋知識庫"}, + "search_system_knowledge": {StartMsg: "搜尋系統知識中...", CompleteMsg: "系統知識搜尋完成", ErrorMsg: "無法搜尋系統知識"}, + "knowledge_store": {StartMsg: "儲存知識中...", CompleteMsg: "知識已儲存", ErrorMsg: "儲存知識失敗"}, } var defaultToolDisplay = toolDisplayInfo{ @@ -73,48 +61,47 @@ func getToolDisplay(name string) toolDisplayInfo { // chatHandler handles chat-related API requests. type chatHandler struct { - logger *slog.Logger - genkit *genkit.Genkit - modelName string // Provider-qualified model name for title generation - flow *chat.Flow - sessions *sessionManager + logger *slog.Logger + agent *chat.Agent // Optional: nil disables AI title generation + flow *chat.Flow + sessions *sessionManager } // send handles POST /api/v1/chat — accepts JSON, sends message to chat flow. -// -//nolint:revive // unused-receiver: method bound to chatHandler for consistent route registration func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) { var req struct { Content string `json:"content"` SessionID string `json:"sessionId"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - WriteError(w, http.StatusBadRequest, "invalid_json", "invalid request body") + WriteError(w, http.StatusBadRequest, "invalid_json", "invalid request body", h.logger) return } content := strings.TrimSpace(req.Content) if content == "" { - WriteError(w, http.StatusBadRequest, "content_required", "content is required") + WriteError(w, http.StatusBadRequest, "content_required", "content is required", h.logger) return } - // Resolve session ID from request body or context - var sessionID uuid.UUID + // Resolve session from context (set by session middleware from cookie) + sessionID, ok := sessionIDFromContext(r.Context()) + if !ok { + WriteError(w, http.StatusBadRequest, "session_required", "session ID required", h.logger) + return + } + + // If body also specifies a session, verify it matches (defense-in-depth) if req.SessionID != "" { parsed, err := uuid.Parse(req.SessionID) if err != nil { - WriteError(w, http.StatusBadRequest, "invalid_session", "invalid session ID") + WriteError(w, http.StatusBadRequest, "invalid_session", "invalid session ID", h.logger) return } - sessionID = parsed - } else { - ctxID, ok := SessionIDFromContext(r.Context()) - if !ok { - WriteError(w, http.StatusBadRequest, "session_required", "session ID required") + if parsed != sessionID { + WriteError(w, http.StatusForbidden, "forbidden", "session access denied", h.logger) return } - sessionID = ctxID } msgID := uuid.New().String() @@ -128,7 +115,7 @@ func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) { "msgId": msgID, "sessionId": sessionID.String(), "streamUrl": "/api/v1/chat/stream?" + params.Encode(), - }) + }, h.logger) } // stream handles GET /api/v1/chat/stream — SSE endpoint with JSON events. @@ -137,12 +124,21 @@ func (h *chatHandler) stream(w http.ResponseWriter, r *http.Request) { sessionID := r.URL.Query().Get("session_id") query := r.URL.Query().Get("query") - if msgID == "" || sessionID == "" { - WriteError(w, http.StatusBadRequest, "missing_params", "msgId and session_id required") + if msgID == "" || sessionID == "" || query == "" { + WriteError(w, http.StatusBadRequest, "missing_params", "msgId, session_id, and query required", h.logger) return } - if query == "" { - query = "Hello" + + // Verify session ownership + parsedID, err := uuid.Parse(sessionID) + if err != nil { + WriteError(w, http.StatusBadRequest, "invalid_session", "invalid session ID", h.logger) + return + } + ctxID, ok := sessionIDFromContext(r.Context()) + if !ok || ctxID != parsedID { + WriteError(w, http.StatusForbidden, "forbidden", "session access denied", h.logger) + return } // Set SSE headers @@ -151,24 +147,23 @@ func (h *chatHandler) stream(w http.ResponseWriter, r *http.Request) { w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") - flusher, ok := w.(http.Flusher) - if !ok { - WriteError(w, http.StatusInternalServerError, "sse_unsupported", "streaming not supported") + if _, ok := w.(http.Flusher); !ok { + WriteError(w, http.StatusInternalServerError, "sse_unsupported", "streaming not supported", h.logger) return } ctx, cancel := context.WithTimeout(r.Context(), sseTimeout) defer cancel() - if h.flow != nil { - h.streamWithFlow(ctx, w, flusher, msgID, sessionID, query) - } else { - h.simulateStreaming(ctx, w, flusher, msgID, sessionID, query) + if h.flow == nil { + _ = sseEvent(w, "error", map[string]string{"error": "chat flow not initialized"}) + return } + h.streamWithFlow(ctx, w, msgID, sessionID, query) } // sseEvent writes a single SSE event. -func sseEvent(w http.ResponseWriter, f http.Flusher, event string, data any) error { +func sseEvent(w http.ResponseWriter, event string, data any) error { jsonData, err := json.Marshal(data) if err != nil { return fmt.Errorf("marshal SSE data: %w", err) @@ -177,19 +172,21 @@ func sseEvent(w http.ResponseWriter, f http.Flusher, event string, data any) err if err != nil { return fmt.Errorf("write SSE event: %w", err) } - f.Flush() + if f, ok := w.(http.Flusher); ok { + f.Flush() + } return nil } // streamWithFlow uses the real chat.Flow for AI responses. -func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter, f http.Flusher, msgID, sessionID, query string) { +func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter, msgID, sessionID, query string) { input := chat.Input{ Query: query, SessionID: sessionID, } // Create JSON tool emitter and inject into context - emitter := &jsonToolEmitter{w: w, f: f, msgID: msgID} + emitter := &jsonToolEmitter{w: w, msgID: msgID} ctx = tools.ContextWithEmitter(ctx, emitter) h.logger.Debug("starting stream", "sessionId", sessionID) @@ -197,7 +194,6 @@ func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter, var ( finalOutput chat.Output streamErr error - buf strings.Builder ) for streamValue, err := range h.flow.Stream(ctx, input) { @@ -219,11 +215,8 @@ func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter, } if streamValue.Stream.Text != "" { - buf.WriteString(streamValue.Stream.Text) - content := buf.String() - buf.Reset() - if err := sseEvent(w, f, "chunk", map[string]string{"msgId": msgID, "text": content}); err != nil { - h.logger.Error("failed to write chunk", "error", err) + if err := sseEvent(w, "chunk", map[string]string{"msgId": msgID, "text": streamValue.Stream.Text}); err != nil { + h.logger.Error("writing chunk", "error", err) return } } @@ -231,16 +224,11 @@ func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter, if streamErr != nil { code, message := classifyError(streamErr) - h.logger.Error("flow execution failed", "error", streamErr, "sessionId", sessionID) - _ = sseEvent(w, f, "error", map[string]string{"msgId": msgID, "code": code, "message": message}) + h.logger.Error("executing flow", "error", streamErr, "sessionId", sessionID) + _ = sseEvent(w, "error", map[string]string{"msgId": msgID, "code": code, "message": message}) // best-effort: client may have disconnected return } - // Flush remaining buffer - if buf.Len() > 0 { - _ = sseEvent(w, f, "chunk", map[string]string{"msgId": msgID, "text": buf.String()}) - } - // Generate title before sending done event title := h.maybeGenerateTitle(ctx, sessionID, query) @@ -253,55 +241,15 @@ func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter, if title != "" { doneData["title"] = title } - _ = sseEvent(w, f, "done", doneData) -} - -// simulateStreaming is a placeholder for testing without real Flow. -func (h *chatHandler) simulateStreaming(ctx context.Context, w http.ResponseWriter, f http.Flusher, msgID, sessionID, query string) { - response := fmt.Sprintf("I received your message: %q. This is a simulated response.", query) - words := strings.Fields(response) - - var full strings.Builder - for i, word := range words { - select { - case <-ctx.Done(): - h.logContextDone(ctx, msgID) - return - default: - } - - if i > 0 { - full.WriteString(" ") - } - full.WriteString(word) - - if err := sseEvent(w, f, "chunk", map[string]string{"msgId": msgID, "text": full.String()}); err != nil { - h.logger.Error("failed to send chunk", "error", err) - return - } - - time.Sleep(50 * time.Millisecond) - } - - title := h.maybeGenerateTitle(ctx, sessionID, query) - - doneData := map[string]string{ - "msgId": msgID, - "sessionId": sessionID, - "response": full.String(), - } - if title != "" { - doneData["title"] = title - } - _ = sseEvent(w, f, "done", doneData) + _ = sseEvent(w, "done", doneData) // best-effort: client may have disconnected } // classifyError returns error code and user message based on error type. func classifyError(err error) (code, message string) { switch { - case errors.Is(err, agent.ErrInvalidSession): + case errors.Is(err, chat.ErrInvalidSession): return "invalid_session", "Invalid session. Please refresh the page." - case errors.Is(err, agent.ErrExecutionFailed): + case errors.Is(err, chat.ErrExecutionFailed): return "execution_failed", err.Error() case errors.Is(err, context.DeadlineExceeded): return "timeout", "Request timed out. Please try again." @@ -311,8 +259,13 @@ func classifyError(err error) (code, message string) { } // maybeGenerateTitle generates a session title if one doesn't exist. +// Uses Agent.GenerateTitle for AI-powered titles, falls back to truncation. // Returns the generated title or empty string. func (h *chatHandler) maybeGenerateTitle(ctx context.Context, sessionID, userMessage string) string { + if h.sessions == nil || h.sessions.store == nil { + return "" + } + sessionUUID, err := uuid.Parse(sessionID) if err != nil { return "" @@ -327,13 +280,16 @@ func (h *chatHandler) maybeGenerateTitle(ctx context.Context, sessionID, userMes return "" } - title := h.generateTitleWithAI(ctx, userMessage) + var title string + if h.agent != nil { + title = h.agent.GenerateTitle(ctx, userMessage) + } if title == "" { title = truncateForTitle(userMessage) } if err := h.sessions.store.UpdateSessionTitle(ctx, sessionUUID, title); err != nil { - h.logger.Error("failed to update session title", "error", err, "session_id", sessionID) + h.logger.Error("updating session title", "error", err, "session_id", sessionID) return "" } @@ -341,42 +297,6 @@ func (h *chatHandler) maybeGenerateTitle(ctx context.Context, sessionID, userMes return title } -// generateTitleWithAI uses Genkit to generate a session title. -func (h *chatHandler) generateTitleWithAI(ctx context.Context, userMessage string) string { - if h.genkit == nil { - return "" - } - - ctx, cancel := context.WithTimeout(ctx, titleGenerationTimeout) - defer cancel() - - inputRunes := []rune(userMessage) - if len(inputRunes) > titleInputMaxRunes { - userMessage = string(inputRunes[:titleInputMaxRunes]) + "..." - } - - response, err := genkit.Generate(ctx, h.genkit, - ai.WithModelName(h.modelName), - ai.WithPrompt(titlePrompt, userMessage), - ) - if err != nil { - h.logger.Debug("AI title generation failed", "error", err) - return "" - } - - title := strings.TrimSpace(response.Text()) - if title == "" { - return "" - } - - titleRunes := []rune(title) - if len(titleRunes) > titleMaxLength { - title = string(titleRunes[:titleMaxLength-3]) + "..." - } - - return title -} - // truncateForTitle truncates a message to create a fallback session title. func truncateForTitle(message string) string { message = strings.TrimSpace(message) @@ -402,16 +322,15 @@ func (h *chatHandler) logContextDone(ctx context.Context, msgID string) { } } -// jsonToolEmitter implements tools.ToolEventEmitter for JSON SSE events. +// jsonToolEmitter implements tools.Emitter for JSON SSE events. type jsonToolEmitter struct { w http.ResponseWriter - f http.Flusher msgID string } func (e *jsonToolEmitter) OnToolStart(name string) { display := getToolDisplay(name) - _ = sseEvent(e.w, e.f, "tool_start", map[string]string{ + _ = sseEvent(e.w, "tool_start", map[string]string{ // best-effort "msgId": e.msgID, "tool": name, "message": display.StartMsg, @@ -420,7 +339,7 @@ func (e *jsonToolEmitter) OnToolStart(name string) { func (e *jsonToolEmitter) OnToolComplete(name string) { display := getToolDisplay(name) - _ = sseEvent(e.w, e.f, "tool_complete", map[string]string{ + _ = sseEvent(e.w, "tool_complete", map[string]string{ // best-effort "msgId": e.msgID, "tool": name, "message": display.CompleteMsg, @@ -429,7 +348,7 @@ func (e *jsonToolEmitter) OnToolComplete(name string) { func (e *jsonToolEmitter) OnToolError(name string) { display := getToolDisplay(name) - _ = sseEvent(e.w, e.f, "tool_error", map[string]string{ + _ = sseEvent(e.w, "tool_error", map[string]string{ // best-effort "msgId": e.msgID, "tool": name, "message": display.ErrorMsg, @@ -437,4 +356,4 @@ func (e *jsonToolEmitter) OnToolError(name string) { } // Compile-time interface verification. -var _ tools.ToolEventEmitter = (*jsonToolEmitter)(nil) +var _ tools.Emitter = (*jsonToolEmitter)(nil) diff --git a/internal/api/chat_test.go b/internal/api/chat_test.go index 4b708db..e08af8a 100644 --- a/internal/api/chat_test.go +++ b/internal/api/chat_test.go @@ -9,10 +9,13 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" + "github.com/firebase/genkit/go/genkit" "github.com/google/uuid" - "github.com/koopa0/koopa/internal/agent" + + "github.com/koopa0/koopa/internal/chat" ) func newTestChatHandler() *chatHandler { @@ -22,8 +25,6 @@ func newTestChatHandler() *chatHandler { } func TestChatSend_URLEncoding(t *testing.T) { - ch := newTestChatHandler() - sessionID := uuid.New() content := "你好 world & foo=bar#hash?query" @@ -34,8 +35,10 @@ func TestChatSend_URLEncoding(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) + ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) + r = r.WithContext(ctx) - ch.send(w, r) + newTestChatHandler().send(w, r) if w.Code != http.StatusOK { t.Fatalf("send() status = %d, want %d", w.Code, http.StatusOK) @@ -67,8 +70,6 @@ func TestChatSend_URLEncoding(t *testing.T) { } func TestChatSend_SessionIDFromBody(t *testing.T) { - ch := newTestChatHandler() - sessionID := uuid.New() body, _ := json.Marshal(map[string]string{ "content": "hello", @@ -77,8 +78,10 @@ func TestChatSend_SessionIDFromBody(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) + ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) + r = r.WithContext(ctx) - ch.send(w, r) + newTestChatHandler().send(w, r) if w.Code != http.StatusOK { t.Fatalf("send() status = %d, want %d", w.Code, http.StatusOK) @@ -97,8 +100,6 @@ func TestChatSend_SessionIDFromBody(t *testing.T) { } func TestChatSend_SessionIDFromContext(t *testing.T) { - ch := newTestChatHandler() - sessionID := uuid.New() body, _ := json.Marshal(map[string]string{ "content": "hello", @@ -112,7 +113,7 @@ func TestChatSend_SessionIDFromContext(t *testing.T) { ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) r = r.WithContext(ctx) - ch.send(w, r) + newTestChatHandler().send(w, r) if w.Code != http.StatusOK { t.Fatalf("send() status = %d, want %d", w.Code, http.StatusOK) @@ -127,8 +128,6 @@ func TestChatSend_SessionIDFromContext(t *testing.T) { } func TestChatSend_MissingContent(t *testing.T) { - ch := newTestChatHandler() - body, _ := json.Marshal(map[string]string{ "sessionId": uuid.New().String(), }) @@ -136,7 +135,7 @@ func TestChatSend_MissingContent(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) - ch.send(w, r) + newTestChatHandler().send(w, r) if w.Code != http.StatusBadRequest { t.Fatalf("send(no content) status = %d, want %d", w.Code, http.StatusBadRequest) @@ -150,8 +149,6 @@ func TestChatSend_MissingContent(t *testing.T) { } func TestChatSend_EmptyContent(t *testing.T) { - ch := newTestChatHandler() - body, _ := json.Marshal(map[string]string{ "content": " ", "sessionId": uuid.New().String(), @@ -160,7 +157,7 @@ func TestChatSend_EmptyContent(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) - ch.send(w, r) + newTestChatHandler().send(w, r) if w.Code != http.StatusBadRequest { t.Fatalf("send(whitespace) status = %d, want %d", w.Code, http.StatusBadRequest) @@ -168,8 +165,6 @@ func TestChatSend_EmptyContent(t *testing.T) { } func TestChatSend_InvalidSessionID(t *testing.T) { - ch := newTestChatHandler() - body, _ := json.Marshal(map[string]string{ "content": "hello", "sessionId": "not-a-uuid", @@ -177,8 +172,11 @@ func TestChatSend_InvalidSessionID(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) + // Inject context session so we reach the body parse check + ctx := context.WithValue(r.Context(), ctxKeySessionID, uuid.New()) + r = r.WithContext(ctx) - ch.send(w, r) + newTestChatHandler().send(w, r) if w.Code != http.StatusBadRequest { t.Fatalf("send(bad uuid) status = %d, want %d", w.Code, http.StatusBadRequest) @@ -192,8 +190,6 @@ func TestChatSend_InvalidSessionID(t *testing.T) { } func TestChatSend_NoSession(t *testing.T) { - ch := newTestChatHandler() - body, _ := json.Marshal(map[string]string{ "content": "hello", // No sessionId in body, no session in context @@ -202,7 +198,7 @@ func TestChatSend_NoSession(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) - ch.send(w, r) + newTestChatHandler().send(w, r) if w.Code != http.StatusBadRequest { t.Fatalf("send(no session) status = %d, want %d", w.Code, http.StatusBadRequest) @@ -216,12 +212,10 @@ func TestChatSend_NoSession(t *testing.T) { } func TestChatSend_InvalidJSON(t *testing.T) { - ch := newTestChatHandler() - w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader([]byte("not json"))) - ch.send(w, r) + newTestChatHandler().send(w, r) if w.Code != http.StatusBadRequest { t.Fatalf("send(invalid json) status = %d, want %d", w.Code, http.StatusBadRequest) @@ -235,11 +229,11 @@ func TestTruncateForTitle(t *testing.T) { wantMax int wantDots bool }{ - {"short", "Hello world", 50, false}, - {"exact_50", "12345678901234567890123456789012345678901234567890", 50, false}, - {"long", "This is a very long message that exceeds the maximum allowed title length of fifty characters", 53, true}, - {"empty", "", 0, false}, - {"whitespace", " hello ", 50, false}, + {name: "short", input: "Hello world", wantMax: 50, wantDots: false}, + {name: "exact_50", input: "12345678901234567890123456789012345678901234567890", wantMax: 50, wantDots: false}, + {name: "long", input: "This is a very long message that exceeds the maximum allowed title length of fifty characters", wantMax: 53, wantDots: true}, + {name: "empty", input: "", wantMax: 0, wantDots: false}, + {name: "whitespace", input: " hello ", wantMax: 50, wantDots: false}, } for _, tt := range tests { @@ -280,16 +274,175 @@ func TestGetToolDisplay(t *testing.T) { } } +func TestStream_MissingParams(t *testing.T) { + ch := newTestChatHandler() + + tests := []struct { + name string + query string + }{ + {name: "missing all", query: ""}, + {name: "missing session_id and query", query: "?msgId=abc"}, + {name: "missing msgId and query", query: "?session_id=abc"}, + {name: "missing query", query: "?msgId=abc&session_id=def"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream"+tt.query, nil) + + ch.stream(w, r) + + if w.Code != http.StatusBadRequest { + t.Fatalf("stream(%s) status = %d, want %d", tt.name, w.Code, http.StatusBadRequest) + } + + errResp := decodeErrorEnvelope(t, w) + if errResp.Code != "missing_params" { + t.Errorf("stream(%s) code = %q, want %q", tt.name, errResp.Code, "missing_params") + } + }) + } +} + +func TestStream_SSEHeaders(t *testing.T) { + ch := newTestChatHandler() // flow is nil → error event (headers still set) + sessionID := uuid.New() + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil) + ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) + r = r.WithContext(ctx) + + ch.stream(w, r) + + wantHeaders := map[string]string{ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + } + + for header, want := range wantHeaders { + if got := w.Header().Get(header); got != want { + t.Errorf("stream() header %q = %q, want %q", header, got, want) + } + } +} + +func TestStream_NilFlow(t *testing.T) { + ch := newTestChatHandler() // flow is nil → error SSE event + sessionID := uuid.New() + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hello", nil) + ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) + r = r.WithContext(ctx) + + ch.stream(w, r) + + body := w.Body.String() + + // Should emit an error event, not chunk or done events + if !strings.Contains(body, "event: error\n") { + t.Error("stream(nil flow) expected error event in SSE output") + } + if !strings.Contains(body, "chat flow not initialized") { + t.Error("stream(nil flow) expected 'chat flow not initialized' in error event") + } + if strings.Contains(body, "event: chunk\n") { + t.Error("stream(nil flow) should not emit chunk events") + } + if strings.Contains(body, "event: done\n") { + t.Error("stream(nil flow) should not emit done events") + } +} + +func TestSSEEvent_Format(t *testing.T) { + w := httptest.NewRecorder() + + data := map[string]string{"msgId": "abc", "text": "hello"} + err := sseEvent(w, "chunk", data) + + if err != nil { + t.Fatalf("sseEvent() error: %v", err) + } + + body := w.Body.String() + + // Verify SSE format: "event: \ndata: \n\n" + if !strings.HasPrefix(body, "event: chunk\ndata: ") { + t.Errorf("sseEvent() format = %q, want prefix %q", body, "event: chunk\ndata: ") + } + + if !strings.HasSuffix(body, "\n\n") { + t.Errorf("sseEvent() should end with double newline, got %q", body) + } + + // Verify JSON payload is valid + dataLine := strings.TrimPrefix(body, "event: chunk\ndata: ") + dataLine = strings.TrimSuffix(dataLine, "\n\n") + + var decoded map[string]string + if err := json.Unmarshal([]byte(dataLine), &decoded); err != nil { + t.Fatalf("sseEvent() data is not valid JSON: %v", err) + } + + if decoded["msgId"] != "abc" { + t.Errorf("sseEvent() data.msgId = %q, want %q", decoded["msgId"], "abc") + } + if decoded["text"] != "hello" { + t.Errorf("sseEvent() data.text = %q, want %q", decoded["text"], "hello") + } +} + +func TestSSEEvent_MarshalError(t *testing.T) { + w := httptest.NewRecorder() + + // Channels cannot be marshaled to JSON + err := sseEvent(w, "chunk", make(chan int)) + + if err == nil { + t.Fatal("sseEvent(unmarshalable) expected error, got nil") + } +} + +func TestStream_NilFlow_ContextCanceled(t *testing.T) { + ch := newTestChatHandler() // flow is nil + sessionID := uuid.New() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + ctx = context.WithValue(ctx, ctxKeySessionID, sessionID) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil) + r = r.WithContext(ctx) + + ch.stream(w, r) + + body := w.Body.String() + + // Flow is nil, so error event is emitted regardless of context state. + if !strings.Contains(body, "event: error\n") { + t.Error("stream(nil flow, canceled) expected error event") + } + if strings.Contains(body, "event: done\n") { + t.Error("stream(nil flow, canceled) should not emit done event") + } +} + func TestClassifyError(t *testing.T) { tests := []struct { name string err error wantCode string }{ - {"invalid_session", agent.ErrInvalidSession, "invalid_session"}, - {"execution_failed", agent.ErrExecutionFailed, "execution_failed"}, - {"deadline_exceeded", context.DeadlineExceeded, "timeout"}, - {"generic_error", errors.New("something went wrong"), "flow_error"}, + {name: "invalid_session", err: chat.ErrInvalidSession, wantCode: "invalid_session"}, + {name: "execution_failed", err: chat.ErrExecutionFailed, wantCode: "execution_failed"}, + {name: "deadline_exceeded", err: context.DeadlineExceeded, wantCode: "timeout"}, + {name: "generic_error", err: errors.New("something went wrong"), wantCode: "flow_error"}, } for _, tt := range tests { @@ -304,3 +457,231 @@ func TestClassifyError(t *testing.T) { }) } } + +func TestChatSend_SessionMismatch(t *testing.T) { + body, _ := json.Marshal(map[string]string{ + "content": "hello", + "sessionId": uuid.New().String(), // Different from context + }) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) + ctx := context.WithValue(r.Context(), ctxKeySessionID, uuid.New()) + r = r.WithContext(ctx) + + newTestChatHandler().send(w, r) + + if w.Code != http.StatusForbidden { + t.Fatalf("send(mismatched session) status = %d, want %d", w.Code, http.StatusForbidden) + } + + errResp := decodeErrorEnvelope(t, w) + if errResp.Code != "forbidden" { + t.Errorf("send(mismatched session) code = %q, want %q", errResp.Code, "forbidden") + } +} + +func TestStream_OwnershipDenied(t *testing.T) { + ch := newTestChatHandler() + sessionID := uuid.New() + otherID := uuid.New() + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil) + ctx := context.WithValue(r.Context(), ctxKeySessionID, otherID) + r = r.WithContext(ctx) + + ch.stream(w, r) + + if w.Code != http.StatusForbidden { + t.Fatalf("stream(wrong session) status = %d, want %d", w.Code, http.StatusForbidden) + } +} + +func TestStream_NoSession(t *testing.T) { + ch := newTestChatHandler() + sessionID := uuid.New() + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil) + // No session in context + + ch.stream(w, r) + + if w.Code != http.StatusForbidden { + t.Fatalf("stream(no session) status = %d, want %d", w.Code, http.StatusForbidden) + } +} + +// sseTestEvent represents a parsed SSE event for test assertions. +type sseTestEvent struct { + Type string + Data map[string]string +} + +// parseSSEEvents parses an SSE response body into structured events. +func parseSSEEvents(t *testing.T, body string) []sseTestEvent { + t.Helper() + var events []sseTestEvent + for _, block := range strings.Split(body, "\n\n") { + block = strings.TrimSpace(block) + if block == "" { + continue + } + var ev sseTestEvent + for _, line := range strings.Split(block, "\n") { + switch { + case strings.HasPrefix(line, "event: "): + ev.Type = strings.TrimPrefix(line, "event: ") + case strings.HasPrefix(line, "data: "): + raw := strings.TrimPrefix(line, "data: ") + ev.Data = make(map[string]string) + if err := json.Unmarshal([]byte(raw), &ev.Data); err != nil { + t.Fatalf("parseSSEEvents: invalid JSON in data line %q: %v", raw, err) + } + } + } + if ev.Type != "" { + events = append(events, ev) + } + } + return events +} + +// filterSSEEvents returns events matching the given type. +func filterSSEEvents(events []sseTestEvent, eventType string) []sseTestEvent { + var filtered []sseTestEvent + for _, e := range events { + if e.Type == eventType { + filtered = append(filtered, e) + } + } + return filtered +} + +func TestStreamWithFlow(t *testing.T) { + sessionID := uuid.New() + sessionIDStr := sessionID.String() + + tests := []struct { + name string + flowFn func(context.Context, chat.Input, func(context.Context, chat.StreamChunk) error) (chat.Output, error) + wantChunks []string // expected chunk text values in order + wantDone map[string]string // expected fields in done event (nil = no done expected) + wantError map[string]string // expected fields in error event (nil = no error expected) + }{ + { + name: "success with chunks", + flowFn: func(ctx context.Context, input chat.Input, stream func(context.Context, chat.StreamChunk) error) (chat.Output, error) { + if stream != nil { + if err := stream(ctx, chat.StreamChunk{Text: "Hello "}); err != nil { + return chat.Output{}, err + } + if err := stream(ctx, chat.StreamChunk{Text: "World"}); err != nil { + return chat.Output{}, err + } + } + return chat.Output{Response: "Hello World", SessionID: input.SessionID}, nil + }, + wantChunks: []string{"Hello ", "World"}, + wantDone: map[string]string{"response": "Hello World", "sessionId": sessionIDStr}, + }, + { + name: "flow error without chunks", + flowFn: func(_ context.Context, _ chat.Input, _ func(context.Context, chat.StreamChunk) error) (chat.Output, error) { + return chat.Output{}, chat.ErrInvalidSession + }, + wantError: map[string]string{"code": "invalid_session"}, + }, + { + name: "partial chunks then error", + flowFn: func(ctx context.Context, _ chat.Input, stream func(context.Context, chat.StreamChunk) error) (chat.Output, error) { + if stream != nil { + if err := stream(ctx, chat.StreamChunk{Text: "partial"}); err != nil { + return chat.Output{}, err + } + } + return chat.Output{}, chat.ErrExecutionFailed + }, + wantChunks: []string{"partial"}, + wantError: map[string]string{"code": "execution_failed"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + g := genkit.Init(ctx) + flowName := "test/" + strings.ReplaceAll(tt.name, " ", "_") + testFlow := genkit.DefineStreamingFlow(g, flowName, tt.flowFn) + + ch := &chatHandler{ + logger: slog.New(slog.DiscardHandler), + flow: testFlow, + } + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, + "/api/v1/chat/stream?msgId=m1&session_id="+sessionIDStr+"&query=test", nil) + rctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) + r = r.WithContext(rctx) + + ch.stream(w, r) + + if ct := w.Header().Get("Content-Type"); ct != "text/event-stream" { + t.Fatalf("stream(%s) Content-Type = %q, want %q", tt.name, ct, "text/event-stream") + } + + events := parseSSEEvents(t, w.Body.String()) + + // Verify chunk events + chunks := filterSSEEvents(events, "chunk") + if len(chunks) != len(tt.wantChunks) { + t.Fatalf("stream(%s) got %d chunk events, want %d", tt.name, len(chunks), len(tt.wantChunks)) + } + for i, wantText := range tt.wantChunks { + if got := chunks[i].Data["text"]; got != wantText { + t.Errorf("stream(%s) chunk[%d].text = %q, want %q", tt.name, i, got, wantText) + } + if got := chunks[i].Data["msgId"]; got != "m1" { + t.Errorf("stream(%s) chunk[%d].msgId = %q, want %q", tt.name, i, got, "m1") + } + } + + // Verify done event + doneEvents := filterSSEEvents(events, "done") + if tt.wantDone != nil { + if len(doneEvents) != 1 { + t.Fatalf("stream(%s) got %d done events, want 1", tt.name, len(doneEvents)) + } + for k, want := range tt.wantDone { + if got := doneEvents[0].Data[k]; got != want { + t.Errorf("stream(%s) done.%s = %q, want %q", tt.name, k, got, want) + } + } + if got := doneEvents[0].Data["msgId"]; got != "m1" { + t.Errorf("stream(%s) done.msgId = %q, want %q", tt.name, got, "m1") + } + } else if len(doneEvents) != 0 { + t.Errorf("stream(%s) got %d done events, want 0", tt.name, len(doneEvents)) + } + + // Verify error event + errorEvents := filterSSEEvents(events, "error") + if tt.wantError != nil { + if len(errorEvents) != 1 { + t.Fatalf("stream(%s) got %d error events, want 1", tt.name, len(errorEvents)) + } + for k, want := range tt.wantError { + if got := errorEvents[0].Data[k]; got != want { + t.Errorf("stream(%s) error.%s = %q, want %q", tt.name, k, got, want) + } + } + } else if len(errorEvents) != 0 { + t.Errorf("stream(%s) got %d error events, want 0", tt.name, len(errorEvents)) + } + }) + } +} diff --git a/internal/api/doc.go b/internal/api/doc.go new file mode 100644 index 0000000..b6ae844 --- /dev/null +++ b/internal/api/doc.go @@ -0,0 +1,79 @@ +// Package api provides the JSON REST API server for Koopa. +// +// # Architecture +// +// The API server uses Go 1.22+ routing with a layered middleware stack: +// +// Recovery → Logging → RateLimit → CORS → Session → CSRF → Routes +// +// Health probes (/health, /ready) bypass the middleware stack via a +// top-level mux, ensuring they remain fast and unauthenticated. +// +// # Endpoints +// +// Health probes (no middleware): +// - GET /health — returns {"status":"ok"} +// - GET /ready — returns {"status":"ok"} +// +// CSRF provisioning: +// - GET /api/v1/csrf-token — returns pre-session or session-bound token +// +// Session CRUD (ownership-enforced): +// - POST /api/v1/sessions — create new session +// - GET /api/v1/sessions — list caller's sessions +// - GET /api/v1/sessions/{id} — get session by ID +// - GET /api/v1/sessions/{id}/messages — get session messages +// - DELETE /api/v1/sessions/{id} — delete session +// +// Chat (ownership-enforced): +// - POST /api/v1/chat — initiate chat, returns stream URL +// - GET /api/v1/chat/stream — SSE endpoint for streaming responses +// +// # CSRF Token Model +// +// Two token types prevent cross-site request forgery: +// +// - Pre-session tokens ("pre:nonce:timestamp:signature"): issued before +// a session exists, valid for the first POST /sessions call. +// +// - Session-bound tokens ("timestamp:signature"): bound to a specific +// session via HMAC-SHA256, verified with constant-time comparison. +// +// Both expire after 24 hours with 5 minutes of clock skew tolerance. +// +// # Session Ownership +// +// All session-accessing endpoints verify that the requested resource +// matches the caller's session cookie. This prevents session enumeration +// and cross-session data access. +// +// # Error Handling +// +// All responses use an envelope format: +// +// Success: {"data": } +// Error: {"error": {"code": "...", "message": "..."}} +// +// Tool errors during chat are sent as SSE events (event: error), +// not HTTP error responses, since SSE headers are already committed. +// +// # SSE Streaming +// +// Chat responses stream via Server-Sent Events with typed events: +// +// - chunk: incremental text content +// - tool_start: tool execution began +// - tool_complete: tool execution succeeded +// - tool_error: tool execution failed +// - done: final response with session metadata +// - error: flow-level error +// +// # Security +// +// The middleware stack enforces: +// - CSRF protection for state-changing requests +// - Per-IP rate limiting (token bucket, 60 req/min burst) +// - CORS with explicit origin allowlist +// - Security headers (CSP, HSTS, X-Frame-Options, etc.) +// - HttpOnly, Secure, SameSite=Lax session cookies +package api diff --git a/internal/api/e2e_test.go b/internal/api/e2e_test.go new file mode 100644 index 0000000..db2473a --- /dev/null +++ b/internal/api/e2e_test.go @@ -0,0 +1,470 @@ +//go:build integration + +package api + +import ( + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/koopa0/koopa/internal/session" + "github.com/koopa0/koopa/internal/sqlc" + "github.com/koopa0/koopa/internal/testutil" +) + +// e2eServer creates a full Server with all middleware backed by a real PostgreSQL database. +// Returns the server handler and cleanup function. +func e2eServer(t *testing.T) http.Handler { + t.Helper() + + db := testutil.SetupTestDB(t) + + store := session.New(sqlc.New(db.Pool), db.Pool, slog.New(slog.DiscardHandler)) + + srv, err := NewServer(ServerConfig{ + Logger: slog.New(slog.DiscardHandler), + SessionStore: store, + CSRFSecret: []byte("e2e-test-secret-at-least-32-characters!!"), + CORSOrigins: []string{"http://localhost:4200"}, + IsDev: true, + }) + if err != nil { + t.Fatalf("NewServer() error: %v", err) + } + + return srv.Handler() +} + +// e2eCookies extracts cookies from a response for use in subsequent requests. +// When multiple Set-Cookie headers share the same name, only the last one is kept +// (matching browser behavior where later cookies overwrite earlier ones). +func e2eCookies(t *testing.T, w *httptest.ResponseRecorder) []*http.Cookie { + t.Helper() + all := w.Result().Cookies() + seen := make(map[string]int, len(all)) + var deduped []*http.Cookie + for _, c := range all { + if idx, ok := seen[c.Name]; ok { + deduped[idx] = c // overwrite with later cookie + } else { + seen[c.Name] = len(deduped) + deduped = append(deduped, c) + } + } + return deduped +} + +// e2eAddCookies adds cookies to a request. +func e2eAddCookies(r *http.Request, cookies []*http.Cookie) { + for _, c := range cookies { + r.AddCookie(c) + } +} + +func TestE2E_HealthBypassesMiddleware(t *testing.T) { + handler := e2eServer(t) + + for _, path := range []string{"/health", "/ready"} { + t.Run(path, func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, path, nil) + r.RemoteAddr = "10.0.0.1:12345" + + handler.ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("GET %s status = %d, want %d", path, w.Code, http.StatusOK) + } + + // Health probes are on the top-level mux and bypass the middleware stack + // (including security headers). Verify they return valid JSON. + var body map[string]string + decodeData(t, w, &body) + if body["status"] != "ok" { + t.Errorf("GET %s status = %q, want %q", path, body["status"], "ok") + } + }) + } +} + +func TestE2E_FullSessionLifecycle(t *testing.T) { + handler := e2eServer(t) + + // --- Step 1: GET /api/v1/csrf-token → pre-session token --- + w1 := httptest.NewRecorder() + r1 := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil) + r1.RemoteAddr = "10.0.0.1:12345" + + handler.ServeHTTP(w1, r1) + + if w1.Code != http.StatusOK { + t.Fatalf("step 1: GET /csrf-token status = %d, want %d", w1.Code, http.StatusOK) + } + + var csrfResp map[string]string + decodeData(t, w1, &csrfResp) + + preSessionToken := csrfResp["csrfToken"] + if preSessionToken == "" { + t.Fatal("step 1: expected csrfToken in response") + } + if !strings.HasPrefix(preSessionToken, "pre:") { + t.Fatalf("step 1: token = %q, want pre: prefix", preSessionToken) + } + + // --- Step 2: POST /api/v1/sessions with pre-session CSRF → 201 --- + w2 := httptest.NewRecorder() + r2 := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil) + r2.RemoteAddr = "10.0.0.1:12345" + r2.Header.Set("X-CSRF-Token", preSessionToken) + + handler.ServeHTTP(w2, r2) + + if w2.Code != http.StatusCreated { + t.Fatalf("step 2: POST /sessions status = %d, want %d\nbody: %s", w2.Code, http.StatusCreated, w2.Body.String()) + } + + var createResp map[string]string + decodeData(t, w2, &createResp) + + sessionID := createResp["id"] + sessionCSRF := createResp["csrfToken"] + + if sessionID == "" { + t.Fatal("step 2: expected id in response") + } + if sessionCSRF == "" { + t.Fatal("step 2: expected csrfToken in response") + } + if strings.HasPrefix(sessionCSRF, "pre:") { + t.Fatal("step 2: session-bound token should not have pre: prefix") + } + + // Extract cookies (should have sid cookie) + cookies := e2eCookies(t, w2) + var sidCookie *http.Cookie + for _, c := range cookies { + if c.Name == "sid" { + sidCookie = c + } + } + if sidCookie == nil { + t.Fatal("step 2: expected sid cookie") + } + + // --- Step 3: GET /api/v1/sessions/{id} with cookie → 200 --- + w3 := httptest.NewRecorder() + r3 := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sessionID, nil) + r3.RemoteAddr = "10.0.0.1:12345" + e2eAddCookies(r3, cookies) + + handler.ServeHTTP(w3, r3) + + if w3.Code != http.StatusOK { + t.Fatalf("step 3: GET /sessions/%s status = %d, want %d\nbody: %s", sessionID, w3.Code, http.StatusOK, w3.Body.String()) + } + + var getResp map[string]string + decodeData(t, w3, &getResp) + + if getResp["id"] != sessionID { + t.Errorf("step 3: id = %q, want %q", getResp["id"], sessionID) + } + + // --- Step 4: GET /api/v1/sessions/{id}/messages → 200 (empty) --- + w4 := httptest.NewRecorder() + r4 := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sessionID+"/messages", nil) + r4.RemoteAddr = "10.0.0.1:12345" + e2eAddCookies(r4, cookies) + + handler.ServeHTTP(w4, r4) + + if w4.Code != http.StatusOK { + t.Fatalf("step 4: GET /sessions/%s/messages status = %d, want %d\nbody: %s", sessionID, w4.Code, http.StatusOK, w4.Body.String()) + } + + // --- Step 5: GET /api/v1/sessions with cookie → 200 (list contains session) --- + w5 := httptest.NewRecorder() + r5 := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil) + r5.RemoteAddr = "10.0.0.1:12345" + e2eAddCookies(r5, cookies) + + handler.ServeHTTP(w5, r5) + + if w5.Code != http.StatusOK { + t.Fatalf("step 5: GET /sessions status = %d, want %d", w5.Code, http.StatusOK) + } + + // --- Step 6: DELETE /api/v1/sessions/{id} with cookie + CSRF → 200 --- + w6 := httptest.NewRecorder() + r6 := httptest.NewRequest(http.MethodDelete, "/api/v1/sessions/"+sessionID, nil) + r6.RemoteAddr = "10.0.0.1:12345" + r6.Header.Set("X-CSRF-Token", sessionCSRF) + e2eAddCookies(r6, cookies) + + handler.ServeHTTP(w6, r6) + + if w6.Code != http.StatusOK { + t.Fatalf("step 6: DELETE /sessions/%s status = %d, want %d\nbody: %s", sessionID, w6.Code, http.StatusOK, w6.Body.String()) + } + + // --- Step 7: GET deleted session → 404 --- + w7 := httptest.NewRecorder() + r7 := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sessionID, nil) + r7.RemoteAddr = "10.0.0.1:12345" + e2eAddCookies(r7, cookies) + + handler.ServeHTTP(w7, r7) + + if w7.Code != http.StatusNotFound { + t.Fatalf("step 7: GET deleted session status = %d, want %d", w7.Code, http.StatusNotFound) + } +} + +func TestE2E_MissingCSRF_Rejected(t *testing.T) { + handler := e2eServer(t) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil) + r.RemoteAddr = "10.0.0.1:12345" + // No X-CSRF-Token header + + handler.ServeHTTP(w, r) + + if w.Code != http.StatusForbidden { + t.Fatalf("POST /sessions (no CSRF) status = %d, want %d\nbody: %s", w.Code, http.StatusForbidden, w.Body.String()) + } + + errResp := decodeErrorEnvelope(t, w) + if errResp.Code != "csrf_invalid" && errResp.Code != "session_required" { + t.Errorf("POST /sessions (no CSRF) code = %q, want csrf_invalid or session_required", errResp.Code) + } +} + +func TestE2E_InvalidCSRF_Rejected(t *testing.T) { + handler := e2eServer(t) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil) + r.RemoteAddr = "10.0.0.1:12345" + r.Header.Set("X-CSRF-Token", "totally-fake-token") + + handler.ServeHTTP(w, r) + + if w.Code != http.StatusForbidden { + t.Fatalf("POST /sessions (bad CSRF) status = %d, want %d", w.Code, http.StatusForbidden) + } +} + +func TestE2E_CrossSessionAccess_Denied(t *testing.T) { + handler := e2eServer(t) + + // Create session A + w1 := httptest.NewRecorder() + r1 := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil) + r1.RemoteAddr = "10.0.0.1:12345" + handler.ServeHTTP(w1, r1) + + var csrf1 map[string]string + decodeData(t, w1, &csrf1) + + w2 := httptest.NewRecorder() + r2 := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil) + r2.RemoteAddr = "10.0.0.1:12345" + r2.Header.Set("X-CSRF-Token", csrf1["csrfToken"]) + handler.ServeHTTP(w2, r2) + + var sessA map[string]string + decodeData(t, w2, &sessA) + cookiesA := e2eCookies(t, w2) + + // Create session B (different "client") + w3 := httptest.NewRecorder() + r3 := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil) + r3.RemoteAddr = "10.0.0.2:12345" + handler.ServeHTTP(w3, r3) + + var csrf2 map[string]string + decodeData(t, w3, &csrf2) + + w4 := httptest.NewRecorder() + r4 := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil) + r4.RemoteAddr = "10.0.0.2:12345" + r4.Header.Set("X-CSRF-Token", csrf2["csrfToken"]) + handler.ServeHTTP(w4, r4) + + var sessB map[string]string + decodeData(t, w4, &sessB) + + // Client A tries to access session B → 403 + w5 := httptest.NewRecorder() + r5 := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sessB["id"], nil) + r5.RemoteAddr = "10.0.0.1:12345" + e2eAddCookies(r5, cookiesA) // Cookie has session A + + handler.ServeHTTP(w5, r5) + + if w5.Code != http.StatusForbidden { + t.Fatalf("cross-session GET status = %d, want %d", w5.Code, http.StatusForbidden) + } + + errResp := decodeErrorEnvelope(t, w5) + if errResp.Code != "forbidden" { + t.Errorf("cross-session GET code = %q, want %q", errResp.Code, "forbidden") + } +} + +func TestE2E_RateLimiting(t *testing.T) { + handler := e2eServer(t) + + // Rate limiter is configured at 1 token/sec, burst 60. + // /health bypasses the middleware stack, so use an API path that goes through rate limiting. + var lastCode int + for i := range 65 { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil) + r.RemoteAddr = "10.99.99.99:12345" // Unique IP so no interference + + handler.ServeHTTP(w, r) + lastCode = w.Code + + if w.Code == http.StatusTooManyRequests { + // Verify Retry-After header + if got := w.Header().Get("Retry-After"); got != "1" { + t.Errorf("rate limited Retry-After = %q, want %q", got, "1") + } + t.Logf("rate limited at request %d", i+1) + return + } + } + + t.Fatalf("rate limiter: no 429 within 65 requests, last status = %d", lastCode) +} + +func TestE2E_SecurityHeaders(t *testing.T) { + handler := e2eServer(t) + + // Only check API paths — /health and /ready are on the top-level mux + // and intentionally bypass the middleware stack (including security headers). + paths := []struct { + method string + path string + }{ + {http.MethodGet, "/api/v1/csrf-token"}, + {http.MethodGet, "/api/v1/sessions"}, + } + + for _, p := range paths { + t.Run(p.method+" "+p.path, func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(p.method, p.path, nil) + r.RemoteAddr = "10.0.0.1:12345" + + handler.ServeHTTP(w, r) + + wantHeaders := map[string]string{ + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "Referrer-Policy": "strict-origin-when-cross-origin", + "Content-Security-Policy": "default-src 'none'", + } + + for header, want := range wantHeaders { + if got := w.Header().Get(header); got != want { + t.Errorf("%s %s header %q = %q, want %q", p.method, p.path, header, got, want) + } + } + }) + } +} + +func TestE2E_CORSPreflight(t *testing.T) { + handler := e2eServer(t) + + t.Run("allowed origin", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodOptions, "/api/v1/sessions", nil) + r.RemoteAddr = "10.0.0.1:12345" + r.Header.Set("Origin", "http://localhost:4200") + + handler.ServeHTTP(w, r) + + if w.Code != http.StatusNoContent { + t.Fatalf("CORS preflight status = %d, want %d", w.Code, http.StatusNoContent) + } + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:4200" { + t.Errorf("Allow-Origin = %q, want %q", got, "http://localhost:4200") + } + if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" { + t.Errorf("Allow-Credentials = %q, want %q", got, "true") + } + }) + + t.Run("disallowed origin", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodOptions, "/api/v1/sessions", nil) + r.RemoteAddr = "10.0.0.1:12345" + r.Header.Set("Origin", "http://evil.com") + + handler.ServeHTTP(w, r) + + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" { + t.Errorf("disallowed origin Allow-Origin = %q, want empty", got) + } + }) +} + +func TestE2E_SSEStream(t *testing.T) { + handler := e2eServer(t) + + // --- Create a session first (ownership check requires valid cookie) --- + w1 := httptest.NewRecorder() + r1 := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil) + r1.RemoteAddr = "10.0.0.1:12345" + handler.ServeHTTP(w1, r1) + + var csrfResp map[string]string + decodeData(t, w1, &csrfResp) + + w2 := httptest.NewRecorder() + r2 := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil) + r2.RemoteAddr = "10.0.0.1:12345" + r2.Header.Set("X-CSRF-Token", csrfResp["csrfToken"]) + handler.ServeHTTP(w2, r2) + + if w2.Code != http.StatusCreated { + t.Fatalf("create session status = %d, want %d\nbody: %s", w2.Code, http.StatusCreated, w2.Body.String()) + } + + var sessResp map[string]string + decodeData(t, w2, &sessResp) + sessionID := sessResp["id"] + cookies := e2eCookies(t, w2) + + // --- SSE stream with valid session cookie --- + // e2eServer has no ChatFlow configured (nil), so the handler returns + // an error event instead of chunk/done events. + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=e2e-1&session_id="+sessionID+"&query=hello", nil) + r.RemoteAddr = "10.0.0.1:12345" + e2eAddCookies(r, cookies) + + handler.ServeHTTP(w, r) + + // Should get SSE response (200 implicit from streaming) + if ct := w.Header().Get("Content-Type"); ct != "text/event-stream" { + t.Fatalf("SSE Content-Type = %q, want %q", ct, "text/event-stream") + } + + body := w.Body.String() + + // With nil ChatFlow, handler sends error event + if !strings.Contains(body, "event: error\n") { + t.Errorf("SSE response missing error event, body:\n%s", body) + } + if !strings.Contains(body, "chat flow not initialized") { + t.Errorf("SSE error missing expected message, body:\n%s", body) + } +} diff --git a/internal/api/health.go b/internal/api/health.go index dd46070..b9a093f 100644 --- a/internal/api/health.go +++ b/internal/api/health.go @@ -5,5 +5,5 @@ import "net/http" // health is a simple health check endpoint for Docker/Kubernetes probes. // Returns 200 OK with {"status":"ok"}. func health(w http.ResponseWriter, _ *http.Request) { - WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"}) + WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"}, nil) } diff --git a/internal/api/integration_test.go b/internal/api/integration_test.go new file mode 100644 index 0000000..94dedc1 --- /dev/null +++ b/internal/api/integration_test.go @@ -0,0 +1,382 @@ +//go:build integration + +package api + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/google/uuid" + + "github.com/koopa0/koopa/internal/session" + "github.com/koopa0/koopa/internal/sqlc" + "github.com/koopa0/koopa/internal/testutil" +) + +// setupIntegrationSessionManager creates a sessionManager backed by a real PostgreSQL database. +func setupIntegrationSessionManager(t *testing.T) *sessionManager { + t.Helper() + + db := testutil.SetupTestDB(t) + + store := session.New(sqlc.New(db.Pool), db.Pool, slog.New(slog.DiscardHandler)) + + return &sessionManager{ + store: store, + hmacSecret: []byte("test-secret-at-least-32-characters!!"), + isDev: true, + logger: slog.New(slog.DiscardHandler), + } +} + +func TestCreateSession_Success(t *testing.T) { + sm := setupIntegrationSessionManager(t) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil) + + sm.createSession(w, r) + + if w.Code != http.StatusCreated { + t.Fatalf("createSession() status = %d, want %d\nbody: %s", w.Code, http.StatusCreated, w.Body.String()) + } + + var resp map[string]string + decodeData(t, w, &resp) + + // Should return a valid UUID + if _, err := uuid.Parse(resp["id"]); err != nil { + t.Errorf("createSession() id = %q, want valid UUID", resp["id"]) + } + + // Should return a CSRF token bound to the new session + if resp["csrfToken"] == "" { + t.Error("createSession() expected csrfToken in response") + } + + // Should set a session cookie + cookies := w.Result().Cookies() + var found bool + for _, c := range cookies { + if c.Name == "sid" { + found = true + if c.Value != resp["id"] { + t.Errorf("createSession() cookie sid = %q, want %q", c.Value, resp["id"]) + } + if !c.HttpOnly { + t.Error("createSession() cookie should be HttpOnly") + } + } + } + if !found { + t.Error("createSession() expected sid cookie to be set") + } +} + +func TestGetSession_Success(t *testing.T) { + sm := setupIntegrationSessionManager(t) + ctx := context.Background() + + // Create a session first + sess, err := sm.store.CreateSession(ctx, "Test Session") + if err != nil { + t.Fatalf("setup: CreateSession() error: %v", err) + } + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sess.ID.String(), nil) + r.SetPathValue("id", sess.ID.String()) + + // Inject session ownership (same session ID in context) + rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID) + r = r.WithContext(rctx) + + sm.getSession(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("getSession(%s) status = %d, want %d\nbody: %s", sess.ID, w.Code, http.StatusOK, w.Body.String()) + } + + var resp map[string]string + decodeData(t, w, &resp) + + if resp["id"] != sess.ID.String() { + t.Errorf("getSession() id = %q, want %q", resp["id"], sess.ID.String()) + } + if resp["title"] != "Test Session" { + t.Errorf("getSession() title = %q, want %q", resp["title"], "Test Session") + } + if resp["createdAt"] == "" { + t.Error("getSession() expected createdAt in response") + } + if resp["updatedAt"] == "" { + t.Error("getSession() expected updatedAt in response") + } +} + +func TestGetSession_NotFound(t *testing.T) { + sm := setupIntegrationSessionManager(t) + + missingID := uuid.New() + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+missingID.String(), nil) + r.SetPathValue("id", missingID.String()) + + // Set ownership to match (bypasses ownership check, tests store-level 404) + rctx := context.WithValue(r.Context(), ctxKeySessionID, missingID) + r = r.WithContext(rctx) + + sm.getSession(w, r) + + if w.Code != http.StatusNotFound { + t.Fatalf("getSession(missing) status = %d, want %d", w.Code, http.StatusNotFound) + } + + errResp := decodeErrorEnvelope(t, w) + if errResp.Code != "not_found" { + t.Errorf("getSession(missing) code = %q, want %q", errResp.Code, "not_found") + } +} + +func TestListSessions_WithSession(t *testing.T) { + sm := setupIntegrationSessionManager(t) + ctx := context.Background() + + // Create a session + sess, err := sm.store.CreateSession(ctx, "My Chat") + if err != nil { + t.Fatalf("setup: CreateSession() error: %v", err) + } + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil) + + // Inject session ownership + rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID) + r = r.WithContext(rctx) + + sm.listSessions(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("listSessions() status = %d, want %d\nbody: %s", w.Code, http.StatusOK, w.Body.String()) + } + + type sessionItem struct { + ID string `json:"id"` + Title string `json:"title"` + UpdatedAt string `json:"updatedAt"` + } + var items []sessionItem + decodeData(t, w, &items) + + if len(items) != 1 { + t.Fatalf("listSessions() returned %d items, want 1", len(items)) + } + if items[0].ID != sess.ID.String() { + t.Errorf("listSessions() items[0].id = %q, want %q", items[0].ID, sess.ID.String()) + } + if items[0].Title != "My Chat" { + t.Errorf("listSessions() items[0].title = %q, want %q", items[0].Title, "My Chat") + } + if items[0].UpdatedAt == "" { + t.Error("listSessions() expected updatedAt in item") + } +} + +func TestGetSessionMessages_Empty(t *testing.T) { + sm := setupIntegrationSessionManager(t) + ctx := context.Background() + + // Create a session with no messages + sess, err := sm.store.CreateSession(ctx, "") + if err != nil { + t.Fatalf("setup: CreateSession() error: %v", err) + } + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sess.ID.String()+"/messages", nil) + r.SetPathValue("id", sess.ID.String()) + + rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID) + r = r.WithContext(rctx) + + sm.getSessionMessages(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("getSessionMessages(empty) status = %d, want %d\nbody: %s", w.Code, http.StatusOK, w.Body.String()) + } + + // Decode the raw envelope to check the data type + var env struct { + Data json.RawMessage `json:"data"` + } + if err := json.NewDecoder(w.Body).Decode(&env); err != nil { + t.Fatalf("decoding envelope: %v", err) + } + + var items []json.RawMessage + if err := json.Unmarshal(env.Data, &items); err != nil { + t.Fatalf("decoding data as array: %v", err) + } + + if len(items) != 0 { + t.Errorf("getSessionMessages(empty) returned %d items, want 0", len(items)) + } +} + +func TestDeleteSession_Success(t *testing.T) { + sm := setupIntegrationSessionManager(t) + ctx := context.Background() + + // Create a session + sess, err := sm.store.CreateSession(ctx, "To Delete") + if err != nil { + t.Fatalf("setup: CreateSession() error: %v", err) + } + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodDelete, "/api/v1/sessions/"+sess.ID.String(), nil) + r.SetPathValue("id", sess.ID.String()) + + // Inject ownership + rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID) + r = r.WithContext(rctx) + + sm.deleteSession(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("deleteSession() status = %d, want %d\nbody: %s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp map[string]string + decodeData(t, w, &resp) + + if resp["status"] != "deleted" { + t.Errorf("deleteSession() status = %q, want %q", resp["status"], "deleted") + } + + // Verify session is actually gone + _, err = sm.store.Session(ctx, sess.ID) + if err == nil { + t.Error("deleteSession() session still exists after deletion") + } +} + +func TestCSRFTokenEndpoint_WithSession(t *testing.T) { + sm := setupIntegrationSessionManager(t) + ctx := context.Background() + + // Create a real session + sess, err := sm.store.CreateSession(ctx, "") + if err != nil { + t.Fatalf("setup: CreateSession() error: %v", err) + } + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil) + r.AddCookie(&http.Cookie{ + Name: "sid", + Value: sess.ID.String(), + }) + + sm.csrfToken(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("csrfToken(with session) status = %d, want %d", w.Code, http.StatusOK) + } + + var body map[string]string + decodeData(t, w, &body) + + token := body["csrfToken"] + if token == "" { + t.Fatal("csrfToken(with session) expected csrfToken in response") + } + + // Session-bound tokens should NOT have pre: prefix + if isPreSessionToken(token) { + t.Error("csrfToken(with session) should return session-bound token, not pre-session") + } + + // Token should validate against the session + if err := sm.CheckCSRF(sess.ID, token); err != nil { + t.Fatalf("csrfToken(with session) returned invalid token: %v", err) + } +} + +func TestGetSessionMessages_WithMessages(t *testing.T) { + sm := setupIntegrationSessionManager(t) + ctx := context.Background() + + // Create a session with messages + sess, err := sm.store.CreateSession(ctx, "Test Chat") + if err != nil { + t.Fatalf("setup: CreateSession() error: %v", err) + } + + // Add user and model messages + msgs := []*ai.Message{ + ai.NewUserMessage(ai.NewTextPart("What is Go?")), + ai.NewModelMessage(ai.NewTextPart("Go is a programming language.")), + } + if err := sm.store.AppendMessages(ctx, sess.ID, msgs); err != nil { + t.Fatalf("setup: AppendMessages() error: %v", err) + } + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sess.ID.String()+"/messages", nil) + r.SetPathValue("id", sess.ID.String()) + + rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID) + r = r.WithContext(rctx) + + sm.getSessionMessages(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("getSessionMessages() status = %d, want %d\nbody: %s", w.Code, http.StatusOK, w.Body.String()) + } + + type messageItem struct { + ID string `json:"id"` + Role string `json:"role"` + Content string `json:"content"` + CreatedAt string `json:"createdAt"` + } + var items []messageItem + decodeData(t, w, &items) + + if len(items) != 2 { + t.Fatalf("getSessionMessages() returned %d items, want 2", len(items)) + } + + // First message: user + if items[0].Role != "user" { + t.Errorf("getSessionMessages() items[0].role = %q, want %q", items[0].Role, "user") + } + if items[0].Content != "What is Go?" { + t.Errorf("getSessionMessages() items[0].content = %q, want %q", items[0].Content, "What is Go?") + } + if items[0].ID == "" { + t.Error("getSessionMessages() items[0].id is empty") + } + if items[0].CreatedAt == "" { + t.Error("getSessionMessages() items[0].createdAt is empty") + } + + // Second message: model (normalizeRole converts "model" → "assistant" in DB) + if items[1].Role != "assistant" { + t.Errorf("getSessionMessages() items[1].role = %q, want %q", items[1].Role, "assistant") + } + if items[1].Content != "Go is a programming language." { + t.Errorf("getSessionMessages() items[1].content = %q, want %q", items[1].Content, "Go is a programming language.") + } + if items[1].ID == "" { + t.Error("getSessionMessages() items[1].id is empty") + } +} diff --git a/internal/api/middleware.go b/internal/api/middleware.go index a9a02d0..821eaef 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -3,10 +3,8 @@ package api import ( "context" "log/slog" - "net" "net/http" "strings" - "sync" "time" "github.com/google/uuid" @@ -17,9 +15,9 @@ type sessionIDKey struct{} var ctxKeySessionID = sessionIDKey{} -// SessionIDFromContext retrieves the session ID from the request context. +// sessionIDFromContext retrieves the session ID from the request context. // Returns uuid.Nil and false if not found. -func SessionIDFromContext(ctx context.Context) (uuid.UUID, bool) { +func sessionIDFromContext(ctx context.Context) (uuid.UUID, bool) { sessionID, ok := ctx.Value(ctxKeySessionID).(uuid.UUID) return sessionID, ok } @@ -27,46 +25,47 @@ func SessionIDFromContext(ctx context.Context) (uuid.UUID, bool) { // loggingWriter wraps http.ResponseWriter to capture metrics. // Implements Flusher for SSE streaming and Unwrap for ResponseController. type loggingWriter struct { - http.ResponseWriter + w http.ResponseWriter statusCode int bytesWritten int64 } -func (w *loggingWriter) WriteHeader(code int) { - w.statusCode = code - w.ResponseWriter.WriteHeader(code) +func (lw *loggingWriter) Header() http.Header { + return lw.w.Header() +} + +func (lw *loggingWriter) WriteHeader(code int) { + lw.statusCode = code + lw.w.WriteHeader(code) } //nolint:wrapcheck // http.ResponseWriter wrapper must return unwrapped errors -func (w *loggingWriter) Write(b []byte) (int, error) { - if w.statusCode == 0 { - w.statusCode = http.StatusOK +func (lw *loggingWriter) Write(b []byte) (int, error) { + if lw.statusCode == 0 { + lw.statusCode = http.StatusOK } - n, err := w.ResponseWriter.Write(b) - w.bytesWritten += int64(n) + n, err := lw.w.Write(b) + lw.bytesWritten += int64(n) return n, err } // Flush implements http.Flusher for SSE streaming support. -func (w *loggingWriter) Flush() { - if f, ok := w.ResponseWriter.(http.Flusher); ok { +func (lw *loggingWriter) Flush() { + if f, ok := lw.w.(http.Flusher); ok { f.Flush() } } // Unwrap returns the underlying ResponseWriter for http.ResponseController. -func (w *loggingWriter) Unwrap() http.ResponseWriter { - return w.ResponseWriter +func (lw *loggingWriter) Unwrap() http.ResponseWriter { + return lw.w } // recoveryMiddleware recovers from panics to prevent server crashes. func recoveryMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - wrapper := &loggingWriter{ - ResponseWriter: w, - statusCode: 0, - } + wrapper := &loggingWriter{w: w} defer func() { if err := recover(); err != nil { @@ -77,7 +76,7 @@ func recoveryMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { ) if wrapper.statusCode == 0 { - WriteError(w, http.StatusInternalServerError, "internal_error", "internal server error") + WriteError(w, http.StatusInternalServerError, "internal_error", "internal server error", logger) } else { logger.Warn("cannot send error response, headers already sent", "path", r.URL.Path, @@ -92,22 +91,29 @@ func recoveryMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { } // loggingMiddleware logs request details including latency, status, and response size. +// Reuses an existing *loggingWriter from outer middleware (e.g., recoveryMiddleware) +// to avoid double-wrapping the ResponseWriter. func loggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() - wrapper := &loggingWriter{ - ResponseWriter: w, - statusCode: http.StatusOK, + wrapper, ok := w.(*loggingWriter) + if !ok { + wrapper = &loggingWriter{w: w} } next.ServeHTTP(wrapper, r) + status := wrapper.statusCode + if status == 0 { + status = http.StatusOK + } + logger.Debug("http request", "method", r.Method, "path", r.URL.Path, - "status", wrapper.statusCode, + "status", status, "bytes", wrapper.bytesWritten, "duration", time.Since(start), "ip", r.RemoteAddr, @@ -146,39 +152,21 @@ func corsMiddleware(allowedOrigins []string) func(http.Handler) http.Handler { } } -// sessionMiddleware ensures a valid session exists before processing the request. -// GET/HEAD/OPTIONS: read-only, don't create session. -// POST/PUT/DELETE: create session if needed. -func sessionMiddleware(sm *sessionManager, logger *slog.Logger) func(http.Handler) http.Handler { +// sessionMiddleware extracts the session ID from the cookie and adds it to the +// request context. If no valid session cookie is present, the request continues +// without a session ID in context. Individual handlers are responsible for +// creating sessions when needed (e.g., createSession). +func sessionMiddleware(sm *sessionManager) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions { - sessionID, err := sm.ID(r) - if err == nil { - ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) - next.ServeHTTP(w, r.WithContext(ctx)) - return - } - // No session — pre-session state, continue without session ID - next.ServeHTTP(w, r) - return - } - - // State-changing request: create session if needed - sessionID, err := sm.GetOrCreate(w, r) - if err != nil { - logger.Error("session creation failed", - "error", err, - "path", r.URL.Path, - "method", r.Method, - "remote_addr", r.RemoteAddr, - ) - WriteError(w, http.StatusInternalServerError, "session_error", "session creation failed") + sessionID, err := sm.ID(r) + if err == nil { + ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) + next.ServeHTTP(w, r.WithContext(ctx)) return } - - ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) - next.ServeHTTP(w, r.WithContext(ctx)) + // No valid session cookie — continue without session in context + next.ServeHTTP(w, r) }) } } @@ -205,7 +193,7 @@ func csrfMiddleware(sm *sessionManager, logger *slog.Logger) func(http.Handler) "path", r.URL.Path, "method", r.Method, ) - WriteError(w, http.StatusForbidden, "csrf_invalid", "CSRF validation failed") + WriteError(w, http.StatusForbidden, "csrf_invalid", "CSRF validation failed", logger) return } next.ServeHTTP(w, r) @@ -213,24 +201,24 @@ func csrfMiddleware(sm *sessionManager, logger *slog.Logger) func(http.Handler) } // Session-bound token - sessionID, ok := SessionIDFromContext(r.Context()) + sessionID, ok := sessionIDFromContext(r.Context()) if !ok { - logger.Error("CSRF validation failed: session ID not in context", + logger.Error("validating CSRF: session ID not in context", "path", r.URL.Path, "method", r.Method, ) - WriteError(w, http.StatusForbidden, "session_required", "session required") + WriteError(w, http.StatusForbidden, "session_required", "session required", logger) return } if err := sm.CheckCSRF(sessionID, csrfToken); err != nil { - logger.Warn("CSRF validation failed", + logger.Warn("validating CSRF", "error", err, "session", sessionID, "path", r.URL.Path, "method", r.Method, ) - WriteError(w, http.StatusForbidden, "csrf_invalid", "CSRF validation failed") + WriteError(w, http.StatusForbidden, "csrf_invalid", "CSRF validation failed", logger) return } @@ -255,140 +243,3 @@ func setSecurityHeaders(w http.ResponseWriter, isDev bool) { func isPreSessionToken(token string) bool { return strings.HasPrefix(token, preSessionPrefix) } - -// ============================================================================ -// Rate Limiting -// ============================================================================ - -const ( - rateLimiterCleanupInterval = 5 * time.Minute - rateLimiterStaleThreshold = 10 * time.Minute -) - -// rateLimiter implements per-IP token bucket rate limiting. -// Cleanup of stale entries happens inline during allow() calls. -type rateLimiter struct { - mu sync.Mutex - visitors map[string]*visitor - rate float64 // tokens per second - burst int // max tokens (also initial tokens) - lastCleanup time.Time -} - -// visitor tracks the token bucket state for a single IP. -type visitor struct { - tokens float64 - lastSeen time.Time -} - -// newRateLimiter creates a rate limiter. -// rate: tokens refilled per second. burst: maximum tokens (and initial allowance). -func newRateLimiter(rate float64, burst int) *rateLimiter { - return &rateLimiter{ - visitors: make(map[string]*visitor), - rate: rate, - burst: burst, - lastCleanup: time.Now(), - } -} - -// allow checks if a request from the given IP is allowed. -// Returns false if the IP has exhausted its tokens. -func (rl *rateLimiter) allow(ip string) bool { - rl.mu.Lock() - defer rl.mu.Unlock() - - now := time.Now() - - // Periodic cleanup of stale entries - if now.Sub(rl.lastCleanup) > rateLimiterCleanupInterval { - for k, v := range rl.visitors { - if now.Sub(v.lastSeen) > rateLimiterStaleThreshold { - delete(rl.visitors, k) - } - } - rl.lastCleanup = now - } - - v, exists := rl.visitors[ip] - if !exists { - rl.visitors[ip] = &visitor{ - tokens: float64(rl.burst) - 1, - lastSeen: now, - } - return true - } - - // Refill tokens based on elapsed time - elapsed := now.Sub(v.lastSeen).Seconds() - v.tokens += elapsed * rl.rate - if v.tokens > float64(rl.burst) { - v.tokens = float64(rl.burst) - } - v.lastSeen = now - - if v.tokens < 1 { - return false - } - - v.tokens-- - return true -} - -// rateLimitMiddleware returns middleware that limits requests per IP. -// Uses token bucket algorithm: each IP gets `burst` initial tokens, -// refilling at `rate` tokens per second. -func rateLimitMiddleware(rl *rateLimiter, trustProxy bool, logger *slog.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip := clientIP(r, trustProxy) - if !rl.allow(ip) { - logger.Warn("rate limit exceeded", - "ip", ip, - "path", r.URL.Path, - "method", r.Method, - ) - w.Header().Set("Retry-After", "1") - WriteError(w, http.StatusTooManyRequests, "rate_limited", "too many requests") - return - } - next.ServeHTTP(w, r) - }) - } -} - -// clientIP extracts the client IP from the request. -// -// When trustProxy is true, checks X-Real-IP first (set by nginx/HAProxy), -// then X-Forwarded-For (first IP). Header values are validated with net.ParseIP -// to prevent injection of non-IP strings into rate limiter keys. -// -// When trustProxy is false, only uses RemoteAddr (safe default for direct exposure). -func clientIP(r *http.Request, trustProxy bool) string { - if trustProxy { - // Prefer X-Real-IP (single value, set by reverse proxy) - if xri := r.Header.Get("X-Real-IP"); xri != "" { - if ip := net.ParseIP(strings.TrimSpace(xri)); ip != nil { - return ip.String() - } - } - - // Fall back to X-Forwarded-For (first IP is the client) - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - raw := xff - if first, _, ok := strings.Cut(xff, ","); ok { - raw = first - } - if ip := net.ParseIP(strings.TrimSpace(raw)); ip != nil { - return ip.String() - } - } - } - - // Fall back to RemoteAddr (strip port) - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return r.RemoteAddr - } - return ip -} diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index c108e77..c27e60f 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -6,7 +6,6 @@ import ( "net/http" "net/http/httptest" "testing" - "time" "github.com/google/uuid" ) @@ -44,7 +43,7 @@ func TestRecoveryMiddleware_NoPanic(t *testing.T) { logger := discardLogger() okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - WriteJSON(w, http.StatusOK, map[string]string{"ok": "true"}) + WriteJSON(w, http.StatusOK, map[string]string{"ok": "true"}, nil) }) handler := recoveryMiddleware(logger)(okHandler) @@ -266,7 +265,7 @@ func TestSecurityHeaders(t *testing.T) { w := httptest.NewRecorder() setSecurityHeaders(w, false) - expected := map[string]string{ + wantHeaders := map[string]string{ "X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", "Referrer-Policy": "strict-origin-when-cross-origin", @@ -274,7 +273,7 @@ func TestSecurityHeaders(t *testing.T) { "Strict-Transport-Security": "max-age=63072000; includeSubDomains", } - for header, want := range expected { + for header, want := range wantHeaders { if got := w.Header().Get(header); got != want { t.Errorf("setSecurityHeaders(isDev=false) %q = %q, want %q", header, got, want) } @@ -296,208 +295,133 @@ func TestSecurityHeaders(t *testing.T) { }) } -// ============================================================================ -// Rate Limiting Tests -// ============================================================================ +func Test_sessionIDFromContext(t *testing.T) { + t.Run("present", func(t *testing.T) { + id := uuid.New() + ctx := context.WithValue(context.Background(), ctxKeySessionID, id) -func TestRateLimiter_AllowsWithinBurst(t *testing.T) { - rl := newRateLimiter(1.0, 5) + got, ok := sessionIDFromContext(ctx) + if !ok { + t.Fatal("sessionIDFromContext() ok = false, want true") + } + if got != id { + t.Errorf("sessionIDFromContext() = %s, want %s", got, id) + } + }) - for i := range 5 { - if !rl.allow("1.2.3.4") { - t.Fatalf("allow() returned false on request %d (within burst of 5)", i+1) + t.Run("absent", func(t *testing.T) { + _, ok := sessionIDFromContext(context.Background()) + if ok { + t.Error("sessionIDFromContext(empty) ok = true, want false") } - } + }) } -func TestRateLimiter_BlocksAfterBurst(t *testing.T) { - rl := newRateLimiter(1.0, 3) - - // Exhaust the burst - for range 3 { - rl.allow("1.2.3.4") - } - - if rl.allow("1.2.3.4") { - t.Error("allow() should return false after burst exhausted") +func TestSessionMiddleware_GET_WithCookie(t *testing.T) { + logger := discardLogger() + sm := &sessionManager{ + hmacSecret: []byte("test-secret-at-least-32-characters!!"), + logger: logger, } -} - -func TestRateLimiter_SeparateIPs(t *testing.T) { - rl := newRateLimiter(1.0, 2) - // Exhaust IP 1 - rl.allow("1.1.1.1") - rl.allow("1.1.1.1") + sessionID := uuid.New() - // IP 2 should still be allowed - if !rl.allow("2.2.2.2") { - t.Error("allow() should allow a different IP") - } -} + var gotID uuid.UUID + var gotOK bool + handler := sessionMiddleware(sm)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + gotID, gotOK = sessionIDFromContext(r.Context()) + })) -func TestRateLimiter_RefillsOverTime(t *testing.T) { - rl := newRateLimiter(100.0, 1) // 100 tokens/sec so we can test quickly + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil) + r.AddCookie(&http.Cookie{Name: "sid", Value: sessionID.String()}) - // Use the single token - rl.allow("1.2.3.4") + handler.ServeHTTP(w, r) - if rl.allow("1.2.3.4") { - t.Error("allow() should be blocked immediately after burst exhausted") + if !gotOK { + t.Fatal("sessionMiddleware(GET, valid cookie) expected session ID in context") } - - // Wait enough time for a token to refill - time.Sleep(20 * time.Millisecond) - - if !rl.allow("1.2.3.4") { - t.Error("allow() should be allowed after token refill") + if gotID != sessionID { + t.Errorf("sessionMiddleware(GET, valid cookie) session ID = %s, want %s", gotID, sessionID) } } -func TestRateLimitMiddleware_Returns429(t *testing.T) { - rl := newRateLimiter(0.001, 1) // Very low rate +func TestSessionMiddleware_GET_WithoutCookie(t *testing.T) { logger := discardLogger() + sm := &sessionManager{ + hmacSecret: []byte("test-secret-at-least-32-characters!!"), + logger: logger, + } - handler := rateLimitMiddleware(rl, false, logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusOK) + var gotOK bool + called := false + handler := sessionMiddleware(sm)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + called = true + _, gotOK = sessionIDFromContext(r.Context()) })) - // First request should succeed w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.RemoteAddr = "10.0.0.1:12345" - handler.ServeHTTP(w, r) - - if w.Code != http.StatusOK { - t.Fatalf("first request status = %d, want %d", w.Code, http.StatusOK) - } + r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil) - // Second request should be rate limited - w = httptest.NewRecorder() - r = httptest.NewRequest(http.MethodGet, "/", nil) - r.RemoteAddr = "10.0.0.1:12345" handler.ServeHTTP(w, r) - if w.Code != http.StatusTooManyRequests { - t.Fatalf("rate limited request status = %d, want %d", w.Code, http.StatusTooManyRequests) + if !called { + t.Fatal("sessionMiddleware(GET, no cookie) did not call next handler") } - - if got := w.Header().Get("Retry-After"); got != "1" { - t.Errorf("Retry-After = %q, want %q", got, "1") + if gotOK { + t.Error("sessionMiddleware(GET, no cookie) should not have session ID in context") } } -func TestClientIP(t *testing.T) { - tests := []struct { - name string - trustProxy bool - remoteAddr string - xff string - xri string - want string - }{ - { - name: "remote addr with port", - trustProxy: true, - remoteAddr: "10.0.0.1:12345", - want: "10.0.0.1", - }, - { - name: "X-Forwarded-For single when trusted", - trustProxy: true, - remoteAddr: "127.0.0.1:80", - xff: "203.0.113.50", - want: "203.0.113.50", - }, - { - name: "X-Forwarded-For multiple when trusted", - trustProxy: true, - remoteAddr: "127.0.0.1:80", - xff: "203.0.113.50, 70.41.3.18, 150.172.238.178", - want: "203.0.113.50", - }, - { - name: "X-Real-IP when trusted", - trustProxy: true, - remoteAddr: "127.0.0.1:80", - xri: "203.0.113.50", - want: "203.0.113.50", - }, - { - name: "X-Real-IP takes precedence over X-Forwarded-For when trusted", - trustProxy: true, - remoteAddr: "127.0.0.1:80", - xff: "203.0.113.50", - xri: "198.51.100.1", - want: "198.51.100.1", - }, - { - name: "untrusted ignores X-Forwarded-For", - trustProxy: false, - remoteAddr: "10.0.0.1:12345", - xff: "203.0.113.50", - want: "10.0.0.1", - }, - { - name: "untrusted ignores X-Real-IP", - trustProxy: false, - remoteAddr: "10.0.0.1:12345", - xri: "203.0.113.50", - want: "10.0.0.1", - }, - { - name: "invalid X-Real-IP falls through to XFF", - trustProxy: true, - remoteAddr: "127.0.0.1:80", - xri: "not-an-ip", - xff: "203.0.113.50", - want: "203.0.113.50", - }, - { - name: "invalid XFF falls through to RemoteAddr", - trustProxy: true, - remoteAddr: "127.0.0.1:80", - xff: "not-an-ip", - want: "127.0.0.1", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.RemoteAddr = tt.remoteAddr - if tt.xff != "" { - r.Header.Set("X-Forwarded-For", tt.xff) - } - if tt.xri != "" { - r.Header.Set("X-Real-IP", tt.xri) - } +func TestSessionMiddleware_GET_InvalidCookie(t *testing.T) { + logger := discardLogger() + sm := &sessionManager{ + hmacSecret: []byte("test-secret-at-least-32-characters!!"), + logger: logger, + } - if got := clientIP(r, tt.trustProxy); got != tt.want { - t.Errorf("clientIP(r, %v) = %q, want %q", tt.trustProxy, got, tt.want) - } - }) + var gotOK bool + handler := sessionMiddleware(sm)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + _, gotOK = sessionIDFromContext(r.Context()) + })) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil) + r.AddCookie(&http.Cookie{Name: "sid", Value: "not-a-uuid"}) + + handler.ServeHTTP(w, r) + + if gotOK { + t.Error("sessionMiddleware(GET, invalid cookie) should not have session ID in context") } } -func TestSessionIDFromContext(t *testing.T) { - t.Run("present", func(t *testing.T) { - id := uuid.New() - ctx := context.WithValue(context.Background(), ctxKeySessionID, id) +func TestLoggingMiddleware(t *testing.T) { + logger := discardLogger() - got, ok := SessionIDFromContext(ctx) - if !ok { - t.Fatal("expected session ID to be present") - } - if got != id { - t.Errorf("SessionIDFromContext() = %s, want %s", got, id) + called := false + inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusCreated) + if _, err := w.Write([]byte("hello")); err != nil { + t.Errorf("Write() error: %v", err) } }) - t.Run("absent", func(t *testing.T) { - _, ok := SessionIDFromContext(context.Background()) - if ok { - t.Error("expected session ID to be absent") - } - }) + handler := loggingMiddleware(logger)(inner) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/test", nil) + + handler.ServeHTTP(w, r) + + if !called { + t.Fatal("loggingMiddleware did not call next handler") + } + if w.Code != http.StatusCreated { + t.Errorf("loggingMiddleware status = %d, want %d", w.Code, http.StatusCreated) + } + if w.Body.String() != "hello" { + t.Errorf("loggingMiddleware body = %q, want %q", w.Body.String(), "hello") + } } diff --git a/internal/api/ratelimit.go b/internal/api/ratelimit.go new file mode 100644 index 0000000..c937e49 --- /dev/null +++ b/internal/api/ratelimit.go @@ -0,0 +1,135 @@ +package api + +import ( + "log/slog" + "net" + "net/http" + "strings" + "sync" + "time" + + "golang.org/x/time/rate" +) + +const ( + rateLimiterCleanupInterval = 5 * time.Minute + rateLimiterStaleThreshold = 10 * time.Minute +) + +// rateLimiter implements per-IP rate limiting using golang.org/x/time/rate. +// Cleanup of stale entries happens inline during allow() calls. +type rateLimiter struct { + mu sync.Mutex + visitors map[string]*visitor + limit rate.Limit + burst int + lastCleanup time.Time +} + +// visitor holds a rate limiter and last-seen time for a single IP. +type visitor struct { + limiter *rate.Limiter + lastSeen time.Time +} + +// newRateLimiter creates a rate limiter. +// r: tokens refilled per second. burst: maximum tokens (and initial allowance). +func newRateLimiter(r float64, burst int) *rateLimiter { + return &rateLimiter{ + visitors: make(map[string]*visitor), + limit: rate.Limit(r), + burst: burst, + lastCleanup: time.Now(), + } +} + +// allow checks if a request from the given IP is allowed. +// Returns false if the IP has exhausted its tokens. +func (rl *rateLimiter) allow(ip string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + + // Periodic cleanup of stale entries + if now.Sub(rl.lastCleanup) > rateLimiterCleanupInterval { + for k, v := range rl.visitors { + if now.Sub(v.lastSeen) > rateLimiterStaleThreshold { + delete(rl.visitors, k) + } + } + rl.lastCleanup = now + } + + v, exists := rl.visitors[ip] + if !exists { + limiter := rate.NewLimiter(rl.limit, rl.burst) + rl.visitors[ip] = &visitor{ + limiter: limiter, + lastSeen: now, + } + limiter.Allow() + return true + } + + v.lastSeen = now + return v.limiter.Allow() +} + +// rateLimitMiddleware returns middleware that limits requests per IP. +// Uses token bucket algorithm: each IP gets `burst` initial tokens, +// refilling at `rate` tokens per second. +func rateLimitMiddleware(rl *rateLimiter, trustProxy bool, logger *slog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := clientIP(r, trustProxy) + if !rl.allow(ip) { + logger.Warn("rate limit exceeded", + "ip", ip, + "path", r.URL.Path, + "method", r.Method, + ) + w.Header().Set("Retry-After", "1") + WriteError(w, http.StatusTooManyRequests, "rate_limited", "too many requests", logger) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// clientIP extracts the client IP from the request. +// +// When trustProxy is true, checks X-Real-IP first (set by nginx/HAProxy), +// then X-Forwarded-For (first IP). Header values are validated with net.ParseIP +// to prevent injection of non-IP strings into rate limiter keys. +// +// When trustProxy is false, only uses RemoteAddr (safe default for direct exposure). +func clientIP(r *http.Request, trustProxy bool) string { + if trustProxy { + // Prefer X-Real-IP (single value, set by reverse proxy) + if xri := r.Header.Get("X-Real-IP"); xri != "" { + if ip := net.ParseIP(strings.TrimSpace(xri)); ip != nil { + return ip.String() + } + } + + // Fall back to X-Forwarded-For (first IP is the client) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + raw := xff + if first, _, ok := strings.Cut(xff, ","); ok { + raw = first + } + if ip := net.ParseIP(strings.TrimSpace(raw)); ip != nil { + return ip.String() + } + } + } + + // Fall back to RemoteAddr (strip port) + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return ip +} diff --git a/internal/api/ratelimit_test.go b/internal/api/ratelimit_test.go new file mode 100644 index 0000000..ff2d26b --- /dev/null +++ b/internal/api/ratelimit_test.go @@ -0,0 +1,204 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestRateLimiter_AllowsWithinBurst(t *testing.T) { + rl := newRateLimiter(1.0, 5) + + for i := range 5 { + if !rl.allow("1.2.3.4") { + t.Fatalf("allow() returned false on request %d (within burst of 5)", i+1) + } + } +} + +func TestRateLimiter_BlocksAfterBurst(t *testing.T) { + rl := newRateLimiter(1.0, 3) + + // Exhaust the burst + for range 3 { + rl.allow("1.2.3.4") + } + + if rl.allow("1.2.3.4") { + t.Error("allow() should return false after burst exhausted") + } +} + +func TestRateLimiter_SeparateIPs(t *testing.T) { + rl := newRateLimiter(1.0, 2) + + // Exhaust IP 1 + rl.allow("1.1.1.1") + rl.allow("1.1.1.1") + + // IP 2 should still be allowed + if !rl.allow("2.2.2.2") { + t.Error("allow() should allow a different IP") + } +} + +func TestRateLimiter_RefillsOverTime(t *testing.T) { + rl := newRateLimiter(100.0, 1) // 100 tokens/sec so we can test quickly + + // Use the single token + rl.allow("1.2.3.4") + + if rl.allow("1.2.3.4") { + t.Error("allow() should be blocked immediately after burst exhausted") + } + + // Wait enough time for a token to refill + time.Sleep(20 * time.Millisecond) + + if !rl.allow("1.2.3.4") { + t.Error("allow() should be allowed after token refill") + } +} + +func TestRateLimitMiddleware_Returns429(t *testing.T) { + rl := newRateLimiter(0.001, 1) // Very low rate + logger := discardLogger() + + handler := rateLimitMiddleware(rl, false, logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should succeed + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.RemoteAddr = "10.0.0.1:12345" + handler.ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("first request status = %d, want %d", w.Code, http.StatusOK) + } + + // Second request should be rate limited + w = httptest.NewRecorder() + r = httptest.NewRequest(http.MethodGet, "/", nil) + r.RemoteAddr = "10.0.0.1:12345" + handler.ServeHTTP(w, r) + + if w.Code != http.StatusTooManyRequests { + t.Fatalf("rate limited request status = %d, want %d", w.Code, http.StatusTooManyRequests) + } + + if got := w.Header().Get("Retry-After"); got != "1" { + t.Errorf("Retry-After = %q, want %q", got, "1") + } +} + +func TestClientIP(t *testing.T) { + tests := []struct { + name string + trustProxy bool + remoteAddr string + xff string + xri string + want string + }{ + { + name: "remote addr with port", + trustProxy: true, + remoteAddr: "10.0.0.1:12345", + want: "10.0.0.1", + }, + { + name: "X-Forwarded-For single when trusted", + trustProxy: true, + remoteAddr: "127.0.0.1:80", + xff: "203.0.113.50", + want: "203.0.113.50", + }, + { + name: "X-Forwarded-For multiple when trusted", + trustProxy: true, + remoteAddr: "127.0.0.1:80", + xff: "203.0.113.50, 70.41.3.18, 150.172.238.178", + want: "203.0.113.50", + }, + { + name: "X-Real-IP when trusted", + trustProxy: true, + remoteAddr: "127.0.0.1:80", + xri: "203.0.113.50", + want: "203.0.113.50", + }, + { + name: "X-Real-IP takes precedence over X-Forwarded-For when trusted", + trustProxy: true, + remoteAddr: "127.0.0.1:80", + xff: "203.0.113.50", + xri: "198.51.100.1", + want: "198.51.100.1", + }, + { + name: "untrusted ignores X-Forwarded-For", + trustProxy: false, + remoteAddr: "10.0.0.1:12345", + xff: "203.0.113.50", + want: "10.0.0.1", + }, + { + name: "untrusted ignores X-Real-IP", + trustProxy: false, + remoteAddr: "10.0.0.1:12345", + xri: "203.0.113.50", + want: "10.0.0.1", + }, + { + name: "invalid X-Real-IP falls through to XFF", + trustProxy: true, + remoteAddr: "127.0.0.1:80", + xri: "not-an-ip", + xff: "203.0.113.50", + want: "203.0.113.50", + }, + { + name: "invalid XFF falls through to RemoteAddr", + trustProxy: true, + remoteAddr: "127.0.0.1:80", + xff: "not-an-ip", + want: "127.0.0.1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.RemoteAddr = tt.remoteAddr + if tt.xff != "" { + r.Header.Set("X-Forwarded-For", tt.xff) + } + if tt.xri != "" { + r.Header.Set("X-Real-IP", tt.xri) + } + + if got := clientIP(r, tt.trustProxy); got != tt.want { + t.Errorf("clientIP(r, %v) = %q, want %q", tt.trustProxy, got, tt.want) + } + }) + } +} + +func BenchmarkRateLimiterAllow(b *testing.B) { + rl := newRateLimiter(1e9, 1<<30) // effectively unlimited + for b.Loop() { + rl.allow("1.2.3.4") + } +} + +func BenchmarkClientIP(b *testing.B) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.RemoteAddr = "10.0.0.1:12345" + r.Header.Set("X-Real-IP", "203.0.113.50") + for b.Loop() { + clientIP(r, true) + } +} diff --git a/internal/api/response.go b/internal/api/response.go index a040b74..15dd1ff 100644 --- a/internal/api/response.go +++ b/internal/api/response.go @@ -1,4 +1,3 @@ -// Package api provides the JSON REST API server for Koopa. package api import ( @@ -23,12 +22,16 @@ type envelope struct { // WriteJSON writes data wrapped in an envelope as JSON. // For nil data, writes no body (use with 204 No Content). -func WriteJSON(w http.ResponseWriter, status int, data any) { +// If logger is nil, falls back to slog.Default(). +func WriteJSON(w http.ResponseWriter, status int, data any, logger *slog.Logger) { if data != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(envelope{Data: data}); err != nil { - slog.Error("failed to encode JSON response", "error", err) + if logger == nil { + logger = slog.Default() + } + logger.Error("encoding JSON response", "error", err) } } else { w.WriteHeader(status) @@ -36,10 +39,14 @@ func WriteJSON(w http.ResponseWriter, status int, data any) { } // WriteError writes a JSON error response wrapped in an envelope. -func WriteError(w http.ResponseWriter, status int, code, message string) { +// If logger is nil, falls back to slog.Default(). +func WriteError(w http.ResponseWriter, status int, code, message string, logger *slog.Logger) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(envelope{Error: &Error{Code: code, Message: message}}); err != nil { - slog.Error("failed to encode JSON error response", "error", err) + if logger == nil { + logger = slog.Default() + } + logger.Error("encoding JSON error response", "error", err) } } diff --git a/internal/api/response_test.go b/internal/api/response_test.go index 37243cd..139fb4d 100644 --- a/internal/api/response_test.go +++ b/internal/api/response_test.go @@ -14,10 +14,10 @@ func decodeData(t *testing.T, w *httptest.ResponseRecorder, target any) { Data json.RawMessage `json:"data"` } if err := json.NewDecoder(w.Body).Decode(&env); err != nil { - t.Fatalf("failed to decode envelope: %v", err) + t.Fatalf("decoding envelope: %v", err) } if err := json.Unmarshal(env.Data, target); err != nil { - t.Fatalf("failed to decode envelope data: %v", err) + t.Fatalf("decoding envelope data: %v", err) } } @@ -28,7 +28,7 @@ func decodeErrorEnvelope(t *testing.T, w *httptest.ResponseRecorder) Error { Error *Error `json:"error"` } if err := json.NewDecoder(w.Body).Decode(&env); err != nil { - t.Fatalf("failed to decode envelope: %v", err) + t.Fatalf("decoding envelope: %v", err) } if env.Error == nil { t.Fatal("expected error in envelope, got nil") @@ -40,7 +40,7 @@ func TestWriteJSON(t *testing.T) { w := httptest.NewRecorder() data := map[string]string{"key": "value"} - WriteJSON(w, http.StatusOK, data) + WriteJSON(w, http.StatusOK, data, nil) if ct := w.Header().Get("Content-Type"); ct != "application/json" { t.Errorf("WriteJSON() Content-Type = %q, want %q", ct, "application/json") @@ -62,7 +62,7 @@ func TestWriteJSON_Envelope(t *testing.T) { w := httptest.NewRecorder() items := []string{"a", "b"} - WriteJSON(w, http.StatusOK, items) + WriteJSON(w, http.StatusOK, items, nil) var body []string decodeData(t, w, &body) @@ -75,7 +75,7 @@ func TestWriteJSON_Envelope(t *testing.T) { func TestWriteJSON_NilData(t *testing.T) { w := httptest.NewRecorder() - WriteJSON(w, http.StatusNoContent, nil) + WriteJSON(w, http.StatusNoContent, nil, nil) if w.Code != http.StatusNoContent { t.Errorf("WriteJSON(nil) status = %d, want %d", w.Code, http.StatusNoContent) @@ -89,7 +89,7 @@ func TestWriteJSON_NilData(t *testing.T) { func TestWriteJSON_CustomStatus(t *testing.T) { w := httptest.NewRecorder() - WriteJSON(w, http.StatusCreated, map[string]string{"id": "123"}) + WriteJSON(w, http.StatusCreated, map[string]string{"id": "123"}, nil) if w.Code != http.StatusCreated { t.Errorf("WriteJSON() status = %d, want %d", w.Code, http.StatusCreated) @@ -99,7 +99,7 @@ func TestWriteJSON_CustomStatus(t *testing.T) { func TestWriteError(t *testing.T) { w := httptest.NewRecorder() - WriteError(w, http.StatusBadRequest, "invalid_input", "name is required") + WriteError(w, http.StatusBadRequest, "invalid_input", "name is required", nil) if w.Code != http.StatusBadRequest { t.Errorf("WriteError() status = %d, want %d", w.Code, http.StatusBadRequest) diff --git a/internal/api/server.go b/internal/api/server.go index c75ac2f..df0f20c 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -5,16 +5,14 @@ import ( "log/slog" "net/http" - "github.com/firebase/genkit/go/genkit" - "github.com/koopa0/koopa/internal/agent/chat" + "github.com/koopa0/koopa/internal/chat" "github.com/koopa0/koopa/internal/session" ) // ServerConfig contains configuration for creating the API server. type ServerConfig struct { Logger *slog.Logger - Genkit *genkit.Genkit // Optional: nil disables AI title generation - ModelName string // Provider-qualified model name (e.g., "googleai/gemini-2.5-flash") + ChatAgent *chat.Agent // Optional: nil disables AI title generation ChatFlow *chat.Flow // Optional: nil enables simulation mode SessionStore *session.Store // Required CSRFSecret []byte // Required: 32+ bytes @@ -50,11 +48,10 @@ func NewServer(cfg ServerConfig) (*Server, error) { } ch := &chatHandler{ - logger: logger, - genkit: cfg.Genkit, - modelName: cfg.ModelName, - flow: cfg.ChatFlow, - sessions: sm, + logger: logger, + agent: cfg.ChatAgent, + flow: cfg.ChatFlow, + sessions: sm, } mux := http.NewServeMux() @@ -79,7 +76,7 @@ func NewServer(cfg ServerConfig) (*Server, error) { // Build middleware stack: Recovery → Logging → RateLimit → CORS → Session → CSRF → Routes var handler http.Handler = mux handler = csrfMiddleware(sm, logger)(handler) - handler = sessionMiddleware(sm, logger)(handler) + handler = sessionMiddleware(sm)(handler) handler = corsMiddleware(cfg.CORSOrigins)(handler) handler = rateLimitMiddleware(rl, cfg.TrustProxy, logger)(handler) handler = loggingMiddleware(logger)(handler) diff --git a/internal/api/session.go b/internal/api/session.go index 4f542ba..d062316 100644 --- a/internal/api/session.go +++ b/internal/api/session.go @@ -19,12 +19,18 @@ import ( // Sentinel errors for session/CSRF operations. var ( + // ErrSessionCookieNotFound is returned when the session cookie is absent from the request. ErrSessionCookieNotFound = errors.New("session cookie not found") - ErrSessionInvalid = errors.New("session ID invalid") - ErrCSRFRequired = errors.New("csrf token required") - ErrCSRFInvalid = errors.New("csrf token invalid") - ErrCSRFExpired = errors.New("csrf token expired") - ErrCSRFMalformed = errors.New("csrf token malformed") + // ErrSessionInvalid is returned when the session cookie value is not a valid UUID. + ErrSessionInvalid = errors.New("session ID invalid") + // ErrCSRFRequired is returned when a state-changing request has no CSRF token. + ErrCSRFRequired = errors.New("csrf token required") + // ErrCSRFInvalid is returned when the CSRF token signature does not match. + ErrCSRFInvalid = errors.New("csrf token invalid") + // ErrCSRFExpired is returned when the CSRF token timestamp exceeds csrfTokenTTL. + ErrCSRFExpired = errors.New("csrf token expired") + // ErrCSRFMalformed is returned when the CSRF token format cannot be parsed. + ErrCSRFMalformed = errors.New("csrf token malformed") ) // Pre-session CSRF token prefix to distinguish from session-bound tokens. @@ -32,10 +38,11 @@ const preSessionPrefix = "pre:" // Cookie and CSRF configuration. const ( - sessionCookieName = "sid" - csrfTokenTTL = 24 * time.Hour - sessionMaxAge = 30 * 24 * 3600 // 30 days in seconds - csrfClockSkew = 5 * time.Minute + sessionCookieName = "sid" + csrfTokenTTL = 24 * time.Hour + sessionMaxAge = 30 * 24 * 3600 // 30 days in seconds + csrfClockSkew = 5 * time.Minute + messagesDefaultLimit = 100 ) // sessionManager handles session cookies and CSRF token operations. @@ -46,31 +53,6 @@ type sessionManager struct { logger *slog.Logger } -// GetOrCreate retrieves session ID from cookie or creates a new session. -// On success, sets/refreshes the session cookie and returns the session UUID. -func (sm *sessionManager) GetOrCreate(w http.ResponseWriter, r *http.Request) (uuid.UUID, error) { - // Try to get existing session from cookie - cookie, err := r.Cookie(sessionCookieName) - if err == nil && cookie.Value != "" { - sessionID, parseErr := uuid.Parse(cookie.Value) - if parseErr == nil { - if _, getErr := sm.store.Session(r.Context(), sessionID); getErr == nil { - sm.setCookie(w, sessionID) - return sessionID, nil - } - } - } - - // Create new session - sess, err := sm.store.CreateSession(r.Context(), "", "", "") - if err != nil { - return uuid.Nil, fmt.Errorf("create session: %w", err) - } - - sm.setCookie(w, sess.ID) - return sess.ID, nil -} - // ID extracts session ID from cookie without creating a new session. func (*sessionManager) ID(r *http.Request) (uuid.UUID, error) { cookie, err := r.Cookie(sessionCookieName) @@ -197,24 +179,24 @@ func (sm *sessionManager) CheckPreSessionCSRF(token string) error { func (sm *sessionManager) requireOwnership(w http.ResponseWriter, r *http.Request) (uuid.UUID, bool) { idStr := r.PathValue("id") if idStr == "" { - WriteError(w, http.StatusBadRequest, "missing_id", "session ID required") + WriteError(w, http.StatusBadRequest, "missing_id", "session ID required", sm.logger) return uuid.Nil, false } targetID, err := uuid.Parse(idStr) if err != nil { - WriteError(w, http.StatusBadRequest, "invalid_id", "invalid session ID") + WriteError(w, http.StatusBadRequest, "invalid_id", "invalid session ID", sm.logger) return uuid.Nil, false } - ownerID, ok := SessionIDFromContext(r.Context()) + ownerID, ok := sessionIDFromContext(r.Context()) if !ok || ownerID != targetID { sm.logger.Warn("session ownership check failed", "target", targetID, "path", r.URL.Path, "remote_addr", r.RemoteAddr, ) - WriteError(w, http.StatusForbidden, "forbidden", "session access denied") + WriteError(w, http.StatusForbidden, "forbidden", "session access denied", sm.logger) return uuid.Nil, false } @@ -240,13 +222,13 @@ func (sm *sessionManager) csrfToken(w http.ResponseWriter, r *http.Request) { if err == nil { WriteJSON(w, http.StatusOK, map[string]string{ "csrfToken": sm.NewCSRFToken(sessionID), - }) + }, sm.logger) return } WriteJSON(w, http.StatusOK, map[string]string{ "csrfToken": sm.NewPreSessionCSRFToken(), - }) + }, sm.logger) } // listSessions handles GET /api/v1/sessions — returns sessions owned by the caller. @@ -258,21 +240,21 @@ func (sm *sessionManager) listSessions(w http.ResponseWriter, r *http.Request) { UpdatedAt string `json:"updatedAt"` } - sessionID, ok := SessionIDFromContext(r.Context()) + sessionID, ok := sessionIDFromContext(r.Context()) if !ok { // No session cookie — return empty list - WriteJSON(w, http.StatusOK, []sessionItem{}) + WriteJSON(w, http.StatusOK, []sessionItem{}, sm.logger) return } sess, err := sm.store.Session(r.Context(), sessionID) if err != nil { - if errors.Is(err, session.ErrSessionNotFound) { - WriteJSON(w, http.StatusOK, []sessionItem{}) + if errors.Is(err, session.ErrNotFound) { + WriteJSON(w, http.StatusOK, []sessionItem{}, sm.logger) return } - sm.logger.Error("failed to get session", "error", err, "session_id", sessionID) - WriteError(w, http.StatusInternalServerError, "list_failed", "failed to list sessions") + sm.logger.Error("getting session", "error", err, "session_id", sessionID) + WriteError(w, http.StatusInternalServerError, "list_failed", "failed to list sessions", sm.logger) return } @@ -282,15 +264,15 @@ func (sm *sessionManager) listSessions(w http.ResponseWriter, r *http.Request) { Title: sess.Title, UpdatedAt: sess.UpdatedAt.Format(time.RFC3339), }, - }) + }, sm.logger) } // createSession handles POST /api/v1/sessions — creates a new session. func (sm *sessionManager) createSession(w http.ResponseWriter, r *http.Request) { - sess, err := sm.store.CreateSession(r.Context(), "", "", "") + sess, err := sm.store.CreateSession(r.Context(), "") if err != nil { - sm.logger.Error("failed to create session", "error", err) - WriteError(w, http.StatusInternalServerError, "create_failed", "failed to create session") + sm.logger.Error("creating session", "error", err) + WriteError(w, http.StatusInternalServerError, "create_failed", "failed to create session", sm.logger) return } @@ -299,7 +281,7 @@ func (sm *sessionManager) createSession(w http.ResponseWriter, r *http.Request) WriteJSON(w, http.StatusCreated, map[string]string{ "id": sess.ID.String(), "csrfToken": sm.NewCSRFToken(sess.ID), - }) + }, sm.logger) } // getSession handles GET /api/v1/sessions/{id} — returns a single session. @@ -312,12 +294,12 @@ func (sm *sessionManager) getSession(w http.ResponseWriter, r *http.Request) { sess, err := sm.store.Session(r.Context(), id) if err != nil { - if errors.Is(err, session.ErrSessionNotFound) { - WriteError(w, http.StatusNotFound, "not_found", "session not found") + if errors.Is(err, session.ErrNotFound) { + WriteError(w, http.StatusNotFound, "not_found", "session not found", sm.logger) return } - sm.logger.Error("failed to get session", "error", err, "session_id", id) - WriteError(w, http.StatusInternalServerError, "get_failed", "failed to get session") + sm.logger.Error("getting session", "error", err, "session_id", id) + WriteError(w, http.StatusInternalServerError, "get_failed", "failed to get session", sm.logger) return } @@ -326,7 +308,7 @@ func (sm *sessionManager) getSession(w http.ResponseWriter, r *http.Request) { "title": sess.Title, "createdAt": sess.CreatedAt.Format(time.RFC3339), "updatedAt": sess.UpdatedAt.Format(time.RFC3339), - }) + }, sm.logger) } // getSessionMessages handles GET /api/v1/sessions/{id}/messages — returns messages for a session. @@ -337,10 +319,10 @@ func (sm *sessionManager) getSessionMessages(w http.ResponseWriter, r *http.Requ return } - messages, err := sm.store.Messages(r.Context(), id, 100, 0) + messages, err := sm.store.Messages(r.Context(), id, messagesDefaultLimit, 0) if err != nil { - sm.logger.Error("failed to get messages", "error", err, "session_id", id) - WriteError(w, http.StatusInternalServerError, "get_failed", "failed to get messages") + sm.logger.Error("getting messages", "error", err, "session_id", id) + WriteError(w, http.StatusInternalServerError, "get_failed", "failed to get messages", sm.logger) return } @@ -369,7 +351,7 @@ func (sm *sessionManager) getSessionMessages(w http.ResponseWriter, r *http.Requ } } - WriteJSON(w, http.StatusOK, items) + WriteJSON(w, http.StatusOK, items, sm.logger) } // deleteSession handles DELETE /api/v1/sessions/{id} — deletes a session. @@ -381,10 +363,10 @@ func (sm *sessionManager) deleteSession(w http.ResponseWriter, r *http.Request) } if err := sm.store.DeleteSession(r.Context(), id); err != nil { - sm.logger.Error("failed to delete session", "error", err, "session_id", id) - WriteError(w, http.StatusInternalServerError, "delete_failed", "failed to delete session") + sm.logger.Error("deleting session", "error", err, "session_id", id) + WriteError(w, http.StatusInternalServerError, "delete_failed", "failed to delete session", sm.logger) return } - WriteJSON(w, http.StatusOK, map[string]string{"status": "deleted"}) + WriteJSON(w, http.StatusOK, map[string]string{"status": "deleted"}, sm.logger) } diff --git a/internal/api/session_test.go b/internal/api/session_test.go index 118ba4c..efdf2b5 100644 --- a/internal/api/session_test.go +++ b/internal/api/session_test.go @@ -81,9 +81,9 @@ func TestCSRFToken_Malformed(t *testing.T) { name string token string }{ - {"empty", ""}, - {"no_colon", "justtext"}, - {"bad_timestamp", "notanumber:signature"}, + {name: "empty", token: ""}, + {name: "no_colon", token: "justtext"}, + {name: "bad_timestamp", token: "notanumber:signature"}, } for _, tt := range tests { @@ -245,10 +245,6 @@ func TestGetSessionMessages_InvalidUUID(t *testing.T) { } } -// ============================================================================ -// Session Ownership Tests -// ============================================================================ - func TestRequireOwnership_NoSession(t *testing.T) { sm := newTestSessionManager() targetID := uuid.New() @@ -348,3 +344,59 @@ func TestListSessions_NoSession(t *testing.T) { t.Errorf("listSessions(no session) returned %d items, want 0", len(items)) } } + +func FuzzCheckCSRF(f *testing.F) { + sm := newTestSessionManager() + sessionID := uuid.New() + validToken := sm.NewCSRFToken(sessionID) + + f.Add(sessionID.String(), validToken) + f.Add(sessionID.String(), "") + f.Add(sessionID.String(), "notanumber:signature") + f.Add(sessionID.String(), "12345:badsig") + f.Add(uuid.New().String(), validToken) + f.Add("", "") + f.Add("not-a-uuid", "1234:sig") + + f.Fuzz(func(t *testing.T, sessionIDStr, token string) { + id, err := uuid.Parse(sessionIDStr) + if err != nil { + return + } + _ = sm.CheckCSRF(id, token) // must not panic + }) +} + +func FuzzCheckPreSessionCSRF(f *testing.F) { + sm := newTestSessionManager() + validToken := sm.NewPreSessionCSRFToken() + + f.Add(validToken) + f.Add("") + f.Add("pre:") + f.Add("pre:nonce:notanumber:sig") + f.Add("pre:abc:12345:sig") + f.Add("notpre:abc:123:sig") + f.Add("pre:abc:12345:sig:extra") + + f.Fuzz(func(t *testing.T, token string) { + _ = sm.CheckPreSessionCSRF(token) // must not panic + }) +} + +func BenchmarkNewCSRFToken(b *testing.B) { + sm := newTestSessionManager() + sessionID := uuid.New() + for b.Loop() { + sm.NewCSRFToken(sessionID) + } +} + +func BenchmarkCheckCSRF(b *testing.B) { + sm := newTestSessionManager() + sessionID := uuid.New() + token := sm.NewCSRFToken(sessionID) + for b.Loop() { + _ = sm.CheckCSRF(sessionID, token) + } +} diff --git a/internal/app/app.go b/internal/app/app.go index a13ef60..ad26490 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -1,108 +1,97 @@ -// Package app provides application initialization and dependency injection. +// Package app provides application initialization and lifecycle management. // -// App is the core container that orchestrates all application components with struct-based DI. -// It initializes Genkit, database connection, DocStore (via Genkit PostgreSQL Plugin), -// and creates the agent with all necessary dependencies. +// App is the core container that holds all application components. +// Created by Setup, released by Close. package app import ( - "context" "fmt" "log/slog" + "sync" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/postgresql" "github.com/jackc/pgx/v5/pgxpool" - "github.com/koopa0/koopa/internal/agent/chat" + + "github.com/koopa0/koopa/internal/chat" "github.com/koopa0/koopa/internal/config" "github.com/koopa0/koopa/internal/security" "github.com/koopa0/koopa/internal/session" - "golang.org/x/sync/errgroup" + "github.com/koopa0/koopa/internal/tools" ) -// App is the core application container. +// App is the application instance. +// Created by Setup, closed by Close. All entry points (CLI, HTTP, MCP) +// use this struct to access shared resources. type App struct { - // Configuration - Config *config.Config - + Config *config.Config Genkit *genkit.Genkit Embedder ai.Embedder - DBPool *pgxpool.Pool // Database connection pool - DocStore *postgresql.DocStore // Genkit PostgreSQL DocStore for indexing - Retriever ai.Retriever // Genkit Retriever for searching - SessionStore *session.Store // Session persistence (concrete type, not interface) - PathValidator *security.Path // Path validator for security - Tools []ai.Tool // Pre-registered tools + DBPool *pgxpool.Pool + DocStore *postgresql.DocStore + Retriever ai.Retriever + SessionStore *session.Store + PathValidator *security.Path + Tools []ai.Tool // Pre-registered Genkit tools (for chat agent) - // Lifecycle management - ctx context.Context - cancel context.CancelFunc + // Concrete toolsets shared by CLI and MCP entry points. + File *tools.File + System *tools.System + Network *tools.Network + Knowledge *tools.Knowledge // nil if retriever unavailable - // errgroup for background goroutine lifecycle management - eg *errgroup.Group + // Lifecycle management (unexported) + cancel func() + dbCleanup func() + otelCleanup func() + closeOnce sync.Once } -// Close gracefully shuts down App-managed resources. -// Cleanup function handles DB pool and OTel (single owner principle). +// Close gracefully shuts down all resources. Safe for concurrent and +// repeated calls — cleanup runs exactly once via sync.Once. // // Shutdown order: -// 1. Cancel context (signals background tasks to stop) -// 2. Wait for background goroutines (errgroup) +// 1. Cancel context (signals background tasks to stop) +// 2. Close DB pool +// 3. Flush OTel spans func (a *App) Close() error { - slog.Info("shutting down application") - - // 1. Cancel context (signals all background tasks to stop) - if a.cancel != nil { - a.cancel() - } + a.closeOnce.Do(func() { + slog.Info("shutting down application") - // 2. Wait for background goroutines - if a.eg != nil { - if err := a.eg.Wait(); err != nil { - return fmt.Errorf("background task error: %w", err) + // 1. Cancel context (signals all background tasks to stop) + if a.cancel != nil { + a.cancel() } - slog.Debug("background tasks completed") - } - // Pool is closed by cleanup function, NOT here (single owner principle) - return nil -} + // 2. Close DB pool + if a.dbCleanup != nil { + a.dbCleanup() + } -// Wait blocks until all background goroutines complete. -// This is useful for waiting on background tasks without closing resources. -func (a *App) Wait() error { - if a.eg == nil { - return nil - } - if err := a.eg.Wait(); err != nil { - return fmt.Errorf("errgroup failed: %w", err) - } + // 3. Flush OTel spans + if a.otelCleanup != nil { + a.otelCleanup() + } + }) return nil } -// Go starts a new background goroutine tracked by the app's errgroup. -// Use this for any background tasks that should be waited on during shutdown. -func (a *App) Go(f func() error) { - if a.eg != nil { - a.eg.Go(f) - } -} - // CreateAgent creates a Chat Agent using pre-registered tools. -// Tools are registered once at App construction (not lazily). -// InitializeApp guarantees all dependencies are non-nil. -func (a *App) CreateAgent(_ context.Context) (*chat.Chat, error) { - // No nil checks - InitializeApp guarantees injection - return chat.New(chat.Config{ +// Tools are registered once at Setup (not lazily). +// Setup guarantees all dependencies are non-nil. +func (a *App) CreateAgent() (*chat.Agent, error) { + agent, err := chat.New(chat.Config{ Genkit: a.Genkit, - Retriever: a.Retriever, SessionStore: a.SessionStore, Logger: slog.Default(), Tools: a.Tools, ModelName: a.Config.FullModelName(), MaxTurns: a.Config.MaxTurns, - RAGTopK: a.Config.RAGTopK, Language: a.Config.Language, }) + if err != nil { + return nil, fmt.Errorf("creating chat agent: %w", err) + } + return agent, nil } diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 8247a40..dd8049b 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -2,115 +2,81 @@ package app import ( "context" - "fmt" "os" "path/filepath" - "runtime" "testing" "time" "github.com/firebase/genkit/go/genkit" + "github.com/koopa0/koopa/internal/config" "github.com/koopa0/koopa/internal/security" - "golang.org/x/sync/errgroup" + "github.com/koopa0/koopa/internal/testutil" ) -// ============================================================================ -// App.Close() Tests -// ============================================================================ - func TestApp_Close(t *testing.T) { tests := []struct { name string - setupApp func() *App + setup func() *App expectError bool }{ { name: "close with cancel function", - setupApp: func() *App { - ctx, cancel := context.WithCancel(context.Background()) - return &App{ - ctx: ctx, - cancel: cancel, - DBPool: nil, // Don't mock pgxpool as it causes panic on close - } + setup: func() *App { + _, cancel := context.WithCancel(context.Background()) + return &App{cancel: cancel} }, - expectError: false, }, { - name: "close with nil DBPool", - setupApp: func() *App { - ctx, cancel := context.WithCancel(context.Background()) - return &App{ - ctx: ctx, - cancel: cancel, - DBPool: nil, - } + name: "close with nil cancel", + setup: func() *App { + return &App{cancel: nil} }, - expectError: false, }, { - name: "close with nil cancel function", - setupApp: func() *App { + name: "close with cleanup functions", + setup: func() *App { return &App{ - ctx: context.Background(), - cancel: nil, - DBPool: nil, + dbCleanup: func() {}, + otelCleanup: func() {}, } }, - expectError: false, }, { name: "close minimal app", - setupApp: func() *App { + setup: func() *App { return &App{} }, - expectError: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - app := tt.setupApp() - err := app.Close() + a := tt.setup() + err := a.Close() if tt.expectError && err == nil { t.Error("expected error but got none") } - if !tt.expectError && err != nil { - t.Errorf("unexpected error: %v", err) - } - - // Verify context was canceled if cancel function existed - if app.cancel != nil && app.ctx != nil { - select { - case <-app.ctx.Done(): - // Context was properly canceled - default: - t.Error("context was not canceled") - } + t.Errorf("Close() unexpected error: %v", err) } }) } } -// ============================================================================ -// App Struct Field Tests -// ============================================================================ - func TestApp_Fields(t *testing.T) { - t.Run("app with all fields set", func(t *testing.T) { + t.Run("all fields set", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() g := genkit.Init(ctx) pathValidator, err := security.NewPath([]string{"."}) if err != nil { - t.Fatalf("failed to create path validator: %v", err) + t.Fatalf("security.NewPath() error: %v", err) } - app := &App{ + a := &App{ Config: &config.Config{ ModelName: "gemini-2.0-flash-exp", }, @@ -120,90 +86,111 @@ func TestApp_Fields(t *testing.T) { DocStore: nil, Retriever: nil, PathValidator: pathValidator, - ctx: ctx, cancel: cancel, } - // Verify fields are set - if app.Config == nil { - t.Error("expected Config to be set") + if a.Config == nil { + t.Error("App.Config = nil, want non-nil") } - if app.Genkit == nil { - t.Error("expected Genkit to be set") + if a.Genkit == nil { + t.Error("App.Genkit = nil, want non-nil") } - if app.PathValidator == nil { - t.Error("expected PathValidator to be set") + if a.PathValidator == nil { + t.Error("App.PathValidator = nil, want non-nil") } - if app.ctx == nil { - t.Error("expected ctx to be set") - } - if app.cancel == nil { - t.Error("expected cancel to be set") + if a.cancel == nil { + t.Error("App.cancel = nil, want non-nil") } }) } -// ============================================================================ -// Nil Safety Tests -// ============================================================================ - func TestApp_NilSafety(t *testing.T) { tests := []struct { name string - app *App + a *App }{ { - name: "close nil app fields", - app: &App{}, + name: "close nil fields", + a: &App{}, }, { - name: "close with only ctx", - app: &App{ - ctx: context.Background(), - }, + name: "close with only cancel", + a: &App{cancel: func() {}}, }, { - name: "close with only cancel", - app: &App{ - cancel: func() {}, - }, + name: "close with only cleanup", + a: &App{dbCleanup: func() {}, otelCleanup: func() {}}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Should not panic - err := tt.app.Close() + err := tt.a.Close() if err != nil { - t.Errorf("unexpected error: %v", err) + t.Errorf("Close() unexpected error: %v", err) } }) } } -// ============================================================================ -// InitializeApp Integration Tests -// ============================================================================ +func TestApp_Close_ShutdownOrder(t *testing.T) { + t.Run("cleanup called after cancel", func(t *testing.T) { + var order []string + + a := &App{ + cancel: func() { order = append(order, "cancel") }, + dbCleanup: func() { order = append(order, "db") }, + otelCleanup: func() { order = append(order, "otel") }, + } + + _ = a.Close() + + if len(order) != 3 { + t.Fatalf("Close() operations = %d, want 3: %v", len(order), order) + } + if order[0] != "cancel" { + t.Errorf("order[0] = %q, want %q", order[0], "cancel") + } + if order[1] != "db" { + t.Errorf("order[1] = %q, want %q", order[1], "db") + } + if order[2] != "otel" { + t.Errorf("order[2] = %q, want %q", order[2], "otel") + } + }) + + t.Run("idempotent close", func(t *testing.T) { + callCount := 0 + a := &App{ + cancel: func() { callCount++ }, + dbCleanup: func() { callCount++ }, + otelCleanup: func() { callCount++ }, + } + + _ = a.Close() + _ = a.Close() // second call should be no-op -func TestInitializeApp_Success(t *testing.T) { + if callCount != 3 { + t.Errorf("Close() twice: call count = %d, want 3 (second Close should be no-op)", callCount) + } + }) +} + +func TestSetup_Success(t *testing.T) { if testing.Short() { - t.Skip("Skipping integration test in short mode") + t.Skip("skipping integration test in short mode") } - - // Check required environment variables if os.Getenv("GEMINI_API_KEY") == "" { - t.Skip("GEMINI_API_KEY not set - skipping integration test") + t.Skip("GEMINI_API_KEY not set") } - - // Skip if database is not available (PostgreSQL required) if os.Getenv("DATABASE_URL") == "" { - t.Skip("DATABASE_URL not set - skipping integration test") + t.Skip("DATABASE_URL not set") } ctx := context.Background() cfg := &config.Config{ ModelName: "gemini-2.0-flash-exp", - EmbedderModel: "text-embedding-004", + EmbedderModel: "gemini-embedding-001", Temperature: 0.7, MaxTokens: 8192, PostgresHost: "localhost", @@ -215,68 +202,58 @@ func TestInitializeApp_Success(t *testing.T) { PromptDir: getPromptsDir(t), } - // Test: InitializeApp should successfully create all components - app, cleanup, err := InitializeApp(ctx, cfg) + a, err := Setup(ctx, cfg) if err != nil { - t.Fatalf("InitializeApp failed: %v", err) + t.Fatalf("Setup() error: %v", err) } - defer cleanup() - defer func() { _ = app.Close() }() + defer func() { _ = a.Close() }() - // Verify all components are initialized - if app == nil { - t.Fatal("expected non-nil app") - return + if a.Config == nil { + t.Error("Setup().Config = nil, want non-nil") } - if app.Config == nil { - t.Error("expected Config to be set") + if a.Genkit == nil { + t.Error("Setup().Genkit = nil, want non-nil") } - if app.Genkit == nil { - t.Error("expected Genkit to be set") + if a.Embedder == nil { + t.Error("Setup().Embedder = nil, want non-nil") } - if app.Embedder == nil { - t.Error("expected Embedder to be set") + if a.DBPool == nil { + t.Error("Setup().DBPool = nil, want non-nil") } - if app.DBPool == nil { - t.Error("expected DBPool to be set") + if a.DocStore == nil { + t.Error("Setup().DocStore = nil, want non-nil") } - if app.DocStore == nil { - t.Error("expected DocStore to be set") + if a.Retriever == nil { + t.Error("Setup().Retriever = nil, want non-nil") } - if app.Retriever == nil { - t.Error("expected Retriever to be set") + if a.SessionStore == nil { + t.Error("Setup().SessionStore = nil, want non-nil") } - if app.SessionStore == nil { - t.Error("expected SessionStore to be set") + if a.PathValidator == nil { + t.Error("Setup().PathValidator = nil, want non-nil") } - if app.PathValidator == nil { - t.Error("expected PathValidator to be set") - } - // Note: SystemIndexer was removed - system knowledge indexing is now - // done via rag.IndexSystemKnowledge() function in newApp() // Verify database connection is functional - if err := app.DBPool.Ping(ctx); err != nil { - t.Errorf("database ping failed: %v", err) + if err := a.DBPool.Ping(ctx); err != nil { + t.Errorf("pinging database: %v", err) } } -func TestInitializeApp_CleanupFunction(t *testing.T) { +func TestSetup_CleanupOnClose(t *testing.T) { if testing.Short() { - t.Skip("Skipping integration test in short mode") + t.Skip("skipping integration test in short mode") } - if os.Getenv("GEMINI_API_KEY") == "" { - t.Skip("GEMINI_API_KEY not set - skipping integration test") + t.Skip("GEMINI_API_KEY not set") } if os.Getenv("DATABASE_URL") == "" { - t.Skip("DATABASE_URL not set - skipping integration test") + t.Skip("DATABASE_URL not set") } ctx := context.Background() cfg := &config.Config{ ModelName: "gemini-2.0-flash-exp", - EmbedderModel: "text-embedding-004", + EmbedderModel: "gemini-embedding-001", PostgresHost: "localhost", PostgresPort: 5432, PostgresUser: "postgres", @@ -286,235 +263,59 @@ func TestInitializeApp_CleanupFunction(t *testing.T) { PromptDir: getPromptsDir(t), } - app, cleanup, err := InitializeApp(ctx, cfg) + a, err := Setup(ctx, cfg) if err != nil { - t.Fatalf("InitializeApp failed: %v", err) + t.Fatalf("Setup() error: %v", err) } - // Test: cleanup function should close database pool - cleanup() + pool := a.DBPool + + // Close should release DB pool + if err := a.Close(); err != nil { + t.Fatalf("Close() error: %v", err) + } // Verify pool is closed (ping should fail) - if err := app.DBPool.Ping(ctx); err == nil { - t.Error("expected database ping to fail after cleanup") + if err := pool.Ping(ctx); err == nil { + t.Error("pool.Ping() after Close() = nil, want error") } } func TestProvidePathValidator_Success(t *testing.T) { validator, err := providePathValidator() - if err != nil { - t.Fatalf("providePathValidator failed: %v", err) + t.Fatalf("providePathValidator() error: %v", err) } if validator == nil { - t.Fatal("expected non-nil path validator") + t.Fatal("providePathValidator() returned nil") } } -// ============================================================================ -// Shutdown and Lifecycle Tests -// ============================================================================ - -// TestApp_ShutdownTimeout tests that shutdown completes within reasonable time. -// Safety: Prevents hang during shutdown if background tasks don't respond to cancellation. func TestApp_ShutdownTimeout(t *testing.T) { t.Run("graceful shutdown completes quickly", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - app := &App{ - ctx: ctx, - cancel: cancel, - DBPool: nil, - } + _, cancel := context.WithCancel(context.Background()) + a := &App{cancel: cancel} - // Shutdown should complete quickly (no background tasks) done := make(chan struct{}) go func() { - _ = app.Close() + _ = a.Close() close(done) }() select { case <-done: - // Success: shutdown completed + // Success case <-time.After(5 * time.Second): - t.Fatal("shutdown timed out - potential deadlock") + t.Fatal("shutdown timed out") } }) - - t.Run("shutdown with background goroutine", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - eg, egCtx := errgroup.WithContext(ctx) - - app := &App{ - ctx: ctx, - cancel: cancel, - DBPool: nil, - eg: eg, - } - - // Start a background task that respects context cancellation - taskDone := make(chan struct{}) - app.Go(func() error { - defer close(taskDone) - <-egCtx.Done() - return nil - }) - - // Shutdown should complete after background task exits - done := make(chan struct{}) - go func() { - _ = app.Close() - close(done) - }() - - select { - case <-done: - // Success: shutdown completed - case <-time.After(5 * time.Second): - t.Fatal("shutdown timed out with background goroutine") - } - - // Verify background task was properly terminated - select { - case <-taskDone: - // Task completed - default: - t.Error("background task was not properly terminated") - } - }) - - t.Run("shutdown timeout safety", func(t *testing.T) { - // This test documents the expected behavior: - // If a background task doesn't respond to context cancellation, - // shutdown will block. This is intentional to prevent data loss. - // - // In production, consider adding a hard timeout: - // - Use context.WithTimeout for background tasks - // - Or implement a watchdog timer in Close() - - ctx, cancel := context.WithCancel(context.Background()) - app := &App{ - ctx: ctx, - cancel: cancel, - } - - // Verify cancel is called during Close - _ = app.Close() - - select { - case <-ctx.Done(): - // Context was properly canceled - default: - t.Error("context was not canceled during shutdown") - } - }) -} - -// TestApp_Wait tests the Wait() method for background task completion. -func TestApp_Wait(t *testing.T) { - t.Run("wait with nil errgroup returns nil", func(t *testing.T) { - app := &App{eg: nil} - err := app.Wait() - if err != nil { - t.Errorf("expected nil error, got: %v", err) - } - }) - - t.Run("wait blocks until tasks complete", func(t *testing.T) { - ctx := context.Background() - eg, _ := errgroup.WithContext(ctx) - - app := &App{eg: eg} - - taskStarted := make(chan struct{}) - taskDone := make(chan struct{}) - - app.Go(func() error { - close(taskStarted) - time.Sleep(50 * time.Millisecond) - close(taskDone) - return nil - }) - - <-taskStarted // Wait for task to start - - // Wait should block until task completes - err := app.Wait() - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - select { - case <-taskDone: - // Task completed before Wait returned - default: - t.Error("Wait returned before task completed") - } - }) -} - -// TestApp_Go tests the Go() method for starting background tasks. -func TestApp_Go(t *testing.T) { - t.Run("go with nil errgroup does not panic", func(t *testing.T) { - app := &App{eg: nil} - - // Should not panic - app.Go(func() error { - return nil - }) - }) - - t.Run("go tracks task in errgroup", func(t *testing.T) { - ctx := context.Background() - eg, _ := errgroup.WithContext(ctx) - - app := &App{eg: eg} - - executed := false - app.Go(func() error { - executed = true - return nil - }) - - // Wait for task - _ = app.Wait() - - if !executed { - t.Error("task was not executed") - } - }) -} - -// ============================================================================ -// Test Helpers -// ============================================================================ - -// findProjectRoot finds the project root directory by looking for go.mod. -func findProjectRoot() (string, error) { - _, filename, _, ok := runtime.Caller(0) - if !ok { - return "", fmt.Errorf("could not determine caller filename") - } - - dir := filepath.Dir(filename) - for { - if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { - return dir, nil - } - parent := filepath.Dir(dir) - if parent == dir { - return "", fmt.Errorf("project root (go.mod) not found") - } - dir = parent - } } -// getPromptsDir returns the absolute path to the prompts directory. func getPromptsDir(t *testing.T) string { t.Helper() - root, err := findProjectRoot() + root, err := testutil.FindProjectRoot() if err != nil || root == "" { - t.Skip("Could not find project root") + t.Skip("could not find project root") } return filepath.Join(root, "prompts") } diff --git a/internal/app/runtime.go b/internal/app/runtime.go deleted file mode 100644 index c5fff33..0000000 --- a/internal/app/runtime.go +++ /dev/null @@ -1,87 +0,0 @@ -package app - -import ( - "context" - "errors" - "fmt" - "log/slog" - - "github.com/koopa0/koopa/internal/agent/chat" - "github.com/koopa0/koopa/internal/config" -) - -// ChatRuntime provides a fully initialized application runtime with all components ready to use. -// It encapsulates the common initialization logic used by CLI and HTTP server entry points. -// MCP mode uses InitializeApp directly (no chat flow needed). -// Implements io.Closer for resource cleanup. -type ChatRuntime struct { - App *App - Flow *chat.Flow - cleanup func() // cleanup (unexported) - handles DB pool, OTel -} - -// Close releases all resources. Implements io.Closer. -// Shutdown order: App.Close (goroutines) → cleanup (DB pool, OTel). -func (r *ChatRuntime) Close() error { - var errs []error - - // 1. App shutdown (cancel context, wait for goroutines) - if r.App != nil { - if err := r.App.Close(); err != nil { - errs = append(errs, fmt.Errorf("app close: %w", err)) - } - } - - // 2. Cleanup (DB pool, OTel) - if r.cleanup != nil { - r.cleanup() - } - - return errors.Join(errs...) -} - -// NewChatRuntime creates a fully initialized runtime with all components ready for use. -// This is the recommended way to initialize Koopa for CLI and HTTP entry points. -// -// Usage: -// -// runtime, err := app.NewChatRuntime(ctx, cfg) -// if err != nil { ... } -// defer runtime.Close() // Single cleanup method (implements io.Closer) -// // Use runtime.Flow for agent interactions -func NewChatRuntime(ctx context.Context, cfg *config.Config) (*ChatRuntime, error) { - // Initialize application - application, cleanup, err := InitializeApp(ctx, cfg) - if err != nil { - return nil, fmt.Errorf("failed to initialize application: %w", err) - } - - // Create Chat Agent (uses pre-registered tools) - chatAgent, err := application.CreateAgent(ctx) - if err != nil { - // Must close application first (stops background goroutines) - // then cleanup (closes DB pool, OTel) - if closeErr := application.Close(); closeErr != nil { - slog.Warn("app close failed during CreateAgent recovery", "error", closeErr) - } - cleanup() - return nil, fmt.Errorf("failed to create agent: %w", err) - } - - // Initialize Chat Flow (singleton pattern with explicit lifecycle) - chatFlow, err := chat.InitFlow(application.Genkit, chatAgent) - if err != nil { - // InitFlow failed (likely called twice) - cleanup and return error - if closeErr := application.Close(); closeErr != nil { - slog.Warn("app close failed during InitFlow recovery", "error", closeErr) - } - cleanup() - return nil, fmt.Errorf("failed to init flow: %w", err) - } - - return &ChatRuntime{ - App: application, - Flow: chatFlow, - cleanup: cleanup, - }, nil -} diff --git a/internal/app/runtime_test.go b/internal/app/runtime_test.go deleted file mode 100644 index e34c138..0000000 --- a/internal/app/runtime_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package app - -import ( - "errors" - "testing" -) - -// ============================================================================ -// ChatChatRuntime.Close() Tests -// ============================================================================ - -func TestChatRuntime_Close(t *testing.T) { - t.Run("close with nil app", func(t *testing.T) { - cleanupCalled := false - r := &ChatRuntime{ - App: nil, - cleanup: func() { cleanupCalled = true }, - } - - err := r.Close() - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if !cleanupCalled { - t.Error("cleanup function should be called") - } - }) - - t.Run("close with nil cleanup", func(t *testing.T) { - r := &ChatRuntime{ - App: nil, - cleanup: nil, - } - - err := r.Close() - if err != nil { - t.Errorf("unexpected error: %v", err) - } - }) - - t.Run("close propagates app error", func(t *testing.T) { - // Create an app that will return an error on Close - app := &App{ - eg: nil, // nil errgroup means Close returns nil - } - - cleanupCalled := false - r := &ChatRuntime{ - App: app, - cleanup: func() { cleanupCalled = true }, - } - - err := r.Close() - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if !cleanupCalled { - t.Error("cleanup should be called even after app.Close succeeds") - } - }) - - t.Run("cleanup called after app close", func(t *testing.T) { - // This test verifies the shutdown order: - // 1. App.Close() (cancel context, wait for goroutines) - // 2. Cleanup (DB pool, OTel) - var order []string - - app := &App{ - cancel: func() { order = append(order, "cancel") }, - } - - r := &ChatRuntime{ - App: app, - cleanup: func() { order = append(order, "cleanup") }, - } - - _ = r.Close() - - if len(order) != 2 { - t.Fatalf("expected 2 operations, got %d", len(order)) - } - if order[0] != "cancel" { - t.Errorf("expected cancel first, got %s", order[0]) - } - if order[1] != "cleanup" { - t.Errorf("expected cleanup second, got %s", order[1]) - } - }) -} - -// TestChatRuntime_Close_ErrorAggregation tests that errors are properly joined. -func TestChatRuntime_Close_ErrorAggregation(t *testing.T) { - t.Run("errors from app close are returned", func(t *testing.T) { - // We can't easily make App.Close() return an error without - // setting up an errgroup that returns an error. - // This documents the expected behavior. - - // When App.Close() returns an error, ChatRuntime.Close() should: - // 1. Still call cleanup() - // 2. Return the error - - cleanupCalled := false - r := &ChatRuntime{ - App: &App{ - cancel: nil, - eg: nil, - }, - cleanup: func() { cleanupCalled = true }, - } - - err := r.Close() - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if !cleanupCalled { - t.Error("cleanup should always be called") - } - }) -} - -// TestNewChatRuntime_CleanupOnFailure documents the expected behavior when -// CreateAgent fails during NewChatRuntime initialization. -// -// The fix in runtime.go:60-66 ensures: -// 1. application.Close() is called to stop background goroutines -// 2. cleanup() is called to release DB pool and OTel resources -// 3. The original error is returned -// -// This test cannot easily verify the behavior without mocking, -// but documents the contract for future maintainers. -func TestNewChatRuntime_CleanupOnFailure(t *testing.T) { - t.Run("documented behavior on CreateAgent failure", func(t *testing.T) { - // When CreateAgent fails: - // 1. application is already created (with background goroutine) - // 2. We must call application.Close() to: - // - Cancel context - // - Wait for errgroup (background IndexSystemKnowledge) - // 3. Then call cleanup() to close DB pool - // - // Without this order: - // - Background goroutine may use closed DB pool - // - Goroutine leak if context not canceled - // - // Integration testing is required to fully verify this behavior. - t.Log("See runtime.go:60-66 for implementation") - }) -} - -// TestErrors_Join verifies errors.Join behavior used in Close(). -func TestErrors_Join(t *testing.T) { - t.Run("nil errors return nil", func(t *testing.T) { - err := errors.Join(nil, nil) - if err != nil { - t.Errorf("expected nil, got %v", err) - } - }) - - t.Run("empty slice returns nil", func(t *testing.T) { - var errs []error - err := errors.Join(errs...) - if err != nil { - t.Errorf("expected nil, got %v", err) - } - }) - - t.Run("single error preserved", func(t *testing.T) { - original := errors.New("test error") - errs := []error{original} - err := errors.Join(errs...) - if err == nil { - t.Fatal("expected error") - } - if !errors.Is(err, original) { - t.Error("error should wrap original") - } - }) -} diff --git a/internal/app/setup.go b/internal/app/setup.go index d172497..fc1e071 100644 --- a/internal/app/setup.go +++ b/internal/app/setup.go @@ -2,24 +2,26 @@ package app import ( "context" + "errors" "fmt" "log/slog" + "os" "time" - "golang.org/x/sync/errgroup" - "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/compat_oai/openai" "github.com/firebase/genkit/go/plugins/googlegenai" "github.com/firebase/genkit/go/plugins/ollama" "github.com/firebase/genkit/go/plugins/postgresql" "github.com/jackc/pgx/v5/pgxpool" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" + sdktrace "go.opentelemetry.io/otel/sdk/trace" "github.com/koopa0/koopa/db" "github.com/koopa0/koopa/internal/config" - "github.com/koopa0/koopa/internal/observability" "github.com/koopa0/koopa/internal/rag" "github.com/koopa0/koopa/internal/security" "github.com/koopa0/koopa/internal/session" @@ -27,96 +29,126 @@ import ( "github.com/koopa0/koopa/internal/tools" ) -// InitializeApp creates and initializes all application dependencies. -// Returns the App, a cleanup function for infrastructure resources (DB pool, OTel), -// and any initialization error. -func InitializeApp(ctx context.Context, cfg *config.Config) (*App, func(), error) { - otelCleanup, err := provideOtelShutdown(ctx, cfg) - if err != nil { - return nil, nil, err - } +// Setup creates and initializes the application. +// Returns an App with embedded cleanup — call Close() to release. +func Setup(ctx context.Context, cfg *config.Config) (_ *App, retErr error) { + a := &App{Config: cfg} + + // On error, clean up everything already initialized + defer func() { + if retErr != nil { + if err := a.Close(); err != nil { + slog.Warn("cleanup during setup failure", "error", err) + } + } + }() + + a.otelCleanup = provideOtelShutdown(ctx, cfg) + pool, dbCleanup, err := provideDBPool(ctx, cfg) if err != nil { - otelCleanup() - return nil, nil, err + return nil, err } + a.dbCleanup = dbCleanup + a.DBPool = pool + postgres, err := providePostgresPlugin(ctx, pool, cfg) if err != nil { - dbCleanup() - otelCleanup() - return nil, nil, err + return nil, err } + g, err := provideGenkit(ctx, cfg, postgres) if err != nil { - dbCleanup() - otelCleanup() - return nil, nil, err + return nil, err } + a.Genkit = g + embedder := provideEmbedder(g, cfg) + if embedder == nil { + return nil, fmt.Errorf("embedder %q not found for provider %q", cfg.EmbedderModel, cfg.Provider) + } + a.Embedder = embedder + docStore, retriever, err := provideRAGComponents(ctx, g, postgres, embedder) if err != nil { - dbCleanup() - otelCleanup() - return nil, nil, err + return nil, err } - store := provideSessionStore(pool) + a.DocStore = docStore + a.Retriever = retriever + + a.SessionStore = provideSessionStore(pool) + path, err := providePathValidator() if err != nil { - dbCleanup() - otelCleanup() - return nil, nil, err + return nil, err } - v, err := provideTools(g, path, retriever, docStore, cfg) - if err != nil { - dbCleanup() - otelCleanup() - return nil, nil, err - } - application := newApp(ctx, cfg, g, embedder, pool, docStore, retriever, store, path, v) - - // Start background system knowledge indexing. - // Launched here (not in newApp) to keep the constructor side-effect free. - //nolint:contextcheck // Independent context: indexing must complete even if parent is canceled - application.Go(func() error { - indexCtx, indexCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer indexCancel() - - count, err := rag.IndexSystemKnowledge(indexCtx, docStore, pool) - if err != nil { - slog.Debug("system knowledge indexing failed (non-critical)", "error", err) - return nil - } - slog.Debug("system knowledge indexed successfully", "count", count) - return nil - }) - - return application, func() { - dbCleanup() - otelCleanup() - }, nil + a.PathValidator = path + + if err := provideTools(a); err != nil { + return nil, err + } + + // Set up lifecycle management + _, cancel := context.WithCancel(ctx) + a.cancel = cancel + + return a, nil } // provideOtelShutdown sets up Datadog tracing before Genkit initialization. // Must be called before provideGenkit to ensure TracerProvider is ready. -func provideOtelShutdown(ctx context.Context, cfg *config.Config) (func(), error) { - shutdown, err := observability.SetupDatadog(ctx, observability.Config{ - AgentHost: cfg.Datadog.AgentHost, - Environment: cfg.Datadog.Environment, - ServiceName: cfg.Datadog.ServiceName, - }) +// +// Traces are exported to a local Datadog Agent via OTLP HTTP (localhost:4318). +// The Agent handles authentication, buffering, and forwarding to Datadog backend. +func provideOtelShutdown(ctx context.Context, cfg *config.Config) func() { + dd := cfg.Datadog + + agentHost := dd.AgentHost + if agentHost == "" { + agentHost = "localhost:4318" + } + + // Set OTEL env vars for Genkit's TracerProvider to pick up. + // SAFETY: os.Setenv is not concurrent-safe, but this function is called + // exactly once during startup in Setup, before goroutines are spawned. + if dd.ServiceName != "" { + _ = os.Setenv("OTEL_SERVICE_NAME", dd.ServiceName) + } + if dd.Environment != "" { + _ = os.Setenv("OTEL_RESOURCE_ATTRIBUTES", "deployment.environment="+dd.Environment) + } + + // Create OTLP HTTP exporter pointing to local Datadog Agent. + // Agent handles authentication and forwarding to Datadog backend. + exporter, err := otlptracehttp.New(ctx, + otlptracehttp.WithEndpoint(agentHost), + otlptracehttp.WithInsecure(), // localhost doesn't need TLS + ) if err != nil { - return nil, err + slog.Warn("creating datadog exporter, tracing disabled", "error", err) + return func() {} } + // Register BatchSpanProcessor with Genkit's TracerProvider. + processor := sdktrace.NewBatchSpanProcessor(exporter) + tracing.TracerProvider().RegisterSpanProcessor(processor) + + slog.Debug("datadog tracing enabled", + "agent", agentHost, + "service", dd.ServiceName, + "environment", dd.Environment, + ) + + shutdown := tracing.TracerProvider().Shutdown + //nolint:contextcheck // Independent context: shutdown runs during teardown when parent is canceled - cleanupFn := func() { + return func() { shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := shutdown(shutdownCtx); err != nil { - slog.Warn("failed to shutdown tracer provider", "error", err) + slog.Warn("shutting down tracer provider", "error", err) } } - return cleanupFn, nil } // providePostgresPlugin creates the Genkit PostgreSQL plugin. @@ -124,7 +156,7 @@ func provideOtelShutdown(ctx context.Context, cfg *config.Config) (func(), error func providePostgresPlugin(ctx context.Context, pool *pgxpool.Pool, cfg *config.Config) (*postgresql.Postgres, error) { pEngine, err := postgresql.NewPostgresEngine(ctx, postgresql.WithPool(pool), postgresql.WithDatabase(cfg.PostgresDBName)) if err != nil { - return nil, fmt.Errorf("failed to create postgres engine: %w", err) + return nil, fmt.Errorf("creating postgres engine: %w", err) } return &postgresql.Postgres{Engine: pEngine}, nil @@ -132,7 +164,7 @@ func providePostgresPlugin(ctx context.Context, pool *pgxpool.Pool, cfg *config. // provideGenkit initializes Genkit with the configured AI provider and PostgreSQL plugins. // Supports gemini (default), ollama, and openai providers. -// Call ordering in InitializeApp ensures tracing is set up first. +// Call ordering in Setup ensures tracing is set up first. func provideGenkit(ctx context.Context, cfg *config.Config, postgres *postgresql.Postgres) (*genkit.Genkit, error) { promptDir := cfg.PromptDir if promptDir == "" { @@ -154,7 +186,7 @@ func provideGenkit(ctx context.Context, cfg *config.Config, postgres *postgresql genkit.WithPromptDir(promptDir), ) if g == nil { - return nil, fmt.Errorf("failed to initialize Genkit with ollama provider") + return nil, errors.New("initializing genkit with ollama provider") } // Ollama requires explicit model registration (no auto-discovery) ollamaPlugin.DefineModel(g, ollama.ModelDefinition{ @@ -172,7 +204,7 @@ func provideGenkit(ctx context.Context, cfg *config.Config, postgres *postgresql genkit.WithPromptDir(promptDir), ) if g == nil { - return nil, fmt.Errorf("failed to initialize Genkit with openai provider") + return nil, errors.New("initializing genkit with openai provider") } slog.Info("initialized Genkit with openai provider", "model", cfg.ModelName) @@ -182,7 +214,7 @@ func provideGenkit(ctx context.Context, cfg *config.Config, postgres *postgresql genkit.WithPromptDir(promptDir), ) if g == nil { - return nil, fmt.Errorf("failed to initialize Genkit with gemini provider") + return nil, errors.New("initializing genkit with gemini provider") } slog.Info("initialized Genkit with gemini provider", "model", cfg.ModelName) } @@ -217,12 +249,12 @@ func provideEmbedder(g *genkit.Genkit, cfg *config.Config) ai.Embedder { // Pool is configured with sensible defaults for connection management. func provideDBPool(ctx context.Context, cfg *config.Config) (*pgxpool.Pool, func(), error) { if err := db.Migrate(cfg.PostgresURL()); err != nil { - return nil, nil, fmt.Errorf("failed to run migrations: %w", err) + return nil, nil, fmt.Errorf("running migrations: %w", err) } poolCfg, err := pgxpool.ParseConfig(cfg.PostgresConnectionString()) if err != nil { - return nil, nil, fmt.Errorf("failed to parse connection config: %w", err) + return nil, nil, fmt.Errorf("parsing connection config: %w", err) } poolCfg.MaxConns = 10 @@ -233,14 +265,14 @@ func provideDBPool(ctx context.Context, cfg *config.Config) (*pgxpool.Pool, func pool, err := pgxpool.NewWithConfig(ctx, poolCfg) if err != nil { - return nil, nil, fmt.Errorf("failed to create connection pool: %w", err) + return nil, nil, fmt.Errorf("creating connection pool: %w", err) } pingCtx, pingCancel := context.WithTimeout(ctx, 5*time.Second) defer pingCancel() if err := pool.Ping(pingCtx); err != nil { pool.Close() - return nil, nil, fmt.Errorf("failed to ping database: %w", err) + return nil, nil, fmt.Errorf("pinging database: %w", err) } cleanup := func() { @@ -256,7 +288,7 @@ func provideRAGComponents(ctx context.Context, g *genkit.Genkit, postgres *postg cfg := rag.NewDocStoreConfig(embedder) docStore, retriever, err := postgresql.DefineRetriever(ctx, g, postgres, cfg) if err != nil { - return nil, nil, fmt.Errorf("failed to define retriever: %w", err) + return nil, nil, fmt.Errorf("defining retriever: %w", err) } return docStore, retriever, nil @@ -272,96 +304,65 @@ func providePathValidator() (*security.Path, error) { return security.NewPath([]string{"."}) } -// provideTools registers all tools at construction time. -// Tools are registered once here, not lazily in CreateAgent. -func provideTools(g *genkit.Genkit, pathValidator *security.Path, retriever ai.Retriever, docStore *postgresql.DocStore, cfg *config.Config) ([]ai.Tool, error) { +// provideTools creates toolsets, registers them with Genkit, and stores both +// the concrete toolsets and the Genkit-wrapped references in a. +func provideTools(a *App) error { logger := slog.Default() + cfg := a.Config var allTools []ai.Tool - ft, err := tools.NewFileTools(pathValidator, logger) + ft, err := tools.NewFile(a.PathValidator, logger) if err != nil { - return nil, fmt.Errorf("creating file tools: %w", err) + return fmt.Errorf("creating file tools: %w", err) } - fileTools, err := tools.RegisterFileTools(g, ft) + a.File = ft + fileTools, err := tools.RegisterFile(a.Genkit, ft) if err != nil { - return nil, fmt.Errorf("registering file tools: %w", err) + return fmt.Errorf("registering file tools: %w", err) } allTools = append(allTools, fileTools...) cmdValidator := security.NewCommand() envValidator := security.NewEnv() - st, err := tools.NewSystemTools(cmdValidator, envValidator, logger) + st, err := tools.NewSystem(cmdValidator, envValidator, logger) if err != nil { - return nil, fmt.Errorf("creating system tools: %w", err) + return fmt.Errorf("creating system tools: %w", err) } - systemTools, err := tools.RegisterSystemTools(g, st) + a.System = st + systemTools, err := tools.RegisterSystem(a.Genkit, st) if err != nil { - return nil, fmt.Errorf("registering system tools: %w", err) + return fmt.Errorf("registering system tools: %w", err) } allTools = append(allTools, systemTools...) - nt, err := tools.NewNetworkTools(tools.NetworkConfig{ + nt, err := tools.NewNetwork(tools.NetConfig{ SearchBaseURL: cfg.SearXNG.BaseURL, FetchParallelism: cfg.WebScraper.Parallelism, FetchDelay: time.Duration(cfg.WebScraper.DelayMs) * time.Millisecond, FetchTimeout: time.Duration(cfg.WebScraper.TimeoutMs) * time.Millisecond, }, logger) if err != nil { - return nil, fmt.Errorf("creating network tools: %w", err) + return fmt.Errorf("creating network tools: %w", err) } - networkTools, err := tools.RegisterNetworkTools(g, nt) + a.Network = nt + networkTools, err := tools.RegisterNetwork(a.Genkit, nt) if err != nil { - return nil, fmt.Errorf("registering network tools: %w", err) + return fmt.Errorf("registering network tools: %w", err) } allTools = append(allTools, networkTools...) - kt, err := tools.NewKnowledgeTools(retriever, docStore, logger) + kt, err := tools.NewKnowledge(a.Retriever, a.DocStore, logger) if err != nil { - return nil, fmt.Errorf("creating knowledge tools: %w", err) + return fmt.Errorf("creating knowledge tools: %w", err) } - knowledgeTools, err := tools.RegisterKnowledgeTools(g, kt) + a.Knowledge = kt + knowledgeTools, err := tools.RegisterKnowledge(a.Genkit, kt) if err != nil { - return nil, fmt.Errorf("registering knowledge tools: %w", err) + return fmt.Errorf("registering knowledge tools: %w", err) } allTools = append(allTools, knowledgeTools...) - slog.Info("tools registered at construction", "count", len(allTools)) - return allTools, nil -} -// newApp constructs an App instance with all dependencies. -// All dependencies are injected by InitializeApp. -// Tools are pre-registered by provideTools. -// NOTE: This constructor has no side effects. Background tasks are started by the caller. -func newApp( - ctx context.Context, - cfg *config.Config, - g *genkit.Genkit, - embedder ai.Embedder, - pool *pgxpool.Pool, - docStore *postgresql.DocStore, - retriever ai.Retriever, - sessionStore *session.Store, - pathValidator *security.Path, registeredTools []ai.Tool, -) *App { - - appCtx, cancel := context.WithCancel(ctx) - - eg, _ := errgroup.WithContext(appCtx) - - app := &App{ - Config: cfg, - ctx: appCtx, - cancel: cancel, - eg: eg, - Genkit: g, - Embedder: embedder, - DBPool: pool, - DocStore: docStore, - Retriever: retriever, - SessionStore: sessionStore, - PathValidator: pathValidator, - Tools: registeredTools, - } - - return app + a.Tools = allTools + slog.Info("tools registered at construction", "count", len(allTools)) + return nil } diff --git a/internal/agent/chat/chat.go b/internal/chat/chat.go similarity index 69% rename from internal/agent/chat/chat.go rename to internal/chat/chat.go index de4163d..680b3bc 100644 --- a/internal/agent/chat/chat.go +++ b/internal/chat/chat.go @@ -4,17 +4,15 @@ import ( "context" "errors" "fmt" + "log/slog" "strings" "time" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/postgresql" "github.com/google/uuid" "golang.org/x/time/rate" - "github.com/koopa0/koopa/internal/log" - "github.com/koopa0/koopa/internal/rag" "github.com/koopa0/koopa/internal/session" ) @@ -31,8 +29,17 @@ const ( // NOTE: The LLM model is configured in the Dotprompt file, not via Config. KoopaPromptName = "koopa" - // FallbackResponseMessage is the message returned when the model produces an empty response. - FallbackResponseMessage = "I apologize, but I couldn't generate a response. Please try rephrasing your question." + // fallbackResponseMessage is the message returned when the model produces an empty response. + fallbackResponseMessage = "I apologize, but I couldn't generate a response. Please try rephrasing your question." +) + +// Sentinel errors for agent operations. +var ( + // ErrInvalidSession indicates the session ID is invalid or malformed. + ErrInvalidSession = errors.New("invalid session") + + // ErrExecutionFailed indicates agent execution failed. + ErrExecutionFailed = errors.New("execution failed") ) // Response represents the complete result of an agent execution. @@ -49,15 +56,13 @@ type StreamCallback func(ctx context.Context, chunk *ai.ModelResponseChunk) erro // Config contains all required parameters for Chat agent. type Config struct { Genkit *genkit.Genkit - Retriever ai.Retriever // Genkit Retriever for RAG context SessionStore *session.Store - Logger log.Logger + Logger *slog.Logger Tools []ai.Tool // Pre-registered tools from RegisterXxxTools() // Configuration values ModelName string // Provider-qualified model name (e.g., "googleai/gemini-2.5-flash", "ollama/llama3.3") MaxTurns int // Maximum agentic loop turns - RAGTopK int // Number of RAG documents to retrieve Language string // Response language preference // Resilience configuration @@ -74,9 +79,6 @@ func (cfg Config) validate() error { if cfg.Genkit == nil { return errors.New("genkit instance is required") } - if cfg.Retriever == nil { - return errors.New("retriever is required") - } if cfg.SessionStore == nil { return errors.New("session store is required") } @@ -89,20 +91,19 @@ func (cfg Config) validate() error { return nil } -// Chat is Koopa's main conversational agent. +// Agent is Koopa's main conversational agent. // It provides LLM-powered conversations with tool calling and knowledge base integration. // -// Chat is stateless and uses dependency injection. +// Agent is stateless and uses dependency injection. // Required parameters are provided via Config struct. // // All configuration values are captured immutably at construction time // to ensure thread-safe concurrent access. -type Chat struct { +type Agent struct { // Immutable configuration (captured at construction) modelName string // Provider-qualified model name (overrides Dotprompt model) languagePrompt string // Resolved language for prompt template maxTurns int - ragTopK int // Resilience (captured at construction) retryConfig RetryConfig @@ -114,32 +115,32 @@ type Chat struct { // Dependencies (read-only after construction) g *genkit.Genkit - retriever ai.Retriever // Genkit Retriever for RAG context sessions *session.Store - logger log.Logger + logger *slog.Logger tools []ai.Tool // Pre-registered tools (passed in via Config) toolRefs []ai.ToolRef // Cached at construction (ai.Tool implements ai.ToolRef) toolNames string // Cached as comma-separated for logging prompt ai.Prompt // Cached Dotprompt instance (model configured in prompt file) } -// New creates a new Chat agent with required configuration. +// New creates a new Agent with required configuration. +// +// RAG context is provided by knowledge tools (search_documents, search_history, +// search_system_knowledge) which the LLM calls when it determines context is needed. // // NOTE: The LLM model is configured in prompts/koopa.prompt, not via Config. // // Example: // -// chat, err := chat.New(chat.Config{ +// agent, err := chat.New(chat.Config{ // Genkit: g, -// Retriever: retriever, // Genkit Retriever from postgresql.DefineRetriever // SessionStore: sessionStore, // Logger: logger, // Tools: tools, // Pre-registered via RegisterXxxTools() // MaxTurns: cfg.MaxTurns, -// RAGTopK: cfg.RAGTopK, // Language: cfg.Language, // }) -func New(cfg Config) (*Chat, error) { +func New(cfg Config) (*Agent, error) { if err := cfg.validate(); err != nil { return nil, err } @@ -187,12 +188,11 @@ func New(cfg Config) (*Chat, error) { names[i] = t.Name() } - c := &Chat{ + a := &Agent{ // Immutable configuration modelName: cfg.ModelName, languagePrompt: languagePrompt, maxTurns: maxTurns, - ragTopK: cfg.RAGTopK, // Resilience retryConfig: retryConfig, @@ -204,7 +204,6 @@ func New(cfg Config) (*Chat, error) { // Dependencies g: cfg.Genkit, - retriever: cfg.Retriever, sessions: cfg.SessionStore, logger: cfg.Logger, tools: cfg.Tools, // Already registered with Genkit @@ -214,44 +213,44 @@ func New(cfg Config) (*Chat, error) { // Load Dotprompt (koopa.prompt) - REQUIRED // NOTE: Model is configured in the prompt file, not via Config - c.prompt = genkit.LookupPrompt(c.g, KoopaPromptName) - if c.prompt == nil { + a.prompt = genkit.LookupPrompt(a.g, KoopaPromptName) + if a.prompt == nil { return nil, fmt.Errorf("dotprompt '%s' not found: ensure prompts directory is configured correctly", KoopaPromptName) } - c.logger.Debug("loaded dotprompt successfully", "prompt_name", KoopaPromptName) + a.logger.Debug("loaded dotprompt successfully", "prompt_name", KoopaPromptName) - c.logger.Info("chat agent initialized", - "totalTools", len(c.tools), - "maxTurns", c.maxTurns, + a.logger.Info("chat agent initialized", + "totalTools", len(a.tools), + "maxTurns", a.maxTurns, ) - return c, nil + return a, nil } // Execute runs the chat agent with the given input (non-streaming). // This is a convenience wrapper around ExecuteStream with nil callback. -func (c *Chat) Execute(ctx context.Context, sessionID uuid.UUID, input string) (*Response, error) { - return c.ExecuteStream(ctx, sessionID, input, nil) +func (a *Agent) Execute(ctx context.Context, sessionID uuid.UUID, input string) (*Response, error) { + return a.ExecuteStream(ctx, sessionID, input, nil) } // ExecuteStream runs the chat agent with optional streaming output. // If callback is non-nil, it is called for each chunk of the response as it's generated. // If callback is nil, the response is generated without streaming (equivalent to Execute). // The final response is always returned after generation completes. -func (c *Chat) ExecuteStream(ctx context.Context, sessionID uuid.UUID, input string, callback StreamCallback) (*Response, error) { +func (a *Agent) ExecuteStream(ctx context.Context, sessionID uuid.UUID, input string, callback StreamCallback) (*Response, error) { streaming := callback != nil - c.logger.Debug("executing chat agent", + a.logger.Debug("executing chat agent", "session_id", sessionID, "streaming", streaming) // Load session history - history, err := c.sessions.History(ctx, sessionID) + historyMessages, err := a.sessions.History(ctx, sessionID) if err != nil { - return nil, fmt.Errorf("failed to get history: %w", err) + return nil, fmt.Errorf("getting history: %w", err) } // Generate response using unified core logic - resp, err := c.generateResponse(ctx, input, history.Messages(), callback) + resp, err := a.generateResponse(ctx, input, historyMessages, callback) if err != nil { return nil, err } @@ -261,22 +260,18 @@ func (c *Chat) ExecuteStream(ctx context.Context, sessionID uuid.UUID, input str // Only apply fallback when truly empty (no text AND no tool requests) // When LLM returns empty text but has tool requests, this is valid agentic behavior if strings.TrimSpace(responseText) == "" && len(resp.ToolRequests()) == 0 { - c.logger.Warn("model returned empty response with no tool requests", + a.logger.Warn("model returned empty response with no tool requests", "session_id", sessionID) - responseText = FallbackResponseMessage + responseText = fallbackResponseMessage } - // Update history with user input and response - history.Add(input, responseText) - - // Save updated history to session store using AppendMessages (preferred) + // Save new messages to session store newMessages := []*ai.Message{ ai.NewUserMessage(ai.NewTextPart(input)), ai.NewModelMessage(ai.NewTextPart(responseText)), } - if err := c.sessions.AppendMessages(ctx, sessionID, newMessages); err != nil { - c.logger.Error("failed to append messages to history", "error", err) - // Don't fail the request, just log the error + if err := a.sessions.AppendMessages(ctx, sessionID, newMessages); err != nil { + a.logger.Error("appending messages to history", "error", err) // best-effort: don't fail the request } // Return formatted response @@ -288,7 +283,7 @@ func (c *Chat) ExecuteStream(ctx context.Context, sessionID uuid.UUID, input str // generateResponse is the unified response generation logic for both streaming and non-streaming modes. // If callback is non-nil, streaming is enabled; otherwise, standard generation is used. -func (c *Chat) generateResponse(ctx context.Context, input string, historyMessages []*ai.Message, callback StreamCallback) (*ai.ModelResponse, error) { +func (a *Agent) generateResponse(ctx context.Context, input string, historyMessages []*ai.Message, callback StreamCallback) (*ai.ModelResponse, error) { // Build messages: deep copy history and append current user input // CRITICAL: Deep copy is required to prevent DATA RACE in Genkit's renderMessages() // Genkit modifies msg.Content in-place, so concurrent executions sharing the same @@ -297,33 +292,25 @@ func (c *Chat) generateResponse(ctx context.Context, input string, historyMessag // Apply token budget before adding new message // This ensures we don't exceed context window limits - messages = c.truncateHistory(messages, c.tokenBudget.MaxHistoryTokens) + messages = a.truncateHistory(messages, a.tokenBudget.MaxHistoryTokens) messages = append(messages, ai.NewUserMessage(ai.NewTextPart(input))) - // Retrieve relevant documents for RAG context (graceful fallback on error) - ragDocs := c.retrieveRAGContext(ctx, input) - // Build execute options (using cached toolRefs and languagePrompt) opts := []ai.PromptExecuteOption{ ai.WithInput(map[string]any{ - "language": c.languagePrompt, + "language": a.languagePrompt, }), ai.WithMessagesFn(func(_ context.Context, _ any) ([]*ai.Message, error) { return messages, nil }), - ai.WithTools(c.toolRefs...), - ai.WithMaxTurns(c.maxTurns), + ai.WithTools(a.toolRefs...), + ai.WithMaxTurns(a.maxTurns), } // Override model from Dotprompt if configured (supports multi-provider) - if c.modelName != "" { - opts = append(opts, ai.WithModelName(c.modelName)) - } - - // Add RAG documents if available - if len(ragDocs) > 0 { - opts = append(opts, ai.WithDocs(ragDocs...)) + if a.modelName != "" { + opts = append(opts, ai.WithModelName(a.modelName)) } // Add streaming callback if provided @@ -332,84 +319,31 @@ func (c *Chat) generateResponse(ctx context.Context, input string, historyMessag } // Diagnostic logging (using cached toolNames - zero allocation) - c.logger.Debug("executing prompt", - "toolCount", len(c.tools), - "tools", c.toolNames, - "maxTurns", c.maxTurns, + a.logger.Debug("executing prompt", + "toolCount", len(a.tools), + "tools", a.toolNames, + "maxTurns", a.maxTurns, "queryLength", len(input), ) // Check circuit breaker before attempting request - if err := c.circuitBreaker.Allow(); err != nil { - c.logger.Warn("circuit breaker is open, rejecting request", - "state", c.circuitBreaker.State().String()) + if err := a.circuitBreaker.Allow(); err != nil { + a.logger.Warn("circuit breaker is open, rejecting request", + "state", a.circuitBreaker.State().String()) return nil, fmt.Errorf("service unavailable: %w", err) } // Execute prompt with retry mechanism - resp, err := c.executeWithRetry(ctx, opts) + resp, err := a.executeWithRetry(ctx, opts) if err != nil { - c.circuitBreaker.Failure() + a.circuitBreaker.Failure() return nil, err } - c.circuitBreaker.Success() + a.circuitBreaker.Success() return resp, nil } -// ragRetrievalTimeout is the maximum time allowed for RAG document retrieval. -// This prevents slow queries from blocking the entire chat request. -const ragRetrievalTimeout = 5 * time.Second - -// retrieveRAGContext retrieves relevant documents from the knowledge base. -// Returns empty slice on error (graceful degradation). -func (c *Chat) retrieveRAGContext(ctx context.Context, query string) []*ai.Document { - // Skip RAG if topK is not configured or zero - if c.ragTopK <= 0 { - return nil - } - - // Add dedicated timeout for RAG retrieval to prevent slow queries - // from blocking the entire chat request - ragCtx, cancel := context.WithTimeout(ctx, ragRetrievalTimeout) - defer cancel() - - // Build retriever request with source_type filter for files (documents) - req := &ai.RetrieverRequest{ - Query: ai.DocumentFromText(query, nil), - Options: &postgresql.RetrieverOptions{ - Filter: "source_type = '" + rag.SourceTypeFile + "'", - K: c.ragTopK, - }, - } - - // Retrieve documents with timeout - resp, err := c.retriever.Retrieve(ragCtx, req) - if err != nil { - // Use Debug for expected errors (timeout, cancellation) - // Use Warn for unexpected errors (DB issues, etc.) that ops should know about - if ctx.Err() != nil || ragCtx.Err() != nil { - c.logger.Debug("RAG retrieval canceled or timed out (continuing without context)", - "error", err, - "timeout", ragRetrievalTimeout, - "query_length", len(query)) - } else { - c.logger.Warn("RAG retrieval failed (continuing without context)", - "error", err, - "query_length", len(query)) - } - return nil - } - - if len(resp.Documents) > 0 { - c.logger.Debug("retrieved RAG context", - "document_count", len(resp.Documents), - "query_length", len(query)) - } - - return resp.Documents -} - // deepCopyMessages creates independent copies of Message and Part structs. // // WORKAROUND: Genkit's renderMessages() modifies msg.Content in-place, @@ -420,7 +354,7 @@ func (c *Chat) retrieveRAGContext(ctx context.Context, query string) []*ai.Docum // // To remove this workaround: // 1. Upgrade Genkit: go get -u github.com/firebase/genkit/go@latest -// 2. Run: go test -race ./internal/agent/chat/... +// 2. Run: go test -race ./internal/chat/... // 3. If race detector passes, remove deepCopyMessages() calls // 4. If race still fails, update version in this comment func deepCopyMessages(msgs []*ai.Message) []*ai.Message { @@ -492,3 +426,56 @@ func shallowCopyMap(m map[string]any) map[string]any { } return cp } + +// Title generation constants. +const ( + titleMaxLength = 50 + titleGenerationTimeout = 5 * time.Second + titleInputMaxRunes = 500 +) + +const titlePrompt = `Generate a concise title (max 50 characters) for a chat session based on this first message. +The title should capture the main topic or intent. +Return ONLY the title text, no quotes, no explanations, no punctuation at the end. + +Message: %s + +Title:` + +// GenerateTitle generates a concise session title from the user's first message. +// Uses AI generation with fallback to simple truncation. +// Returns empty string on failure (best-effort). +func (a *Agent) GenerateTitle(ctx context.Context, userMessage string) string { + ctx, cancel := context.WithTimeout(ctx, titleGenerationTimeout) + defer cancel() + + inputRunes := []rune(userMessage) + if len(inputRunes) > titleInputMaxRunes { + userMessage = string(inputRunes[:titleInputMaxRunes]) + "..." + } + + opts := []ai.GenerateOption{ + ai.WithPrompt(titlePrompt, userMessage), + } + if a.modelName != "" { + opts = append(opts, ai.WithModelName(a.modelName)) + } + + response, err := genkit.Generate(ctx, a.g, opts...) + if err != nil { + a.logger.Debug("AI title generation failed", "error", err) + return "" + } + + title := strings.TrimSpace(response.Text()) + if title == "" { + return "" + } + + titleRunes := []rune(title) + if len(titleRunes) > titleMaxLength { + title = string(titleRunes[:titleMaxLength-3]) + "..." + } + + return title +} diff --git a/internal/chat/chat_test.go b/internal/chat/chat_test.go new file mode 100644 index 0000000..f0b4c70 --- /dev/null +++ b/internal/chat/chat_test.go @@ -0,0 +1,310 @@ +package chat + +import ( + "log/slog" + "strings" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + + "github.com/koopa0/koopa/internal/session" +) + +// TestConfig_validate tests that each validation check in Config.validate() +// fires independently. Each case provides enough deps to pass prior checks. +func TestConfig_validate(t *testing.T) { + t.Parallel() + + // Minimal non-nil stubs — validate() only checks nil, never dereferences. + stubG := new(genkit.Genkit) + stubS := new(session.Store) + stubL := slog.New(slog.DiscardHandler) + + tests := []struct { + name string + cfg Config + errContains string + }{ + { + name: "nil genkit", + cfg: Config{}, + errContains: "genkit instance is required", + }, + { + name: "nil session store", + cfg: Config{ + Genkit: stubG, + }, + errContains: "session store is required", + }, + { + name: "nil logger", + cfg: Config{ + Genkit: stubG, + SessionStore: stubS, + }, + errContains: "logger is required", + }, + { + name: "empty tools", + cfg: Config{ + Genkit: stubG, + SessionStore: stubS, + Logger: stubL, + Tools: []ai.Tool{}, + }, + errContains: "at least one tool is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := tt.cfg.validate() + if err == nil { + t.Fatal("validate() expected error, got nil") + } + if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("validate() error = %q, want to contain %q", err.Error(), tt.errContains) + } + }) + } +} + +func TestDeepCopyMessages_NilInput(t *testing.T) { + t.Parallel() + got := deepCopyMessages(nil) + if got != nil { + t.Errorf("deepCopyMessages(nil) = %v, want nil", got) + } +} + +func TestDeepCopyMessages_EmptySlice(t *testing.T) { + t.Parallel() + got := deepCopyMessages([]*ai.Message{}) + if got == nil { + t.Fatal("deepCopyMessages(empty) = nil, want non-nil empty slice") + } + if len(got) != 0 { + t.Errorf("deepCopyMessages(empty) len = %d, want 0", len(got)) + } +} + +func TestDeepCopyMessages_MutateOriginalText(t *testing.T) { + t.Parallel() + + original := []*ai.Message{ + ai.NewUserMessage(ai.NewTextPart("hello world")), + } + + copied := deepCopyMessages(original) + + // Mutate the original message's content slice + original[0].Content[0].Text = "MUTATED" + + if copied[0].Content[0].Text != "hello world" { + t.Errorf("deepCopyMessages() copy was affected by original mutation: got %q, want %q", + copied[0].Content[0].Text, "hello world") + } +} + +func TestDeepCopyMessages_MutateOriginalContentSlice(t *testing.T) { + t.Parallel() + + original := []*ai.Message{ + ai.NewUserMessage(ai.NewTextPart("first"), ai.NewTextPart("second")), + } + + copied := deepCopyMessages(original) + + // Append to original's content slice — should not affect copy + original[0].Content = append(original[0].Content, ai.NewTextPart("third")) + + if len(copied[0].Content) != 2 { + t.Errorf("deepCopyMessages() copy content len = %d, want 2", len(copied[0].Content)) + } +} + +func TestDeepCopyMessages_PreservesRole(t *testing.T) { + t.Parallel() + + original := []*ai.Message{ + ai.NewUserMessage(ai.NewTextPart("q")), + ai.NewModelMessage(ai.NewTextPart("a")), + } + + copied := deepCopyMessages(original) + + if copied[0].Role != ai.RoleUser { + t.Errorf("deepCopyMessages()[0].Role = %q, want %q", copied[0].Role, ai.RoleUser) + } + if copied[1].Role != ai.RoleModel { + t.Errorf("deepCopyMessages()[1].Role = %q, want %q", copied[1].Role, ai.RoleModel) + } +} + +func TestDeepCopyMessages_Metadata(t *testing.T) { + t.Parallel() + + original := []*ai.Message{{ + Role: ai.RoleUser, + Content: []*ai.Part{ai.NewTextPart("test")}, + Metadata: map[string]any{"key": "value"}, + }} + + copied := deepCopyMessages(original) + + // Mutate original metadata + original[0].Metadata["key"] = "MUTATED" + + if copied[0].Metadata["key"] != "value" { + t.Errorf("deepCopyMessages() metadata was affected by mutation: got %q, want %q", + copied[0].Metadata["key"], "value") + } +} + +func TestDeepCopyPart_NilInput(t *testing.T) { + t.Parallel() + got := deepCopyPart(nil) + if got != nil { + t.Errorf("deepCopyPart(nil) = %v, want nil", got) + } +} + +func TestDeepCopyPart_TextPart(t *testing.T) { + t.Parallel() + + original := ai.NewTextPart("hello") + copied := deepCopyPart(original) + + original.Text = "MUTATED" + + if copied.Text != "hello" { + t.Errorf("deepCopyPart() text affected by mutation: got %q, want %q", copied.Text, "hello") + } +} + +func TestDeepCopyPart_ToolRequest(t *testing.T) { + t.Parallel() + + original := &ai.Part{ + Kind: ai.PartToolRequest, + ToolRequest: &ai.ToolRequest{ + Name: "read_file", + Input: map[string]any{"path": "/tmp/test"}, + }, + } + + copied := deepCopyPart(original) + + // Mutate original ToolRequest name + original.ToolRequest.Name = "MUTATED" + + if copied.ToolRequest.Name != "read_file" { + t.Errorf("deepCopyPart() ToolRequest.Name affected by mutation: got %q, want %q", + copied.ToolRequest.Name, "read_file") + } +} + +func TestDeepCopyPart_ToolResponse(t *testing.T) { + t.Parallel() + + original := &ai.Part{ + Kind: ai.PartToolResponse, + ToolResponse: &ai.ToolResponse{ + Name: "read_file", + Output: "file contents", + }, + } + + copied := deepCopyPart(original) + + original.ToolResponse.Name = "MUTATED" + + if copied.ToolResponse.Name != "read_file" { + t.Errorf("deepCopyPart() ToolResponse.Name affected by mutation: got %q, want %q", + copied.ToolResponse.Name, "read_file") + } +} + +func TestDeepCopyPart_Resource(t *testing.T) { + t.Parallel() + + original := &ai.Part{ + Kind: ai.PartMedia, + Resource: &ai.ResourcePart{Uri: "https://example.com/image.png"}, + } + + copied := deepCopyPart(original) + + original.Resource.Uri = "MUTATED" + + if copied.Resource.Uri != "https://example.com/image.png" { + t.Errorf("deepCopyPart() Resource.Uri affected by mutation: got %q, want %q", + copied.Resource.Uri, "https://example.com/image.png") + } +} + +func TestDeepCopyPart_PartMetadata(t *testing.T) { + t.Parallel() + + original := &ai.Part{ + Kind: ai.PartText, + Text: "test", + Custom: map[string]any{"c": "custom"}, + Metadata: map[string]any{"m": "meta"}, + } + + copied := deepCopyPart(original) + + original.Custom["c"] = "MUTATED" + original.Metadata["m"] = "MUTATED" + + if copied.Custom["c"] != "custom" { + t.Errorf("deepCopyPart() Custom map affected: got %q, want %q", copied.Custom["c"], "custom") + } + if copied.Metadata["m"] != "meta" { + t.Errorf("deepCopyPart() Metadata map affected: got %q, want %q", copied.Metadata["m"], "meta") + } +} + +func TestShallowCopyMap_NilInput(t *testing.T) { + t.Parallel() + got := shallowCopyMap(nil) + if got != nil { + t.Errorf("shallowCopyMap(nil) = %v, want nil", got) + } +} + +func TestShallowCopyMap_IndependentKeys(t *testing.T) { + t.Parallel() + + original := map[string]any{"a": "1", "b": "2"} + copied := shallowCopyMap(original) + + // Add new key to original + original["c"] = "3" + + if _, ok := copied["c"]; ok { + t.Error("shallowCopyMap() new key in original appeared in copy") + } + if len(copied) != 2 { + t.Errorf("shallowCopyMap() copy len = %d, want 2", len(copied)) + } +} + +func TestShallowCopyMap_MutateValue(t *testing.T) { + t.Parallel() + + original := map[string]any{"key": "value"} + copied := shallowCopyMap(original) + + // Overwrite original value + original["key"] = "MUTATED" + + if copied["key"] != "value" { + t.Errorf("shallowCopyMap() value affected by mutation: got %q, want %q", + copied["key"], "value") + } +} diff --git a/internal/agent/chat/circuit.go b/internal/chat/circuit.go similarity index 100% rename from internal/agent/chat/circuit.go rename to internal/chat/circuit.go diff --git a/internal/agent/chat/circuit_test.go b/internal/chat/circuit_test.go similarity index 94% rename from internal/agent/chat/circuit_test.go rename to internal/chat/circuit_test.go index c9a470f..4fad3b7 100644 --- a/internal/agent/chat/circuit_test.go +++ b/internal/chat/circuit_test.go @@ -243,21 +243,21 @@ func TestCircuitState_String(t *testing.T) { t.Parallel() tests := []struct { - state CircuitState - expected string + state CircuitState + want string }{ - {CircuitClosed, "closed"}, - {CircuitOpen, "open"}, - {CircuitHalfOpen, "half-open"}, - {CircuitState(99), "unknown"}, + {state: CircuitClosed, want: "closed"}, + {state: CircuitOpen, want: "open"}, + {state: CircuitHalfOpen, want: "half-open"}, + {state: CircuitState(99), want: "unknown"}, } for _, tt := range tests { - t.Run(tt.expected, func(t *testing.T) { + t.Run(tt.want, func(t *testing.T) { t.Parallel() - if got := tt.state.String(); got != tt.expected { - t.Errorf("String() = %q, want %q", got, tt.expected) + if got := tt.state.String(); got != tt.want { + t.Errorf("String() = %q, want %q", got, tt.want) } }) } diff --git a/internal/agent/chat/doc.go b/internal/chat/doc.go similarity index 74% rename from internal/agent/chat/doc.go rename to internal/chat/doc.go index e0c5187..0f00798 100644 --- a/internal/agent/chat/doc.go +++ b/internal/chat/doc.go @@ -1,17 +1,17 @@ // Package chat implements Koopa's main conversational agent. // -// Chat is a stateless, LLM-powered agent that provides conversational interactions +// Agent is a stateless, LLM-powered agent that provides conversational interactions // with tool calling and knowledge base integration. It uses the Google Genkit framework // for LLM inference and tool orchestration. // // # Architecture // -// The Chat agent follows a stateless design pattern with dependency injection: +// The Agent follows a stateless design pattern with dependency injection: // // InvocationContext (input) // | // v -// Chat.ExecuteStream() or Chat.Execute() +// Agent.ExecuteStream() or Agent.Execute() // | // +-- Load session history from SessionStore // | @@ -30,27 +30,33 @@ // // # Configuration // -// Chat requires configuration via the Config struct at construction time: +// Agent requires configuration via the Config struct at construction time: // // type Config struct { // Genkit *genkit.Genkit -// Retriever ai.Retriever // SessionStore *session.Store -// Logger log.Logger +// Logger *slog.Logger // Tools []ai.Tool // // // Configuration values // ModelName string // e.g., "googleai/gemini-2.5-flash" // MaxTurns int // Maximum agentic loop turns -// RAGTopK int // Number of RAG documents to retrieve // Language string // Response language preference +// +// // Resilience configuration +// RetryConfig RetryConfig +// CircuitBreakerConfig CircuitBreakerConfig +// RateLimiter *rate.Limiter +// +// // Token management +// TokenBudget TokenBudget // } // // Required fields are validated during construction. // // # Streaming Support // -// Chat supports both streaming and non-streaming execution modes: +// Agent supports both streaming and non-streaming execution modes: // // - Execute(): Non-streaming, returns complete response // - ExecuteStream(): Streaming with optional callback for real-time output @@ -66,18 +72,14 @@ // // The package provides a Genkit Flow for HTTP and observability: // -// - InitFlow(): Initializes singleton streaming Flow (must be called once at startup) -// - GetFlow(): Returns initialized Flow (panics if InitFlow not called) +// - NewFlow(): Returns singleton streaming Flow (idempotent, safe to call multiple times) // - Flow supports both Run() and Stream() methods // - Stream() enables Server-Sent Events (SSE) for real-time responses // // Example Flow usage: // -// // Initialize Flow once during application startup -// chatFlow, err := chat.InitFlow(g, chatAgent) -// if err != nil { -// return err -// } +// // Create Flow (idempotent — first call initializes, subsequent calls return cached) +// chatFlow := chat.NewFlow(g, chatAgent) // // // Non-streaming // output, err := chatFlow.Run(ctx, chat.Input{Query: "Hello", SessionID: "..."}) @@ -113,16 +115,14 @@ // // # Example Usage // -// // Create Chat agent with required configuration -// chatAgent, err := chat.New(chat.Config{ +// // Create agent with required configuration +// agent, err := chat.New(chat.Config{ // Genkit: g, -// Retriever: retriever, // SessionStore: sessionStore, // Logger: slog.Default(), // Tools: tools, // ModelName: "googleai/gemini-2.5-flash", // MaxTurns: 10, -// RAGTopK: 5, // Language: "auto", // }) // if err != nil { @@ -130,10 +130,10 @@ // } // // // Non-streaming execution -// resp, err := chatAgent.Execute(ctx, sessionID, "What is the weather?") +// resp, err := agent.Execute(ctx, sessionID, "What is the weather?") // // // Streaming execution with callback -// resp, err := chatAgent.ExecuteStream(ctx, sessionID, "What is the weather?", +// resp, err := agent.ExecuteStream(ctx, sessionID, "What is the weather?", // func(ctx context.Context, chunk *ai.ModelResponseChunk) error { // fmt.Print(chunk.Text()) // Real-time output // return nil @@ -143,14 +143,14 @@ // // The package uses sentinel errors for categorization: // -// - agent.ErrInvalidSession: Invalid session ID format -// - agent.ErrExecutionFailed: LLM or tool execution failed +// - ErrInvalidSession: Invalid session ID format +// - ErrExecutionFailed: LLM or tool execution failed // // Empty responses are handled with a fallback message to improve UX. // // # Testing // -// Chat is designed for testability: +// Agent is designed for testability: // // - Dependencies are concrete types with clear interfaces // - Stateless design eliminates test ordering issues @@ -158,6 +158,6 @@ // // # Thread Safety // -// Chat is safe for concurrent use. The underlying dependencies (SessionStore, +// Agent is safe for concurrent use. The underlying dependencies (SessionStore, // Genkit) must also be thread-safe. package chat diff --git a/internal/agent/chat/flow.go b/internal/chat/flow.go similarity index 67% rename from internal/agent/chat/flow.go rename to internal/chat/flow.go index 145c4c6..59423fe 100644 --- a/internal/agent/chat/flow.go +++ b/internal/chat/flow.go @@ -1,27 +1,23 @@ -// Package chat provides Flow definition for Chat Agent package chat import ( "context" "fmt" "sync" - "sync/atomic" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/genkit" "github.com/google/uuid" - - "github.com/koopa0/koopa/internal/agent" ) -// Input is the input for Chat Agent Flow +// Input defines the request payload for the chat agent flow. type Input struct { Query string `json:"query"` SessionID string `json:"sessionId"` // Required field: session ID } -// Output is the output for Chat Agent Flow +// Output defines the response payload from the chat agent flow. type Output struct { Response string `json:"response"` SessionID string `json:"sessionId"` @@ -40,39 +36,20 @@ const FlowName = "koopa/chat" // Exported for use in api package with genkit.Handler(). type Flow = core.Flow[Input, Output, StreamChunk] -// Package-level singleton for Flow to prevent Panic on re-registration. -// sync.Once ensures the Flow is defined only once, even in tests. +// Package-level singleton for Flow to prevent panic on re-registration. +// sync.Once ensures genkit.DefineStreamingFlow is called only once. var ( - flowOnce sync.Once - flow *Flow - flowInitDone atomic.Bool + flowOnce sync.Once + flow *Flow ) -// InitFlow initializes the Chat Flow singleton. -// Must be called exactly once during application startup. -// Returns error if called more than once. -// -// This explicit API prevents the dangerous pattern where parameters -// are silently ignored on subsequent calls (as in the old GetFlow). -func InitFlow(g *genkit.Genkit, chatAgent *Chat) (*Flow, error) { - var initialized bool +// NewFlow returns the Chat Flow singleton, initializing it on first call. +// Subsequent calls return the existing Flow (parameters are ignored). +// This is safe because genkit.DefineStreamingFlow panics on re-registration. +func NewFlow(g *genkit.Genkit, agent *Agent) *Flow { flowOnce.Do(func() { - flow = chatAgent.DefineFlow(g) - flowInitDone.Store(true) - initialized = true + flow = agent.DefineFlow(g) }) - if !initialized && flowInitDone.Load() { - return nil, fmt.Errorf("InitFlow called more than once") - } - return flow, nil -} - -// GetFlow returns the initialized Flow singleton. -// Panics if InitFlow was not called - this indicates a programming error. -func GetFlow() *Flow { - if !flowInitDone.Load() { - panic("GetFlow called before InitFlow") - } return flow } @@ -82,14 +59,13 @@ func GetFlow() *Flow { func ResetFlowForTesting() { flowOnce = sync.Once{} flow = nil - flowInitDone.Store(false) } // DefineFlow defines the Genkit Streaming Flow for Chat Agent. // Supports both streaming (via callback) and non-streaming modes. // -// IMPORTANT: Use GetFlow() instead of calling DefineFlow() directly. -// DefineFlow registers a global Flow; calling it twice causes Panic. +// IMPORTANT: Use NewFlow() instead of calling DefineFlow() directly. +// DefineFlow registers a global Flow; calling it twice causes panic. // // Each Agent has its own dedicated Flow, responsible for: // 1. Observability (Genkit DevUI tracing) @@ -103,15 +79,13 @@ func ResetFlowForTesting() { // - Errors are now properly returned using sentinel errors from agent package // - Genkit tracing will correctly show error spans // - HTTP handlers can use errors.Is() to determine error type and HTTP status -// -//nolint:gocognit // Genkit Flow requires orchestration logic in single function -func (c *Chat) DefineFlow(g *genkit.Genkit) *Flow { +func (a *Agent) DefineFlow(g *genkit.Genkit) *Flow { return genkit.DefineStreamingFlow(g, FlowName, func(ctx context.Context, input Input, streamCb func(context.Context, StreamChunk) error) (Output, error) { // Parse session ID from input sessionID, err := uuid.Parse(input.SessionID) if err != nil { - return Output{SessionID: input.SessionID}, fmt.Errorf("%w: %w", agent.ErrInvalidSession, err) + return Output{SessionID: input.SessionID}, fmt.Errorf("%w: %w", ErrInvalidSession, err) } // Create StreamCallback wrapper if streaming is enabled @@ -135,10 +109,10 @@ func (c *Chat) DefineFlow(g *genkit.Genkit) *Flow { } // Execute with streaming callback (or non-streaming if callback is nil) - resp, err := c.ExecuteStream(ctx, sessionID, input.Query, agentCallback) + resp, err := a.ExecuteStream(ctx, sessionID, input.Query, agentCallback) if err != nil { // Genkit will mark this span as failed, enabling proper observability - return Output{SessionID: input.SessionID}, fmt.Errorf("%w: %w", agent.ErrExecutionFailed, err) + return Output{SessionID: input.SessionID}, fmt.Errorf("%w: %w", ErrExecutionFailed, err) } return Output{ diff --git a/internal/chat/flow_test.go b/internal/chat/flow_test.go new file mode 100644 index 0000000..c18251b --- /dev/null +++ b/internal/chat/flow_test.go @@ -0,0 +1,52 @@ +package chat + +import ( + "errors" + "testing" +) + +// TestSentinelErrors_CanBeChecked tests that sentinel errors work correctly with errors.Is +func TestSentinelErrors_CanBeChecked(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + sentinel error + }{ + {name: "ErrInvalidSession", err: ErrInvalidSession, sentinel: ErrInvalidSession}, + {name: "ErrExecutionFailed", err: ErrExecutionFailed, sentinel: ErrExecutionFailed}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if !errors.Is(tt.err, tt.sentinel) { + t.Errorf("errors.Is(%v, %v) = false, want true", tt.err, tt.sentinel) + } + }) + } +} + +// TestWrappedErrors_PreserveSentinel tests that wrapped errors preserve sentinel checking +func TestWrappedErrors_PreserveSentinel(t *testing.T) { + t.Parallel() + + t.Run("wrapped invalid session error", func(t *testing.T) { + t.Parallel() + err := errors.New("original error") + wrapped := errors.Join(ErrInvalidSession, err) + if !errors.Is(wrapped, ErrInvalidSession) { + t.Errorf("errors.Is(wrapped, ErrInvalidSession) = false, want true") + } + }) + + t.Run("wrapped execution failed error", func(t *testing.T) { + t.Parallel() + err := errors.New("LLM timeout") + wrapped := errors.Join(ErrExecutionFailed, err) + if !errors.Is(wrapped, ErrExecutionFailed) { + t.Errorf("errors.Is(wrapped, ErrExecutionFailed) = false, want true") + } + }) +} diff --git a/internal/agent/chat/integration_rag_test.go b/internal/chat/integration_rag_test.go similarity index 71% rename from internal/agent/chat/integration_rag_test.go rename to internal/chat/integration_rag_test.go index 73582aa..c825f00 100644 --- a/internal/agent/chat/integration_rag_test.go +++ b/internal/chat/integration_rag_test.go @@ -13,10 +13,6 @@ import ( "github.com/koopa0/koopa/internal/rag" ) -// ============================================================================= -// Phase 0.2: RAG Integration Tests -// ============================================================================= - // TestChatAgent_RAGIntegration_EndToEnd verifies the complete RAG workflow: // 1. Index document into knowledge store // 2. Query triggers retrieval @@ -25,16 +21,9 @@ import ( // Per Proposal 030: topK > 0 with real retriever, verify documents returned // and integrated into prompt. func TestChatAgent_RAGIntegration_EndToEnd(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() - + framework := SetupTest(t) ctx := context.Background() - // Ensure RAG is enabled - if framework.Config.RAGTopK <= 0 { - t.Fatal("RAG must be enabled for this test (RAGTopK > 0)") - } - // STEP 1: Index test document using DocStore docID := uuid.New() testContent := "The secret password is KOOPA_TEST_123. This is a unique test string." @@ -46,8 +35,7 @@ func TestChatAgent_RAGIntegration_EndToEnd(t *testing.T) { t.Logf("Indexed document %s with test content", docID) // STEP 2: Query should trigger RAG retrieval - invCtx, sessionID := newInvocationContext(ctx, framework.SessionID) - resp, err := framework.Agent.ExecuteStream(invCtx, sessionID, + resp, err := framework.Agent.ExecuteStream(ctx, framework.SessionID, "What is the secret password from the test document?", nil, ) @@ -73,9 +61,7 @@ func TestChatAgent_RAGIntegration_EndToEnd(t *testing.T) { // returns documents when topK > 0. // Per Proposal 030: Verify documents returned and integrated into prompt. func TestRetrieveRAGContext_ActualRetrieval(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() - + framework := SetupTest(t) ctx := context.Background() // Index multiple documents for retrieval @@ -95,8 +81,7 @@ func TestRetrieveRAGContext_ActualRetrieval(t *testing.T) { t.Log("Indexed 2 test documents") // Query that should trigger retrieval - invCtx, sessionID := newInvocationContext(ctx, framework.SessionID) - resp, err := framework.Agent.ExecuteStream(invCtx, sessionID, + resp, err := framework.Agent.ExecuteStream(ctx, framework.SessionID, "Tell me about Go programming language", nil, ) @@ -119,62 +104,14 @@ func TestRetrieveRAGContext_ActualRetrieval(t *testing.T) { t.Logf("Response with RAG: %s", resp.FinalText) } -// TestRetrieveRAGContext_DisabledWhenTopKZero verifies that RAG is skipped -// when topK is 0. -func TestRetrieveRAGContext_DisabledWhenTopKZero(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() - - ctx := context.Background() - - // Override config to disable RAG - originalRAGTopK := framework.Config.RAGTopK - framework.Config.RAGTopK = 0 - defer func() { framework.Config.RAGTopK = originalRAGTopK }() - - // Index document (should NOT be retrieved) - docID := uuid.New() - framework.IndexDocument(t, "This content should be ignored because RAG is disabled.", map[string]any{ - "id": docID.String(), - "filename": "ignored.txt", - "source_type": rag.SourceTypeFile, - }) - - // Query - should NOT trigger retrieval - invCtx, sessionID := newInvocationContext(ctx, framework.SessionID) - resp, err := framework.Agent.ExecuteStream(invCtx, sessionID, - "What does the ignored document say?", - nil, - ) - - if err != nil { - t.Fatalf("ExecuteStream() unexpected error: %v", err) - } - if resp == nil { - t.Fatal("ExecuteStream() response is nil, want non-nil when error is nil") - } - if resp.FinalText == "" { - t.Error("ExecuteStream() response.FinalText is empty, want non-empty") - } - - // Response should NOT contain the ignored content - if strings.Contains(resp.FinalText, "should be ignored") { - t.Errorf("ExecuteStream() response = %q, should not contain %q (RAG disabled)", resp.FinalText, "should be ignored") - } - t.Logf("Response without RAG (topK=0): %s", resp.FinalText) -} - // TestRetrieveRAGContext_EmptyKnowledgeBase verifies graceful handling // when knowledge base has no matching documents. func TestRetrieveRAGContext_EmptyKnowledgeBase(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() - + framework := SetupTest(t) ctx := context.Background() // Query with empty knowledge base (no documents indexed) - invCtx, sessionID := newInvocationContext(ctx, framework.SessionID) - resp, err := framework.Agent.ExecuteStream(invCtx, sessionID, + resp, err := framework.Agent.ExecuteStream(ctx, framework.SessionID, "What is in the knowledge base?", nil, ) @@ -194,9 +131,7 @@ func TestRetrieveRAGContext_EmptyKnowledgeBase(t *testing.T) { // TestRetrieveRAGContext_MultipleRelevantDocuments verifies that when multiple // documents match the query, the RAG system retrieves and uses them. func TestRetrieveRAGContext_MultipleRelevantDocuments(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() - + framework := SetupTest(t) ctx := context.Background() // Index 5 related documents @@ -222,8 +157,7 @@ func TestRetrieveRAGContext_MultipleRelevantDocuments(t *testing.T) { t.Logf("Indexed %d related documents", len(topics)) // Query should retrieve multiple relevant documents - invCtx, sessionID := newInvocationContext(ctx, framework.SessionID) - resp, err := framework.Agent.ExecuteStream(invCtx, sessionID, + resp, err := framework.Agent.ExecuteStream(ctx, framework.SessionID, "Summarize the key features of Go programming language", nil, ) diff --git a/internal/agent/chat/integration_streaming_test.go b/internal/chat/integration_streaming_test.go similarity index 85% rename from internal/agent/chat/integration_streaming_test.go rename to internal/chat/integration_streaming_test.go index 8ec1b04..784c94c 100644 --- a/internal/agent/chat/integration_streaming_test.go +++ b/internal/chat/integration_streaming_test.go @@ -12,19 +12,15 @@ import ( "github.com/firebase/genkit/go/ai" ) -// ============================================================================= -// Phase 0.3: Streaming Integration Tests -// ============================================================================= - // TestChatAgent_StreamingCallbackError verifies that errors from streaming // callbacks are properly propagated and stop the stream. // Per Proposal 030: Callback returns error after N chunks, verify stream stops, // error propagated. func TestChatAgent_StreamingCallbackError(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() + framework := SetupTest(t) + ctx := context.Background() + sessionID := framework.SessionID - ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID) chunks := 0 maxChunks := 3 @@ -66,10 +62,10 @@ func TestChatAgent_StreamingCallbackError(t *testing.T) { // TestChatAgent_StreamingCallbackSuccess verifies that streaming works // correctly when callback always succeeds. func TestChatAgent_StreamingCallbackSuccess(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() + framework := SetupTest(t) + ctx := context.Background() + sessionID := framework.SessionID - ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID) chunks := 0 var receivedTexts []string @@ -105,9 +101,7 @@ func TestChatAgent_StreamingCallbackSuccess(t *testing.T) { // TestChatAgent_StreamingVsNonStreaming verifies that streaming and non-streaming // modes produce equivalent results. func TestChatAgent_StreamingVsNonStreaming(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() - + framework := SetupTest(t) ctx := context.Background() query := "What is 2+2? Answer with just the number." @@ -116,8 +110,7 @@ func TestChatAgent_StreamingVsNonStreaming(t *testing.T) { // This is a standard Go idiom (nil function = skip optional behavior). // Contract: When callback is nil, the method returns only after full completion. session1 := framework.CreateTestSession(t, "Non-streaming test") - invCtx1, sessionID1 := newInvocationContext(ctx, session1) - respNoStream, err := framework.Agent.ExecuteStream(invCtx1, sessionID1, + respNoStream, err := framework.Agent.ExecuteStream(ctx, session1, query, nil, // No callback = non-streaming mode (returns complete response) ) @@ -133,14 +126,13 @@ func TestChatAgent_StreamingVsNonStreaming(t *testing.T) { // Streaming execution session2 := framework.CreateTestSession(t, "Streaming test") - invCtx2, sessionID2 := newInvocationContext(ctx, session2) var streamedResponse string callback := func(ctx context.Context, chunk *ai.ModelResponseChunk) error { streamedResponse += chunk.Text() return nil } - respStream, err := framework.Agent.ExecuteStream(invCtx2, sessionID2, + respStream, err := framework.Agent.ExecuteStream(ctx, session2, query, callback, ) @@ -169,17 +161,14 @@ func TestChatAgent_StreamingVsNonStreaming(t *testing.T) { // TestChatAgent_StreamingContextCancellation verifies that canceling the // context stops streaming. func TestChatAgent_StreamingContextCancellation(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() + framework := SetupTest(t) // Cancel BEFORE starting stream to guarantee cancellation // This is deterministic - no race between stream completion and cancellation ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately - _, sessionID := newInvocationContext(context.Background(), framework.SessionID) - - resp, err := framework.Agent.ExecuteStream(ctx, sessionID, + resp, err := framework.Agent.ExecuteStream(ctx, framework.SessionID, "Write a very long story", nil, ) @@ -200,10 +189,10 @@ func TestChatAgent_StreamingContextCancellation(t *testing.T) { // TestChatAgent_StreamingEmptyChunks verifies handling of empty chunks // in streaming mode. func TestChatAgent_StreamingEmptyChunks(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() + framework := SetupTest(t) + ctx := context.Background() + sessionID := framework.SessionID - ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID) totalChunks := 0 emptyChunks := 0 diff --git a/internal/agent/chat/integration_test.go b/internal/chat/integration_test.go similarity index 80% rename from internal/agent/chat/integration_test.go rename to internal/chat/integration_test.go index e69e6e0..7d9906f 100644 --- a/internal/agent/chat/integration_test.go +++ b/internal/chat/integration_test.go @@ -16,15 +16,14 @@ import ( "github.com/firebase/genkit/go/ai" - "github.com/koopa0/koopa/internal/agent/chat" + "github.com/koopa0/koopa/internal/chat" ) // TestChatAgent_BasicExecution tests basic chat agent execution func TestChatAgent_BasicExecution(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() - - ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID) + framework := SetupTest(t) + ctx := context.Background() + sessionID := framework.SessionID t.Run("simple question", func(t *testing.T) { resp, err := framework.Agent.Execute(ctx, sessionID, "Hello, how are you?") @@ -42,10 +41,9 @@ func TestChatAgent_BasicExecution(t *testing.T) { // TestChatAgent_SessionPersistence tests conversation history persistence func TestChatAgent_SessionPersistence(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() - - ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID) + framework := SetupTest(t) + ctx := context.Background() + sessionID := framework.SessionID t.Run("first message creates history", func(t *testing.T) { resp, err := framework.Agent.Execute(ctx, sessionID, "My name is Koopa") @@ -77,10 +75,9 @@ func TestChatAgent_SessionPersistence(t *testing.T) { // TestChatAgent_ToolIntegration tests tool calling capability func TestChatAgent_ToolIntegration(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() - - ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID) + framework := SetupTest(t) + ctx := context.Background() + sessionID := framework.SessionID t.Run("can use file tools", func(t *testing.T) { // Create unique marker file to verify tool was actually invoked @@ -114,19 +111,24 @@ func TestChatAgent_ToolIntegration(t *testing.T) { // TestChatAgent_ErrorHandling tests error scenarios func TestChatAgent_ErrorHandling(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() + framework := SetupTest(t) t.Run("handles empty input gracefully", func(t *testing.T) { - ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID) + ctx := context.Background() + sessionID := framework.SessionID resp, err := framework.Agent.Execute(ctx, sessionID, "") - // Should handle empty input without crashing - // Either returns error or empty response - if err == nil { - if resp == nil { - t.Error("Execute() with empty input returned nil response and nil error") - } + // Agent should handle empty input without panicking. + // The LLM may return a valid response or an error — both are acceptable. + if err != nil { + t.Logf("Execute(\"\") returned error (acceptable): %v", err) + return + } + if resp == nil { + t.Fatal("Execute(\"\") = nil, nil — want non-nil response or non-nil error") + } + if resp.FinalText == "" { + t.Error("Execute(\"\") response.FinalText is empty, want non-empty (at minimum the fallback message)") } }) } @@ -134,12 +136,10 @@ func TestChatAgent_ErrorHandling(t *testing.T) { // TestChatAgent_NewChatValidation tests constructor validation func TestChatAgent_NewChatValidation(t *testing.T) { // Setup test framework once for all validation tests - framework, cleanup := SetupTest(t) - defer cleanup() + framework := SetupTest(t) t.Run("requires genkit", func(t *testing.T) { _, err := chat.New(chat.Config{ - Retriever: framework.Retriever, SessionStore: framework.SessionStore, Logger: slog.Default(), Tools: []ai.Tool{}, @@ -152,27 +152,11 @@ func TestChatAgent_NewChatValidation(t *testing.T) { } }) - t.Run("requires retriever", func(t *testing.T) { - _, err := chat.New(chat.Config{ - Genkit: framework.Genkit, - SessionStore: framework.SessionStore, - Logger: slog.Default(), - Tools: []ai.Tool{}, - }) - if err == nil { - t.Fatal("New() expected error, got nil") - } - if !strings.Contains(err.Error(), "retriever is required") { - t.Errorf("New() error = %q, want to contain %q", err.Error(), "retriever is required") - } - }) - t.Run("requires session store", func(t *testing.T) { _, err := chat.New(chat.Config{ - Genkit: framework.Genkit, - Retriever: framework.Retriever, - Logger: slog.Default(), - Tools: []ai.Tool{}, + Genkit: framework.Genkit, + Logger: slog.Default(), + Tools: []ai.Tool{}, }) if err == nil { t.Fatal("New() expected error, got nil") @@ -185,7 +169,6 @@ func TestChatAgent_NewChatValidation(t *testing.T) { t.Run("requires logger", func(t *testing.T) { _, err := chat.New(chat.Config{ Genkit: framework.Genkit, - Retriever: framework.Retriever, SessionStore: framework.SessionStore, Tools: []ai.Tool{}, }) @@ -200,7 +183,6 @@ func TestChatAgent_NewChatValidation(t *testing.T) { t.Run("requires at least one tool", func(t *testing.T) { _, err := chat.New(chat.Config{ Genkit: framework.Genkit, - Retriever: framework.Retriever, SessionStore: framework.SessionStore, Logger: slog.Default(), Tools: []ai.Tool{}, @@ -218,14 +200,14 @@ func TestChatAgent_NewChatValidation(t *testing.T) { // Uses mutex-protected error collection instead of assert/require in goroutines // to avoid test reliability issues with t.FailNow() from goroutines. func TestChatAgent_ConcurrentExecution(t *testing.T) { - framework, cleanup := SetupTest(t) - defer cleanup() + framework := SetupTest(t) numConcurrentQueries := 5 var wg sync.WaitGroup wg.Add(numConcurrentQueries) - ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID) + ctx := context.Background() + sessionID := framework.SessionID // Collect results safely type result struct { diff --git a/internal/agent/chat/retry.go b/internal/chat/retry.go similarity index 63% rename from internal/agent/chat/retry.go rename to internal/chat/retry.go index 9b19f33..3dbc70f 100644 --- a/internal/agent/chat/retry.go +++ b/internal/chat/retry.go @@ -25,29 +25,31 @@ func DefaultRetryConfig() RetryConfig { } } -// retryableError determines if an error should trigger a retry. +// retryablePatterns groups error substrings by category. +// Matched case-insensitively against err.Error(). +// +// NOTE: This uses string matching because Genkit and LLM provider SDKs +// do not expose typed/sentinel errors for transient failures. +// This is a documented exception to the project rule against +// strings.Contains(err.Error(), ...). +// Re-evaluate if Genkit adds structured error types in a future version. +var retryablePatterns = [][]string{ + {"rate limit", "quota exceeded", "429"}, // rate limiting + {"500", "502", "503", "504", "unavailable"}, // transient server errors + {"connection reset", "timeout", "temporary"}, // network errors +} + +// retryableError reports whether err is transient and should trigger a retry. func retryableError(err error) bool { if err == nil { return false } - errStr := err.Error() - - // Rate limit errors - always retry - if containsAny(errStr, "rate limit", "quota exceeded", "429") { - return true - } - - // Transient server errors - retry - if containsAny(errStr, "500", "502", "503", "504", "unavailable") { - return true - } - - // Network errors - retry - if containsAny(errStr, "connection reset", "timeout", "temporary") { - return true + for _, group := range retryablePatterns { + if containsAny(errStr, group...) { + return true + } } - return false } @@ -69,27 +71,27 @@ func containsAny(s string, substrs ...string) bool { // - Rate limits EACH attempt (per golang-master review) // - Tracks elapsed time for observability (per Rob Pike review) // - Exponential backoff with configurable max interval -func (c *Chat) executeWithRetry( +func (a *Agent) executeWithRetry( ctx context.Context, opts []ai.PromptExecuteOption, ) (*ai.ModelResponse, error) { var lastErr error - delay := c.retryConfig.InitialInterval + delay := a.retryConfig.InitialInterval start := time.Now() // Track elapsed time - for attempt := 0; attempt <= c.retryConfig.MaxRetries; attempt++ { + for attempt := 0; attempt <= a.retryConfig.MaxRetries; attempt++ { // Rate limit EACH attempt (per golang-master review) - if c.rateLimiter != nil { - if err := c.rateLimiter.Wait(ctx); err != nil { + if a.rateLimiter != nil { + if err := a.rateLimiter.Wait(ctx); err != nil { return nil, fmt.Errorf("rate limit wait: %w", err) } } // Attempt execution - resp, err := c.prompt.Execute(ctx, opts...) + resp, err := a.prompt.Execute(ctx, opts...) if err == nil { elapsed := time.Since(start) - c.logger.Debug("prompt executed successfully", + a.logger.Debug("prompt executed successfully", "attempts", attempt+1, "elapsed", elapsed, ) @@ -104,12 +106,12 @@ func (c *Chat) executeWithRetry( } // Last attempt - don't sleep - if attempt == c.retryConfig.MaxRetries { + if attempt == a.retryConfig.MaxRetries { break } // Exponential backoff with context cancellation check - c.logger.Debug("retrying after error", + a.logger.Debug("retrying after error", "attempt", attempt+1, "delay", delay, "elapsed", time.Since(start), @@ -120,11 +122,11 @@ func (c *Chat) executeWithRetry( case <-ctx.Done(): return nil, fmt.Errorf("context canceled during retry: %w", ctx.Err()) case <-time.After(delay): - delay = min(delay*2, c.retryConfig.MaxInterval) + delay = min(delay*2, a.retryConfig.MaxInterval) } } elapsed := time.Since(start) return nil, fmt.Errorf("prompt execute after %d retries (elapsed: %v): %w", - c.retryConfig.MaxRetries, elapsed, lastErr) + a.retryConfig.MaxRetries, elapsed, lastErr) } diff --git a/internal/chat/retry_test.go b/internal/chat/retry_test.go new file mode 100644 index 0000000..425c41d --- /dev/null +++ b/internal/chat/retry_test.go @@ -0,0 +1,196 @@ +package chat + +import ( + "errors" + "testing" +) + +func TestDefaultRetryConfig(t *testing.T) { + t.Parallel() + + cfg := DefaultRetryConfig() + + if cfg.MaxRetries <= 0 { + t.Errorf("MaxRetries should be positive, got %d", cfg.MaxRetries) + } + if cfg.InitialInterval <= 0 { + t.Errorf("InitialInterval should be positive, got %v", cfg.InitialInterval) + } + if cfg.MaxInterval <= 0 { + t.Errorf("MaxInterval should be positive, got %v", cfg.MaxInterval) + } + if cfg.MaxInterval < cfg.InitialInterval { + t.Error("MaxInterval should be >= InitialInterval") + } +} + +func TestRetryableError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + { + name: "nil error", + err: nil, + want: false, + }, + { + name: "rate limit error", + err: errors.New("rate limit exceeded"), + want: true, + }, + { + name: "quota exceeded error", + err: errors.New("quota exceeded for project"), + want: true, + }, + { + name: "429 status code", + err: errors.New("HTTP 429: Too Many Requests"), + want: true, + }, + { + name: "500 server error", + err: errors.New("HTTP 500 Internal Server Error"), + want: true, + }, + { + name: "502 bad gateway", + err: errors.New("502 Bad Gateway"), + want: true, + }, + { + name: "503 unavailable", + err: errors.New("503 Service Unavailable"), + want: true, + }, + { + name: "504 gateway timeout", + err: errors.New("504 Gateway Timeout"), + want: true, + }, + { + name: "unavailable keyword", + err: errors.New("service unavailable"), + want: true, + }, + { + name: "connection reset", + err: errors.New("connection reset by peer"), + want: true, + }, + { + name: "timeout error", + err: errors.New("request timeout"), + want: true, + }, + { + name: "temporary error", + err: errors.New("temporary failure"), + want: true, + }, + { + name: "non-retryable error", + err: errors.New("invalid API key"), + want: false, + }, + { + name: "non-retryable 400 error", + err: errors.New("HTTP 400 Bad Request"), + want: false, + }, + { + name: "non-retryable 401 error", + err: errors.New("HTTP 401 Unauthorized"), + want: false, + }, + { + name: "non-retryable 403 error", + err: errors.New("HTTP 403 Forbidden"), + want: false, + }, + { + name: "case insensitive rate limit", + err: errors.New("RATE LIMIT reached"), + want: true, + }, + { + name: "case insensitive timeout", + err: errors.New("TIMEOUT occurred"), + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := retryableError(tt.err) + if got != tt.want { + t.Errorf("retryableError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +func TestContainsAny(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + s string + substrs []string + want bool + }{ + { + name: "empty string", + s: "", + substrs: []string{"foo"}, + want: false, + }, + { + name: "empty substrs", + s: "foo bar", + substrs: []string{}, + want: false, + }, + { + name: "contains first substr", + s: "foo bar baz", + substrs: []string{"foo", "qux"}, + want: true, + }, + { + name: "contains last substr", + s: "foo bar baz", + substrs: []string{"qux", "baz"}, + want: true, + }, + { + name: "case insensitive match", + s: "FOO BAR BAZ", + substrs: []string{"foo"}, + want: true, + }, + { + name: "no match", + s: "foo bar baz", + substrs: []string{"qux", "quux"}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := containsAny(tt.s, tt.substrs...) + if got != tt.want { + t.Errorf("containsAny(%q, %v) = %v, want %v", tt.s, tt.substrs, got, tt.want) + } + }) + } +} diff --git a/internal/agent/chat/setup_test.go b/internal/chat/setup_test.go similarity index 73% rename from internal/agent/chat/setup_test.go rename to internal/chat/setup_test.go index eef3123..8cd1bed 100644 --- a/internal/agent/chat/setup_test.go +++ b/internal/chat/setup_test.go @@ -22,7 +22,7 @@ import ( "github.com/firebase/genkit/go/plugins/postgresql" "github.com/google/uuid" - "github.com/koopa0/koopa/internal/agent/chat" + "github.com/koopa0/koopa/internal/chat" "github.com/koopa0/koopa/internal/config" "github.com/koopa0/koopa/internal/rag" "github.com/koopa0/koopa/internal/security" @@ -34,9 +34,10 @@ import ( // TestFramework provides a complete test environment for chat integration tests. // This is the chat-specific equivalent of testutil.AgentTestFramework. +// Cleanup is automatic via tb.Cleanup — no manual cleanup needed. type TestFramework struct { // Core components - Agent *chat.Chat + Agent *chat.Agent Flow *chat.Flow DocStore *postgresql.DocStore // For indexing documents in tests Retriever ai.Retriever // Genkit Retriever for RAG @@ -50,8 +51,6 @@ type TestFramework struct { // Test session (fresh per framework instance) SessionID uuid.UUID - - cleanup func() } // SetupTest creates a complete chat test environment. @@ -66,13 +65,12 @@ type TestFramework struct { // Example: // // func TestChatFeature(t *testing.T) { -// framework, cleanup := SetupTest(t) -// defer cleanup() +// framework := SetupTest(t) // // ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID) // resp, err := framework.Agent.Execute(ctx, sessionID, "test query") // } -func SetupTest(t *testing.T) (*TestFramework, func()) { +func SetupTest(t *testing.T) *TestFramework { t.Helper() apiKey := os.Getenv("GEMINI_API_KEY") @@ -82,8 +80,8 @@ func SetupTest(t *testing.T) (*TestFramework, func()) { ctx := context.Background() - // Layer 1: Use testutil primitives - dbContainer, dbCleanup := testutil.SetupTestDB(t) + // Layer 1: Use testutil primitives (cleanup is automatic via tb.Cleanup) + dbContainer := testutil.SetupTestDB(t) // Setup RAG with Genkit PostgreSQL plugin ragSetup := testutil.SetupRAG(t, dbContainer.Pool) @@ -94,10 +92,9 @@ func SetupTest(t *testing.T) (*TestFramework, func()) { cfg := &config.Config{ ModelName: "googleai/gemini-2.5-flash", - EmbedderModel: "text-embedding-004", + EmbedderModel: "gemini-embedding-001", Temperature: 0.7, MaxTokens: 8192, - RAGTopK: 5, PostgresHost: "localhost", PostgresPort: 5432, PostgresUser: "koopa_test", @@ -109,58 +106,47 @@ func SetupTest(t *testing.T) (*TestFramework, func()) { } // Create test session - testSession, err := sessionStore.CreateSession(ctx, "Chat Integration Test", cfg.ModelName, "") + testSession, err := sessionStore.CreateSession(ctx, "Chat Integration Test") if err != nil { - dbCleanup() - t.Fatalf("Failed to create test session: %v", err) + t.Fatalf("creating test session: %v", err) } // Create toolsets pathValidator, err := security.NewPath([]string{os.TempDir()}) if err != nil { - dbCleanup() - t.Fatalf("Failed to create path validator: %v", err) + t.Fatalf("creating path validator: %v", err) } // Create file tools instance - fileToolset, err := tools.NewFileTools(pathValidator, slog.Default()) + fileToolset, err := tools.NewFile(pathValidator, slog.Default()) if err != nil { - dbCleanup() - t.Fatalf("Failed to create file tools: %v", err) + t.Fatalf("creating file tools: %v", err) } // Register file tools with Genkit - fileTools, err := tools.RegisterFileTools(ragSetup.Genkit, fileToolset) + fileTools, err := tools.RegisterFile(ragSetup.Genkit, fileToolset) if err != nil { - dbCleanup() - t.Fatalf("Failed to register file tools: %v", err) + t.Fatalf("registering file tools: %v", err) } // Create Chat Agent chatAgent, err := chat.New(chat.Config{ Genkit: ragSetup.Genkit, - Retriever: ragSetup.Retriever, SessionStore: sessionStore, Logger: slog.Default(), Tools: fileTools, MaxTurns: cfg.MaxTurns, - RAGTopK: cfg.RAGTopK, Language: cfg.Language, }) if err != nil { - dbCleanup() - t.Fatalf("Failed to create chat agent: %v", err) + t.Fatalf("creating chat agent: %v", err) } // Initialize Flow singleton (reset first for test isolation) chat.ResetFlowForTesting() - flow, err := chat.InitFlow(ragSetup.Genkit, chatAgent) - if err != nil { - dbCleanup() - t.Fatalf("Failed to init chat flow: %v", err) - } + flow := chat.NewFlow(ragSetup.Genkit, chatAgent) - framework := &TestFramework{ + return &TestFramework{ Agent: chatAgent, Flow: flow, DocStore: ragSetup.DocStore, @@ -171,30 +157,20 @@ func SetupTest(t *testing.T) (*TestFramework, func()) { Genkit: ragSetup.Genkit, Embedder: ragSetup.Embedder, SessionID: testSession.ID, - cleanup: dbCleanup, } - - return framework, dbCleanup } // CreateTestSession creates a new isolated session for test isolation. func (f *TestFramework) CreateTestSession(t *testing.T, name string) uuid.UUID { t.Helper() ctx := context.Background() - sess, err := f.SessionStore.CreateSession(ctx, name, f.Config.ModelName, "") + sess, err := f.SessionStore.CreateSession(ctx, name) if err != nil { - t.Fatalf("Failed to create test session: %v", err) + t.Fatalf("creating test session: %v", err) } return sess.ID } -// newInvocationContext creates a simple invocation context for integration tests. -// This helper exists to maintain test readability after removing agent.InvocationContext. -// NOTE: Branch parameter was removed in Proposal 051. -func newInvocationContext(ctx context.Context, sessionID uuid.UUID) (context.Context, uuid.UUID) { - return ctx, sessionID -} - // IndexDocument indexes a document using the Genkit DocStore. // This is a test helper for adding documents to the RAG knowledge base. func (f *TestFramework) IndexDocument(t *testing.T, content string, metadata map[string]any) { @@ -211,6 +187,6 @@ func (f *TestFramework) IndexDocument(t *testing.T, content string, metadata map doc := ai.DocumentFromText(content, metadata) if err := f.DocStore.Index(ctx, []*ai.Document{doc}); err != nil { - t.Fatalf("Failed to index document: %v", err) + t.Fatalf("indexing document: %v", err) } } diff --git a/internal/agent/chat/tokens.go b/internal/chat/tokens.go similarity index 87% rename from internal/agent/chat/tokens.go rename to internal/chat/tokens.go index 24c7dde..1572177 100644 --- a/internal/agent/chat/tokens.go +++ b/internal/chat/tokens.go @@ -12,10 +12,11 @@ type TokenBudget struct { MaxHistoryTokens int // Maximum tokens for conversation history } -// DefaultTokenBudget returns conservative defaults for Gemini models. +// DefaultTokenBudget returns defaults for modern large-context models. +// 32K tokens ≈ 64 conversation turns — balances context retention with cost. func DefaultTokenBudget() TokenBudget { return TokenBudget{ - MaxHistoryTokens: 8000, // ~8K tokens for history + MaxHistoryTokens: 32000, } } @@ -49,7 +50,7 @@ func estimateMessagesTokens(msgs []*ai.Message) int { // truncateHistory removes oldest messages to fit within budget. // Preserves system message (if present) and keeps most recent messages. -func (c *Chat) truncateHistory(msgs []*ai.Message, budget int) []*ai.Message { +func (a *Agent) truncateHistory(msgs []*ai.Message, budget int) []*ai.Message { if len(msgs) == 0 { return msgs } @@ -59,7 +60,7 @@ func (c *Chat) truncateHistory(msgs []*ai.Message, budget int) []*ai.Message { return msgs } - c.logger.Debug("truncating history", + a.logger.Debug("truncating history", "current_tokens", currentTokens, "budget", budget, "message_count", len(msgs), @@ -92,7 +93,7 @@ func (c *Chat) truncateHistory(msgs []*ai.Message, budget int) []*ai.Message { // Append kept messages after system message result = append(result, kept...) - c.logger.Debug("history truncated", + a.logger.Debug("history truncated", "original_count", len(msgs), "new_count", len(result), "tokens_used", budget-remaining, diff --git a/internal/agent/chat/tokens_test.go b/internal/chat/tokens_test.go similarity index 82% rename from internal/agent/chat/tokens_test.go rename to internal/chat/tokens_test.go index fbcaa35..650b1d9 100644 --- a/internal/agent/chat/tokens_test.go +++ b/internal/chat/tokens_test.go @@ -1,11 +1,10 @@ package chat import ( + "log/slog" "testing" "github.com/firebase/genkit/go/ai" - - "github.com/koopa0/koopa/internal/log" ) func TestDefaultTokenBudget(t *testing.T) { @@ -22,39 +21,39 @@ func TestEstimateTokens(t *testing.T) { t.Parallel() tests := []struct { - name string - text string - expected int + name string + text string + want int }{ { - name: "empty string", - text: "", - expected: 0, + name: "empty string", + text: "", + want: 0, }, { - name: "single char returns 1", - text: "a", - expected: 1, // 1 rune / 2 = 0, but min 1 for non-empty + name: "single char returns 1", + text: "a", + want: 1, // 1 rune / 2 = 0, but min 1 for non-empty }, { - name: "short english", - text: "hello", - expected: 2, // 5 runes / 2 = 2 + name: "short english", + text: "hello", + want: 2, // 5 runes / 2 = 2 }, { - name: "longer english", - text: "This is a longer test message with multiple words.", - expected: 25, // 50 runes / 2 = 25 + name: "longer english", + text: "This is a longer test message with multiple words.", + want: 25, // 50 runes / 2 = 25 }, { - name: "cjk text", - text: "你好世界", - expected: 2, // 4 runes / 2 = 2 + name: "cjk text", + text: "你好世界", + want: 2, // 4 runes / 2 = 2 }, { - name: "mixed text", - text: "Hello 世界", - expected: 4, // 8 runes / 2 = 4 + name: "mixed text", + text: "Hello 世界", + want: 4, // 8 runes / 2 = 4 }, } @@ -63,8 +62,8 @@ func TestEstimateTokens(t *testing.T) { t.Parallel() got := estimateTokens(tt.text) - if got != tt.expected { - t.Errorf("estimateTokens(%q) = %d, want %d", tt.text, got, tt.expected) + if got != tt.want { + t.Errorf("estimateTokens(%q) = %d, want %d", tt.text, got, tt.want) } }) } @@ -74,26 +73,26 @@ func TestEstimateMessagesTokens(t *testing.T) { t.Parallel() tests := []struct { - name string - msgs []*ai.Message - expected int + name string + msgs []*ai.Message + want int }{ { - name: "nil messages", - msgs: nil, - expected: 0, + name: "nil messages", + msgs: nil, + want: 0, }, { - name: "empty messages", - msgs: []*ai.Message{}, - expected: 0, + name: "empty messages", + msgs: []*ai.Message{}, + want: 0, }, { name: "single message", msgs: []*ai.Message{ ai.NewUserMessage(ai.NewTextPart("hello world")), // 11 runes / 2 = 5 }, - expected: 5, + want: 5, }, { name: "multiple messages", @@ -102,7 +101,7 @@ func TestEstimateMessagesTokens(t *testing.T) { ai.NewModelMessage(ai.NewTextPart("world")), // 5 / 2 = 2 ai.NewUserMessage(ai.NewTextPart("how are you")), // 11 / 2 = 5 }, - expected: 9, + want: 9, }, } @@ -111,8 +110,8 @@ func TestEstimateMessagesTokens(t *testing.T) { t.Parallel() got := estimateMessagesTokens(tt.msgs) - if got != tt.expected { - t.Errorf("estimateMessagesTokens() = %d, want %d", got, tt.expected) + if got != tt.want { + t.Errorf("estimateMessagesTokens() = %d, want %d", got, tt.want) } }) } @@ -122,9 +121,9 @@ func TestTruncateHistory(t *testing.T) { t.Parallel() // Helper to create a Chat with nop logger for testing - makeChat := func() *Chat { - return &Chat{ - logger: log.NewNop(), + makeAgent := func() *Agent { + return &Agent{ + logger: slog.New(slog.DiscardHandler), } } @@ -234,8 +233,8 @@ func TestTruncateHistory(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - chat := makeChat() - got := chat.truncateHistory(tt.msgs, tt.budget) + agent := makeAgent() + got := agent.truncateHistory(tt.msgs, tt.budget) if len(got) != tt.wantLen { t.Errorf("truncateHistory() len = %d, want %d", len(got), tt.wantLen) @@ -248,7 +247,7 @@ func TestTruncateHistory(t *testing.T) { // Check system message preservation if tt.wantHasSystem { if got[0].Role != ai.RoleSystem { - t.Errorf("expected first message to be system, got %s", got[0].Role) + t.Errorf("want first message to be system, got %s", got[0].Role) } } @@ -267,7 +266,7 @@ func TestTruncateHistory(t *testing.T) { // wantTexts[i] must match got[i] - this implicitly verifies ordering if len(tt.wantTexts) > 0 { if len(got) != len(tt.wantTexts) { - t.Fatalf("got %d messages but expected %d texts to verify", len(got), len(tt.wantTexts)) + t.Fatalf("got %d messages but want %d texts to verify", len(got), len(tt.wantTexts)) } for i, want := range tt.wantTexts { if len(got[i].Content) == 0 { @@ -286,7 +285,7 @@ func TestTruncateHistory(t *testing.T) { func TestTruncateHistory_ChronologicalOrder(t *testing.T) { t.Parallel() - chat := &Chat{logger: log.NewNop()} + agent := &Agent{logger: slog.New(slog.DiscardHandler)} // Create a conversation with alternating user/model messages msgs := []*ai.Message{ @@ -298,7 +297,7 @@ func TestTruncateHistory_ChronologicalOrder(t *testing.T) { } // Budget should keep only last 2-3 messages - result := chat.truncateHistory(msgs, 6) + result := agent.truncateHistory(msgs, 6) // Verify messages are still in chronological order for i := 1; i < len(result); i++ { diff --git a/internal/config/ai.go b/internal/config/ai.go deleted file mode 100644 index 7823a8d..0000000 --- a/internal/config/ai.go +++ /dev/null @@ -1,14 +0,0 @@ -package config - -// AIConfig holds AI model configuration. -// Fields are embedded in the main Config struct for backward compatibility. -// Documented separately for clarity. -// -// Configuration options: -// - Provider: AI provider ("gemini", "ollama", "openai") -// - ModelName: Model identifier (e.g., "gemini-2.5-flash", "llama3.3", "gpt-4o") -// - Temperature: 0.0 (deterministic) to 2.0 (creative) -// - MaxTokens: 1 to 2,097,152 (Gemini 2.5 max context) -// - Language: Response language ("auto", "English", "zh-TW") -// - PromptDir: Directory for .prompt files (Dotprompt) -// - OllamaHost: Ollama server address (default: "http://localhost:11434") diff --git a/internal/config/ai_test.go b/internal/config/ai_test.go deleted file mode 100644 index 5921c5c..0000000 --- a/internal/config/ai_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package config - -import "testing" - -// TestFullModelName tests that FullModelName derives correct provider-qualified names. -func TestFullModelName(t *testing.T) { - tests := []struct { - name string - provider string - modelName string - want string - }{ - {name: "gemini default", provider: "", modelName: "gemini-2.5-flash", want: "googleai/gemini-2.5-flash"}, - {name: "gemini explicit", provider: "gemini", modelName: "gemini-2.5-pro", want: "googleai/gemini-2.5-pro"}, - {name: "ollama", provider: "ollama", modelName: "llama3.3", want: "ollama/llama3.3"}, - {name: "openai", provider: "openai", modelName: "gpt-4o", want: "openai/gpt-4o"}, - {name: "already qualified", provider: "ollama", modelName: "ollama/llama3.3", want: "ollama/llama3.3"}, - {name: "cross-qualified", provider: "gemini", modelName: "openai/gpt-4o", want: "openai/gpt-4o"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cfg := &Config{ - Provider: tt.provider, - ModelName: tt.modelName, - } - got := cfg.FullModelName() - if got != tt.want { - t.Errorf("FullModelName() = %q, want %q", got, tt.want) - } - }) - } -} diff --git a/internal/config/config.go b/internal/config/config.go index 109b973..2f10ade 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,9 +6,9 @@ // 3. Default values (sensible defaults for quick start) // // Main configuration categories: -// - AI: Model selection, temperature, max tokens, embedder (see ai.go) -// - Storage: SQLite database path, PostgreSQL connection (see storage.go) -// - RAG: Number of documents to retrieve (RAGTopK) +// - AI: Model selection, temperature, max tokens, embedder +// - Storage: PostgreSQL connection (see storage.go) +// - RAG: Embedder model for vector embeddings // - MCP: Model Context Protocol server management (see tools.go) // - Observability: Datadog APM tracing (see observability.go) // @@ -32,10 +32,6 @@ import ( "github.com/spf13/viper" ) -// ============================================================================ -// Sentinel Errors -// ============================================================================ - var ( // ErrConfigNil indicates the configuration is nil. ErrConfigNil = errors.New("configuration is nil") @@ -52,12 +48,12 @@ var ( // ErrInvalidMaxTokens indicates the max tokens value is out of range. ErrInvalidMaxTokens = errors.New("invalid max tokens") - // ErrInvalidRAGTopK indicates the RAG top-k value is out of range. - ErrInvalidRAGTopK = errors.New("invalid RAG top-k") - // ErrInvalidEmbedderModel indicates the embedder model is invalid. ErrInvalidEmbedderModel = errors.New("invalid embedder model") + // ErrInvalidEmbedderDimension indicates the embedder produces incompatible vector dimensions. + ErrInvalidEmbedderDimension = errors.New("incompatible embedder dimension") + // ErrInvalidPostgresHost indicates the PostgreSQL host is invalid. ErrInvalidPostgresHost = errors.New("invalid PostgreSQL host") @@ -67,9 +63,6 @@ var ( // ErrInvalidPostgresDBName indicates the PostgreSQL database name is invalid. ErrInvalidPostgresDBName = errors.New("invalid PostgreSQL database name") - // ErrConfigParse indicates configuration parsing failed. - ErrConfigParse = errors.New("failed to parse configuration") - // ErrInvalidProvider indicates the AI provider is not supported. ErrInvalidProvider = errors.New("invalid provider") @@ -81,15 +74,20 @@ var ( // ErrInvalidPostgresSSLMode indicates the PostgreSQL SSL mode is invalid. ErrInvalidPostgresSSLMode = errors.New("invalid PostgreSQL SSL mode") -) -// ============================================================================ -// Constants -// ============================================================================ + // ErrMissingHMACSecret indicates the HMAC secret is not set. + ErrMissingHMACSecret = errors.New("missing HMAC secret") + + // ErrInvalidHMACSecret indicates the HMAC secret is too short. + ErrInvalidHMACSecret = errors.New("invalid HMAC secret") +) const ( - // DefaultEmbedderModel is the default embedder model for vector embeddings. - DefaultEmbedderModel = "text-embedding-004" + // DefaultGeminiEmbedderModel is the default Gemini embedder model. + // gemini-embedding-001 outputs 3072 dimensions by default, but supports + // truncation to 768 via OutputDimensionality (Matryoshka Representation Learning). + // Our pgvector schema uses 768 dimensions; see rag.VectorDimension. + DefaultGeminiEmbedderModel = "gemini-embedding-001" // DefaultMaxHistoryMessages is the default number of messages to load. DefaultMaxHistoryMessages int32 = 100 @@ -101,9 +99,13 @@ const ( MinHistoryMessages int32 = 10 ) -// ============================================================================ -// Config Struct -// ============================================================================ +// AI provider identifiers used in Config.Provider. +const ( + ProviderGemini = "gemini" + ProviderOllama = "ollama" + ProviderOpenAI = "openai" + ProviderGoogleAI = "googleai" +) // Config stores application configuration. // SECURITY: Sensitive fields are explicitly masked in MarshalJSON(). @@ -133,13 +135,8 @@ type Config struct { PostgresSSLMode string `mapstructure:"postgres_ssl_mode" json:"postgres_ssl_mode"` // RAG configuration - RAGTopK int `mapstructure:"rag_top_k" json:"rag_top_k"` EmbedderModel string `mapstructure:"embedder_model" json:"embedder_model"` - // MCP configuration (see tools.go for type definitions) - MCP MCPConfig `mapstructure:"mcp" json:"mcp"` - MCPServers map[string]MCPServer `mapstructure:"mcp_servers" json:"mcp_servers"` - // Tool configuration (see tools.go for type definitions) SearXNG SearXNGConfig `mapstructure:"searxng" json:"searxng"` WebScraper WebScraperConfig `mapstructure:"web_scraper" json:"web_scraper"` @@ -153,24 +150,20 @@ type Config struct { TrustProxy bool `mapstructure:"trust_proxy" json:"trust_proxy"` // Trust X-Real-IP/X-Forwarded-For headers (set true behind reverse proxy) } -// ============================================================================ -// Load Function -// ============================================================================ - // Load loads configuration. // Priority: Environment variables > Configuration file > Default values func Load() (*Config, error) { // Configuration directory: ~/.koopa/ home, err := os.UserHomeDir() if err != nil { - return nil, fmt.Errorf("failed to get user home directory: %w", err) + return nil, fmt.Errorf("getting user home directory: %w", err) } configDir := filepath.Join(home, ".koopa") // Ensure directory exists (use 0750 permission for better security) if err := os.MkdirAll(configDir, 0o750); err != nil { - return nil, fmt.Errorf("failed to create config directory: %w", err) + return nil, fmt.Errorf("creating config directory: %w", err) } // Configure Viper @@ -180,7 +173,7 @@ func Load() (*Config, error) { viper.AddConfigPath(".") // Also support current directory // Set default values - setDefaults(configDir) + setDefaults() // Bind environment variables bindEnvVariables() @@ -190,7 +183,7 @@ func Load() (*Config, error) { // Configuration file not found is not an error, use default values var configNotFound viper.ConfigFileNotFoundError if !errors.As(err, &configNotFound) { - return nil, fmt.Errorf("failed to read config file: %w", err) + return nil, fmt.Errorf("reading config file: %w", err) } slog.Debug("configuration file not found, using default values", "search_paths", []string{configDir, "."}, @@ -200,26 +193,26 @@ func Load() (*Config, error) { // Use Unmarshal to automatically map to struct (type-safe) var cfg Config if err := viper.Unmarshal(&cfg); err != nil { - return nil, fmt.Errorf("failed to parse configuration: %w", err) + return nil, fmt.Errorf("parsing configuration: %w", err) } // Parse DATABASE_URL if set (highest priority for PostgreSQL config) if err := cfg.parseDatabaseURL(); err != nil { - return nil, fmt.Errorf("failed to parse DATABASE_URL: %w", err) + return nil, fmt.Errorf("parsing DATABASE_URL: %w", err) } // CRITICAL: Validate immediately (fail-fast) if err := cfg.Validate(); err != nil { - return nil, fmt.Errorf("configuration validation failed: %w", err) + return nil, fmt.Errorf("validating configuration: %w", err) } return &cfg, nil } // setDefaults sets all default configuration values. -func setDefaults(configDir string) { +func setDefaults() { // AI defaults - viper.SetDefault("provider", "gemini") + viper.SetDefault("provider", ProviderGemini) viper.SetDefault("model_name", "gemini-2.5-flash") viper.SetDefault("temperature", 0.7) viper.SetDefault("max_tokens", 2048) @@ -229,8 +222,6 @@ func setDefaults(configDir string) { // Ollama defaults viper.SetDefault("ollama_host", "http://localhost:11434") - viper.SetDefault("database_path", filepath.Join(configDir, "koopa.db")) - // PostgreSQL defaults (matching docker-compose.yml) viper.SetDefault("postgres_host", "localhost") viper.SetDefault("postgres_port", 5432) @@ -240,8 +231,7 @@ func setDefaults(configDir string) { viper.SetDefault("postgres_ssl_mode", "disable") // RAG defaults - viper.SetDefault("rag_top_k", 3) - viper.SetDefault("embedder_model", DefaultEmbedderModel) + viper.SetDefault("embedder_model", DefaultGeminiEmbedderModel) // MCP defaults viper.SetDefault("mcp.timeout", 5) @@ -302,10 +292,6 @@ func bindEnvVariables() { // Validation checks their presence based on the selected provider in cfg.Validate() } -// ============================================================================ -// Sensitive Data Masking -// ============================================================================ - // maskedValue is the placeholder for masked sensitive data. // Using ████████ (full-width blocks U+2588) to avoid substring matching // Previous attempts: @@ -346,7 +332,6 @@ func maskSecret(s string) string { // - PostgresPassword // - HMACSecret // - Datadog.APIKey (via DatadogConfig.MarshalJSON) -// - MCPServers[*].Env (via MCPServer.MarshalJSON) // // When adding new sensitive fields, update this method or the nested struct's MarshalJSON. // The compiler will remind you when tests fail. @@ -355,7 +340,7 @@ func (c Config) MarshalJSON() ([]byte, error) { a := alias(c) a.PostgresPassword = maskSecret(a.PostgresPassword) a.HMACSecret = maskSecret(a.HMACSecret) - // Note: Datadog.APIKey and MCPServers[*].Env are handled by their own MarshalJSON + // Note: Datadog.APIKey is handled by its own MarshalJSON data, err := json.Marshal(a) if err != nil { return nil, fmt.Errorf("marshal config: %w", err) @@ -371,12 +356,12 @@ func (c *Config) FullModelName() string { return c.ModelName } switch c.Provider { - case "ollama": - return "ollama/" + c.ModelName - case "openai": - return "openai/" + c.ModelName + case ProviderOllama: + return ProviderOllama + "/" + c.ModelName + case ProviderOpenAI: + return ProviderOpenAI + "/" + c.ModelName default: - return "googleai/" + c.ModelName + return ProviderGoogleAI + "/" + c.ModelName } } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4ce02b0..8be18fd 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -21,13 +21,13 @@ func TestLoadDefaults(t *testing.T) { originalHome := os.Getenv("HOME") defer func() { if err := os.Setenv("HOME", originalHome); err != nil { - t.Errorf("Failed to restore HOME: %v", err) + t.Errorf("restoring HOME: %v", err) } }() // Set HOME to temp directory (no existing config.yaml) if err := os.Setenv("HOME", tmpDir); err != nil { - t.Fatalf("Failed to set HOME: %v", err) + t.Fatalf("setting HOME: %v", err) } // Save and restore original environment @@ -35,11 +35,11 @@ func TestLoadDefaults(t *testing.T) { defer func() { if originalAPIKey != "" { if err := os.Setenv("GEMINI_API_KEY", originalAPIKey); err != nil { - t.Errorf("Failed to restore GEMINI_API_KEY: %v", err) + t.Errorf("restoring GEMINI_API_KEY: %v", err) } } else { if err := os.Unsetenv("GEMINI_API_KEY"); err != nil { - t.Errorf("Failed to unset GEMINI_API_KEY: %v", err) + t.Errorf("unsetting GEMINI_API_KEY: %v", err) } } }() @@ -55,58 +55,51 @@ func TestLoadDefaults(t *testing.T) { // Set API key for validation if err := os.Setenv("GEMINI_API_KEY", "test-api-key"); err != nil { - t.Fatalf("Failed to set GEMINI_API_KEY: %v", err) + t.Fatalf("setting GEMINI_API_KEY: %v", err) } cfg, err := Load() if err != nil { - t.Fatalf("Load() failed: %v", err) + t.Fatalf("Load() unexpected error: %v", err) } // Verify default values if cfg.ModelName != "gemini-2.5-flash" { - t.Errorf("expected default ModelName 'gemini-2.5-flash', got %q", cfg.ModelName) + t.Errorf("Load().ModelName = %q, want %q", cfg.ModelName, "gemini-2.5-flash") } if cfg.Temperature != 0.7 { - t.Errorf("expected default Temperature 0.7, got %f", cfg.Temperature) + t.Errorf("Load().Temperature = %f, want %f", cfg.Temperature, 0.7) } if cfg.MaxTokens != 2048 { - t.Errorf("expected default MaxTokens 2048, got %d", cfg.MaxTokens) + t.Errorf("Load().MaxTokens = %d, want %d", cfg.MaxTokens, 2048) } if cfg.MaxHistoryMessages != DefaultMaxHistoryMessages { - t.Errorf("expected default MaxHistoryMessages %d, got %d", DefaultMaxHistoryMessages, cfg.MaxHistoryMessages) + t.Errorf("Load().MaxHistoryMessages = %d, want %d", cfg.MaxHistoryMessages, DefaultMaxHistoryMessages) } if cfg.PostgresHost != "localhost" { - t.Errorf("expected default PostgresHost 'localhost', got %q", cfg.PostgresHost) + t.Errorf("Load().PostgresHost = %q, want %q", cfg.PostgresHost, "localhost") } if cfg.PostgresPort != 5432 { - t.Errorf("expected default PostgresPort 5432, got %d", cfg.PostgresPort) + t.Errorf("Load().PostgresPort = %d, want %d", cfg.PostgresPort, 5432) } if cfg.PostgresUser != "koopa" { - t.Errorf("expected default PostgresUser 'koopa', got %q", cfg.PostgresUser) + t.Errorf("Load().PostgresUser = %q, want %q", cfg.PostgresUser, "koopa") } if cfg.PostgresDBName != "koopa" { - t.Errorf("expected default PostgresDBName 'koopa', got %q", cfg.PostgresDBName) + t.Errorf("Load().PostgresDBName = %q, want %q", cfg.PostgresDBName, "koopa") } - if cfg.RAGTopK != 3 { - t.Errorf("expected default RAGTopK 3, got %d", cfg.RAGTopK) + if cfg.EmbedderModel != "gemini-embedding-001" { + t.Errorf("Load().EmbedderModel = %q, want %q", cfg.EmbedderModel, "gemini-embedding-001") } - if cfg.EmbedderModel != "text-embedding-004" { - t.Errorf("expected default EmbedderModel 'text-embedding-004', got %q", cfg.EmbedderModel) - } - - if cfg.MCP.Timeout != 5 { - t.Errorf("expected default MCP timeout 5, got %d", cfg.MCP.Timeout) - } } // TestLoadConfigFile tests loading configuration from a file @@ -119,20 +112,20 @@ func TestLoadConfigFile(t *testing.T) { originalHome := os.Getenv("HOME") defer func() { if err := os.Setenv("HOME", originalHome); err != nil { - t.Errorf("Failed to restore HOME: %v", err) + t.Errorf("restoring HOME: %v", err) } }() // Set HOME to temp directory if err := os.Setenv("HOME", tmpDir); err != nil { - t.Fatalf("Failed to set HOME: %v", err) + t.Fatalf("setting HOME: %v", err) } if err := os.Setenv("GEMINI_API_KEY", "test-api-key"); err != nil { - t.Fatalf("Failed to set GEMINI_API_KEY: %v", err) + t.Fatalf("setting GEMINI_API_KEY: %v", err) } defer func() { if err := os.Unsetenv("GEMINI_API_KEY"); err != nil { - t.Errorf("Failed to unset GEMINI_API_KEY: %v", err) + t.Errorf("unsetting GEMINI_API_KEY: %v", err) } }() @@ -148,55 +141,50 @@ func TestLoadConfigFile(t *testing.T) { // Create .koopa directory koopaDir := filepath.Join(tmpDir, ".koopa") if err := os.MkdirAll(koopaDir, 0o750); err != nil { - t.Fatalf("failed to create koopa dir: %v", err) + t.Fatalf("creating koopa dir: %v", err) } // Create config file configContent := `model_name: gemini-2.5-pro temperature: 0.9 max_tokens: 4096 -rag_top_k: 5 postgres_host: test-host postgres_port: 5433 postgres_db_name: test_db ` configPath := filepath.Join(koopaDir, "config.yaml") if err := os.WriteFile(configPath, []byte(configContent), 0o600); err != nil { - t.Fatalf("failed to write config file: %v", err) + t.Fatalf("writing config file: %v", err) } cfg, err := Load() if err != nil { - t.Fatalf("Load() failed: %v", err) + t.Fatalf("Load() unexpected error: %v", err) } // Verify values from config file if cfg.ModelName != "gemini-2.5-pro" { - t.Errorf("expected ModelName 'gemini-2.5-pro', got %q", cfg.ModelName) + t.Errorf("Load().ModelName = %q, want %q", cfg.ModelName, "gemini-2.5-pro") } if cfg.Temperature != 0.9 { - t.Errorf("expected Temperature 0.9, got %f", cfg.Temperature) + t.Errorf("Load().Temperature = %f, want %f", cfg.Temperature, 0.9) } if cfg.MaxTokens != 4096 { - t.Errorf("expected MaxTokens 4096, got %d", cfg.MaxTokens) - } - - if cfg.RAGTopK != 5 { - t.Errorf("expected RAGTopK 5, got %d", cfg.RAGTopK) + t.Errorf("Load().MaxTokens = %d, want %d", cfg.MaxTokens, 4096) } if cfg.PostgresHost != "test-host" { - t.Errorf("expected PostgresHost 'test-host', got %q", cfg.PostgresHost) + t.Errorf("Load().PostgresHost = %q, want %q", cfg.PostgresHost, "test-host") } if cfg.PostgresPort != 5433 { - t.Errorf("expected PostgresPort 5433, got %d", cfg.PostgresPort) + t.Errorf("Load().PostgresPort = %d, want %d", cfg.PostgresPort, 5433) } if cfg.PostgresDBName != "test_db" { - t.Errorf("expected PostgresDBName 'test_db', got %q", cfg.PostgresDBName) + t.Errorf("Load().PostgresDBName = %q, want %q", cfg.PostgresDBName, "test_db") } } @@ -228,25 +216,25 @@ func TestConfigDirectoryCreation(t *testing.T) { originalHome := os.Getenv("HOME") defer func() { if err := os.Setenv("HOME", originalHome); err != nil { - t.Errorf("Failed to restore HOME: %v", err) + t.Errorf("restoring HOME: %v", err) } }() if err := os.Setenv("HOME", tmpDir); err != nil { - t.Fatalf("Failed to set HOME: %v", err) + t.Fatalf("setting HOME: %v", err) } if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil { - t.Fatalf("Failed to set GEMINI_API_KEY: %v", err) + t.Fatalf("setting GEMINI_API_KEY: %v", err) } defer func() { if err := os.Unsetenv("GEMINI_API_KEY"); err != nil { - t.Errorf("Failed to unset GEMINI_API_KEY: %v", err) + t.Errorf("unsetting GEMINI_API_KEY: %v", err) } }() _, err := Load() if err != nil { - t.Fatalf("Load() failed: %v", err) + t.Fatalf("Load() unexpected error: %v", err) } // Check that .koopa directory was created @@ -257,14 +245,13 @@ func TestConfigDirectoryCreation(t *testing.T) { } if !info.IsDir() { - t.Error("expected .koopa to be a directory") + t.Error("Load() created .koopa as file, want directory") } // Check permissions (0750 = drwxr-x---) perm := info.Mode().Perm() - expectedPerm := os.FileMode(0o750) - if perm != expectedPerm { - t.Errorf("expected permissions %o, got %o", expectedPerm, perm) + if perm != 0o750 { + t.Errorf("Load() dir permissions = %o, want %o", perm, 0o750) } } @@ -274,26 +261,26 @@ func TestEnvironmentVariableOverride(t *testing.T) { originalHome := os.Getenv("HOME") defer func() { if err := os.Setenv("HOME", originalHome); err != nil { - t.Errorf("Failed to restore HOME: %v", err) + t.Errorf("restoring HOME: %v", err) } }() if err := os.Setenv("HOME", tmpDir); err != nil { - t.Fatalf("Failed to set HOME: %v", err) + t.Fatalf("setting HOME: %v", err) } if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil { - t.Fatalf("Failed to set GEMINI_API_KEY: %v", err) + t.Fatalf("setting GEMINI_API_KEY: %v", err) } defer func() { if err := os.Unsetenv("GEMINI_API_KEY"); err != nil { - t.Errorf("Failed to unset GEMINI_API_KEY: %v", err) + t.Errorf("unsetting GEMINI_API_KEY: %v", err) } }() // Create .koopa directory and config file koopaDir := filepath.Join(tmpDir, ".koopa") if err := os.MkdirAll(koopaDir, 0o750); err != nil { - t.Fatalf("failed to create koopa dir: %v", err) + t.Fatalf("creating koopa dir: %v", err) } configContent := `model_name: gemini-2.5-pro @@ -302,7 +289,7 @@ max_tokens: 1024 ` configPath := filepath.Join(koopaDir, "config.yaml") if err := os.WriteFile(configPath, []byte(configContent), 0o600); err != nil { - t.Fatalf("failed to write config file: %v", err) + t.Fatalf("writing config file: %v", err) } // KOOPA_* env vars NO LONGER supported (removed AutomaticEnv) @@ -310,10 +297,10 @@ max_tokens: 1024 testHMACSecret := "test-hmac-secret-minimum-32-chars-long" if err := os.Setenv("DD_API_KEY", testAPIKey); err != nil { - t.Fatalf("Failed to set DD_API_KEY: %v", err) + t.Fatalf("setting DD_API_KEY: %v", err) } if err := os.Setenv("HMAC_SECRET", testHMACSecret); err != nil { - t.Fatalf("Failed to set HMAC_SECRET: %v", err) + t.Fatalf("setting HMAC_SECRET: %v", err) } defer func() { _ = os.Unsetenv("DD_API_KEY") @@ -322,29 +309,29 @@ max_tokens: 1024 cfg, err := Load() if err != nil { - t.Fatalf("Load() failed: %v", err) + t.Fatalf("Load() unexpected error: %v", err) } // Config values should come from config.yaml (NOT env vars) if cfg.ModelName != "gemini-2.5-pro" { - t.Errorf("expected ModelName from config 'gemini-2.5-pro', got %q", cfg.ModelName) + t.Errorf("Load().ModelName = %q, want %q", cfg.ModelName, "gemini-2.5-pro") } if cfg.Temperature != 0.5 { - t.Errorf("expected Temperature from config 0.5, got %f", cfg.Temperature) + t.Errorf("Load().Temperature = %f, want %f", cfg.Temperature, 0.5) } if cfg.MaxTokens != 1024 { - t.Errorf("expected MaxTokens from config 1024, got %d", cfg.MaxTokens) + t.Errorf("Load().MaxTokens = %d, want %d", cfg.MaxTokens, 1024) } // Sensitive env vars should be bound if cfg.Datadog.APIKey != testAPIKey { - t.Errorf("expected Datadog.APIKey from env %q, got %q", testAPIKey, cfg.Datadog.APIKey) + t.Errorf("Load().Datadog.APIKey = %q, want %q", cfg.Datadog.APIKey, testAPIKey) } if cfg.HMACSecret != testHMACSecret { - t.Errorf("expected HMACSecret from env %q, got %q", testHMACSecret, cfg.HMACSecret) + t.Errorf("Load().HMACSecret = %q, want %q", cfg.HMACSecret, testHMACSecret) } } @@ -354,26 +341,26 @@ func TestLoadInvalidYAML(t *testing.T) { originalHome := os.Getenv("HOME") defer func() { if err := os.Setenv("HOME", originalHome); err != nil { - t.Errorf("Failed to restore HOME: %v", err) + t.Errorf("restoring HOME: %v", err) } }() if err := os.Setenv("HOME", tmpDir); err != nil { - t.Fatalf("Failed to set HOME: %v", err) + t.Fatalf("setting HOME: %v", err) } if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil { - t.Fatalf("Failed to set GEMINI_API_KEY: %v", err) + t.Fatalf("setting GEMINI_API_KEY: %v", err) } defer func() { if err := os.Unsetenv("GEMINI_API_KEY"); err != nil { - t.Errorf("Failed to unset GEMINI_API_KEY: %v", err) + t.Errorf("unsetting GEMINI_API_KEY: %v", err) } }() // Create .koopa directory koopaDir := filepath.Join(tmpDir, ".koopa") if err := os.MkdirAll(koopaDir, 0o750); err != nil { - t.Fatalf("failed to create koopa dir: %v", err) + t.Fatalf("creating koopa dir: %v", err) } // Create invalid YAML config file @@ -384,12 +371,12 @@ max_tokens: not_a_number ` configPath := filepath.Join(koopaDir, "config.yaml") if err := os.WriteFile(configPath, []byte(invalidYAML), 0o600); err != nil { - t.Fatalf("failed to write invalid config file: %v", err) + t.Fatalf("writing invalid config file: %v", err) } _, err := Load() if err == nil { - t.Error("expected error for invalid YAML, got none") + t.Error("Load() error = nil, want error for invalid YAML") } } @@ -399,26 +386,26 @@ func TestLoadUnmarshalError(t *testing.T) { originalHome := os.Getenv("HOME") defer func() { if err := os.Setenv("HOME", originalHome); err != nil { - t.Errorf("Failed to restore HOME: %v", err) + t.Errorf("restoring HOME: %v", err) } }() if err := os.Setenv("HOME", tmpDir); err != nil { - t.Fatalf("Failed to set HOME: %v", err) + t.Fatalf("setting HOME: %v", err) } if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil { - t.Fatalf("Failed to set GEMINI_API_KEY: %v", err) + t.Fatalf("setting GEMINI_API_KEY: %v", err) } defer func() { if err := os.Unsetenv("GEMINI_API_KEY"); err != nil { - t.Errorf("Failed to unset GEMINI_API_KEY: %v", err) + t.Errorf("unsetting GEMINI_API_KEY: %v", err) } }() // Create .koopa directory koopaDir := filepath.Join(tmpDir, ".koopa") if err := os.MkdirAll(koopaDir, 0o750); err != nil { - t.Fatalf("failed to create koopa dir: %v", err) + t.Fatalf("creating koopa dir: %v", err) } // Create config with type mismatch @@ -428,30 +415,28 @@ max_tokens: "this should also be a number" ` configPath := filepath.Join(koopaDir, "config.yaml") if err := os.WriteFile(configPath, []byte(invalidTypeYAML), 0o600); err != nil { - t.Fatalf("failed to write config file: %v", err) + t.Fatalf("writing config file: %v", err) } - // This will succeed because viper is flexible with type conversion - // but we document this test to show that invalid types are handled - _, err := Load() - // Note: viper may successfully parse string "0.7" as float, so we don't assert error here - _ = err + // Viper is flexible with type conversion so this may succeed or fail. + // The test verifies Load() doesn't panic on type-mismatched YAML. + _, _ = Load() } // BenchmarkLoad benchmarks configuration loading func BenchmarkLoad(b *testing.B) { if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil { - b.Fatalf("Failed to set GEMINI_API_KEY: %v", err) + b.Fatalf("setting GEMINI_API_KEY: %v", err) } defer func() { if err := os.Unsetenv("GEMINI_API_KEY"); err != nil { - b.Errorf("Failed to unset GEMINI_API_KEY: %v", err) + b.Errorf("unsetting GEMINI_API_KEY: %v", err) } }() // Verify Load() works before starting benchmark if _, err := Load(); err != nil { - b.Fatalf("Load() failed: %v", err) + b.Fatalf("Load() unexpected error: %v", err) } b.ResetTimer() @@ -473,7 +458,7 @@ func TestConfig_MarshalJSON_MasksSensitiveFields(t *testing.T) { data, err := json.Marshal(cfg) if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) + t.Fatalf("MarshalJSON() unexpected error: %v", err) } jsonStr := string(data) @@ -488,9 +473,9 @@ func TestConfig_MarshalJSON_MasksSensitiveFields(t *testing.T) { // 1. Not be the original password // 2. Contain masking characters (****) // 3. Be present in the JSON output - var result map[string]interface{} + var result map[string]any if err := json.Unmarshal(data, &result); err != nil { - t.Fatalf("failed to unmarshal result: %v", err) + t.Fatalf("unmarshaling result: %v", err) } maskedPwd, ok := result["postgres_password"].(string) @@ -522,17 +507,17 @@ func TestConfig_MarshalJSON_EmptyPassword(t *testing.T) { data, err := json.Marshal(cfg) if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) + t.Fatalf("MarshalJSON() unexpected error: %v", err) } // Empty password should remain empty, not cause panic - var result map[string]interface{} + var result map[string]any if err := json.Unmarshal(data, &result); err != nil { - t.Fatalf("failed to unmarshal result: %v", err) + t.Fatalf("unmarshaling result: %v", err) } if result["postgres_password"] != "" { - t.Errorf("expected empty password to remain empty, got %v", result["postgres_password"]) + t.Errorf("MarshalJSON() postgres_password = %v, want empty string", result["postgres_password"]) } } @@ -544,7 +529,7 @@ func TestConfig_MarshalJSON_ShortPassword(t *testing.T) { data, err := json.Marshal(cfg) if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) + t.Fatalf("MarshalJSON() unexpected error: %v", err) } jsonStr := string(data) @@ -555,7 +540,7 @@ func TestConfig_MarshalJSON_ShortPassword(t *testing.T) { } if !strings.Contains(jsonStr, `"postgres_password":"████████"`) { - t.Errorf("expected fully masked password '████████', got: %s", jsonStr) + t.Errorf("MarshalJSON() short password not fully masked, got: %s", jsonStr) } } @@ -582,23 +567,16 @@ func TestConfig_SensitiveFieldsAreMasked(t *testing.T) { Datadog: DatadogConfig{ APIKey: "datadogapikey789", }, - MCPServers: map[string]MCPServer{ - "test": { - Env: map[string]string{ - "API_KEY": "envvarkey123", - }, - }, - }, } data, err := json.Marshal(cfg) if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) + t.Fatalf("MarshalJSON() unexpected error: %v", err) } jsonStr := string(data) // Verify no raw secrets appear in output - secrets := []string{"secretpassword123", "hmacsecret456", "datadogapikey789", "envvarkey123"} + secrets := []string{"secretpassword123", "hmacsecret456", "datadogapikey789"} for _, secret := range secrets { if strings.Contains(jsonStr, secret) { t.Errorf("sensitive value %q should be masked in JSON output", secret) @@ -612,17 +590,6 @@ func TestConfig_MarshalJSON_NestedStructs(t *testing.T) { cfg := Config{ ModelName: "test-model", PostgresPassword: "secretpassword", - MCP: MCPConfig{ - Timeout: 10, - Allowed: []string{"server1", "server2"}, - }, - MCPServers: map[string]MCPServer{ - "test-server": { - Command: "npx", - Args: []string{"-y", "test-mcp"}, - Timeout: 30, - }, - }, SearXNG: SearXNGConfig{ BaseURL: "http://localhost:8080", }, @@ -635,48 +602,30 @@ func TestConfig_MarshalJSON_NestedStructs(t *testing.T) { data, err := json.Marshal(cfg) if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) + t.Fatalf("MarshalJSON() unexpected error: %v", err) } - var result map[string]interface{} + var result map[string]any if err := json.Unmarshal(data, &result); err != nil { - t.Fatalf("failed to unmarshal result: %v", err) - } - - // Verify nested MCP config is present - mcp, ok := result["mcp"].(map[string]interface{}) - if !ok { - t.Fatal("mcp should be a nested object in JSON output") - } - if mcp["timeout"] != float64(10) { - t.Errorf("expected mcp.timeout = 10, got %v", mcp["timeout"]) - } - - // Verify MCPServers map is present - servers, ok := result["mcp_servers"].(map[string]interface{}) - if !ok { - t.Fatal("mcp_servers should be a map in JSON output") - } - if _, exists := servers["test-server"]; !exists { - t.Error("expected test-server in mcp_servers") + t.Fatalf("unmarshaling result: %v", err) } // Verify SearXNG config - searxng, ok := result["searxng"].(map[string]interface{}) + searxng, ok := result["searxng"].(map[string]any) if !ok { t.Fatal("searxng should be a nested object") } if searxng["base_url"] != "http://localhost:8080" { - t.Errorf("expected searxng.base_url = 'http://localhost:8080', got %v", searxng["base_url"]) + t.Errorf("MarshalJSON() searxng.base_url = %v, want %q", searxng["base_url"], "http://localhost:8080") } // Verify Datadog config - datadog, ok := result["datadog"].(map[string]interface{}) + datadog, ok := result["datadog"].(map[string]any) if !ok { t.Fatal("datadog should be a nested object") } if datadog["environment"] != "test" { - t.Errorf("expected datadog.environment = 'test', got %v", datadog["environment"]) + t.Errorf("MarshalJSON() datadog.environment = %v, want %q", datadog["environment"], "test") } // CRITICAL: Verify sensitive field is still masked @@ -686,95 +635,6 @@ func TestConfig_MarshalJSON_NestedStructs(t *testing.T) { } } -// TestConfig_MarshalJSON_MCPServerEnvMasked verifies that MCPServer.Env (sensitive map) is masked -// SECURITY: MCPServer.Env commonly contains API keys, tokens, and secrets -func TestConfig_MarshalJSON_MCPServerEnvMasked(t *testing.T) { - cfg := Config{ - MCPServers: map[string]MCPServer{ - "github-mcp": { - Command: "npx", - Args: []string{"-y", "@modelcontextprotocol/server-github"}, - Env: map[string]string{ - "GITHUB_TOKEN": "ghp_supersecrettoken12345678", - "API_KEY": "sk-proj-secretapikey67890", - "OPENAI_API_KEY": "sk-openai-verysecretkey", - "ANTHROPIC_KEY": "anthropic-secret-key-xxx", - "DATABASE_PASSWORD": "dbpassword123", - }, - Timeout: 30, - }, - "another-server": { - Command: "node", - Env: map[string]string{ - "SECRET_TOKEN": "another_secret_value", - }, - }, - }, - } - - data, err := json.Marshal(cfg) - if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) - } - - jsonStr := string(data) - - // CRITICAL: All secret values in Env must be masked - secrets := []string{ - "ghp_supersecrettoken12345678", - "sk-proj-secretapikey67890", - "sk-openai-verysecretkey", - "anthropic-secret-key-xxx", - "dbpassword123", - "another_secret_value", - } - - for _, secret := range secrets { - if strings.Contains(jsonStr, secret) { - t.Errorf("SECURITY: MCPServer.Env secret leaked in JSON output: %s", secret) - } - } - - // Verify the Env field is present but masked - var result map[string]interface{} - if err := json.Unmarshal(data, &result); err != nil { - t.Fatalf("failed to unmarshal result: %v", err) - } - - servers, ok := result["mcp_servers"].(map[string]interface{}) - if !ok { - t.Fatal("mcp_servers should be present in JSON output") - } - - githubServer, ok := servers["github-mcp"].(map[string]interface{}) - if !ok { - t.Fatal("github-mcp server should be present") - } - - // Env map values should be masked individually (keys visible, values masked) - env, ok := githubServer["env"].(map[string]interface{}) - if !ok { - t.Fatal("MCPServer.Env should be a map") - } - // Check that original secrets are not present - for _, v := range env { - strVal, ok := v.(string) - if !ok { - t.Error("Env values should be strings") - continue - } - // Masked values should contain maskedValue (████████) - if !strings.Contains(strVal, "████████") && strVal != "" { - t.Errorf("Env value should be masked, got: %s", strVal) - } - } - - // Verify non-sensitive fields are NOT masked - if githubServer["command"] != "npx" { - t.Error("non-sensitive field Command should not be masked") - } -} - // NOTE: TestConfig_MarshalJSON_AllSensitiveFields was removed because: // 1. We no longer use `sensitive:"true"` tags (replaced with explicit MarshalJSON) // 2. TestConfig_SensitiveFieldsAreMasked provides equivalent coverage @@ -821,7 +681,7 @@ func TestMaskSecret_Unicode(t *testing.T) { // Verify masking pattern is present (when expected) if tt.wantContains != "" && !strings.Contains(masked, tt.wantContains) { - t.Errorf("expected masked output to contain %q, got: %q", tt.wantContains, masked) + t.Errorf("maskSecret(%q) = %q, want contains %q", tt.input, masked, tt.wantContains) } // CRITICAL: Original value must NEVER appear in masked output @@ -866,7 +726,7 @@ func TestConfig_MarshalJSON_UnicodePasswords(t *testing.T) { data, err := json.Marshal(cfg) if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) + t.Fatalf("MarshalJSON() unexpected error: %v", err) } jsonStr := string(data) @@ -878,7 +738,7 @@ func TestConfig_MarshalJSON_UnicodePasswords(t *testing.T) { // Verify masking was applied if !strings.Contains(jsonStr, "████████") { - t.Errorf("expected masked output to contain '████████', got: %s", jsonStr) + t.Errorf("MarshalJSON() output missing mask chars, got: %s", jsonStr) } }) } @@ -1082,19 +942,6 @@ func BenchmarkConfig_MarshalJSON(b *testing.B) { PostgresPort: 5432, PostgresUser: "koopa", PostgresDBName: "koopa", - MCP: MCPConfig{ - Timeout: 5, - Allowed: []string{"server1", "server2"}, - }, - MCPServers: map[string]MCPServer{ - "github": { - Command: "npx", - Args: []string{"-y", "@modelcontextprotocol/server-github"}, - Env: map[string]string{ - "GITHUB_TOKEN": "ghp_secrettoken12345", - }, - }, - }, } b.ResetTimer() @@ -1104,16 +951,40 @@ func BenchmarkConfig_MarshalJSON(b *testing.B) { } } +// TestFullModelName tests that FullModelName derives correct provider-qualified names. +func TestFullModelName(t *testing.T) { + tests := []struct { + name string + provider string + modelName string + want string + }{ + {name: "gemini default", provider: "", modelName: "gemini-2.5-flash", want: "googleai/gemini-2.5-flash"}, + {name: "gemini explicit", provider: "gemini", modelName: "gemini-2.5-pro", want: "googleai/gemini-2.5-pro"}, + {name: "ollama", provider: "ollama", modelName: "llama3.3", want: "ollama/llama3.3"}, + {name: "openai", provider: "openai", modelName: "gpt-4o", want: "openai/gpt-4o"}, + {name: "already qualified", provider: "ollama", modelName: "ollama/llama3.3", want: "ollama/llama3.3"}, + {name: "cross-qualified", provider: "gemini", modelName: "openai/gpt-4o", want: "openai/gpt-4o"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &Config{ + Provider: tt.provider, + ModelName: tt.modelName, + } + got := cfg.FullModelName() + if got != tt.want { + t.Errorf("FullModelName() = %q, want %q", got, tt.want) + } + }) + } +} + // BenchmarkConfig_MarshalJSON_Parallel benchmarks concurrent Config marshaling func BenchmarkConfig_MarshalJSON_Parallel(b *testing.B) { cfg := Config{ PostgresPassword: "supersecretpassword123", - MCPServers: map[string]MCPServer{ - "test": { - Command: "npx", - Env: map[string]string{"SECRET": "value"}, - }, - }, } b.ResetTimer() diff --git a/internal/config/observability.go b/internal/config/observability.go index 8ecf78a..c542795 100644 --- a/internal/config/observability.go +++ b/internal/config/observability.go @@ -8,7 +8,7 @@ import ( // DatadogConfig holds Datadog APM tracing configuration. // // Tracing uses the local Datadog Agent for OTLP ingestion. -// See internal/observability/datadog.go for detailed setup instructions. +// Setup is inlined in internal/app/setup.go (provideOtelShutdown). type DatadogConfig struct { // APIKey is the Datadog API key (optional, for observability) APIKey string `mapstructure:"api_key" json:"api_key"` diff --git a/internal/config/storage.go b/internal/config/storage.go index 5771070..f2230be 100644 --- a/internal/config/storage.go +++ b/internal/config/storage.go @@ -8,43 +8,39 @@ import ( "strings" ) -// StorageConfig documentation. -// Fields are embedded in the main Config struct for backward compatibility. -// -// PostgreSQL (for pgvector): -// - PostgresHost: Database host (default: localhost) -// - PostgresPort: Database port (default: 5432) -// - PostgresUser: Database user (default: koopa) -// - PostgresPassword: Database password -// - PostgresDBName: Database name (default: koopa) -// - PostgresSSLMode: SSL mode (default: disable) -// -// RAG: -// - RAGTopK: Number of documents to retrieve (1-10, default: 3) -// - EmbedderModel: Embedding model name (default: text-embedding-004) +// quoteDSNValue quotes a value for PostgreSQL key=value DSN format. +// Within single quotes, backslashes and single quotes are escaped. +// This prevents parsing errors when values contain spaces or special characters. +func quoteDSNValue(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `'`, `\'`) + return "'" + s + "'" +} // PostgresConnectionString returns the PostgreSQL DSN for pgx driver. +// Password is single-quoted to handle special characters (spaces, =, quotes). func (c *Config) PostgresConnectionString() string { return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", c.PostgresHost, c.PostgresPort, c.PostgresUser, - c.PostgresPassword, + quoteDSNValue(c.PostgresPassword), c.PostgresDBName, c.PostgresSSLMode, ) } // PostgresURL returns the PostgreSQL URL for golang-migrate. +// Uses url.URL for proper encoding of special characters in credentials. func (c *Config) PostgresURL() string { - return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s", - c.PostgresUser, - c.PostgresPassword, - c.PostgresHost, - c.PostgresPort, - c.PostgresDBName, - c.PostgresSSLMode, - ) + u := &url.URL{ + Scheme: "postgres", + User: url.UserPassword(c.PostgresUser, c.PostgresPassword), + Host: fmt.Sprintf("%s:%d", c.PostgresHost, c.PostgresPort), + Path: c.PostgresDBName, + RawQuery: fmt.Sprintf("sslmode=%s", c.PostgresSSLMode), + } + return u.String() } // parseDatabaseURL parses DATABASE_URL environment variable and sets PostgreSQL config. diff --git a/internal/config/storage_test.go b/internal/config/storage_test.go index 22ddad4..8f43b45 100644 --- a/internal/config/storage_test.go +++ b/internal/config/storage_test.go @@ -23,7 +23,7 @@ func TestPostgresConnectionString(t *testing.T) { "host=test-host", "port=5433", "user=test-user", - "password=test-password", + "password='test-password'", "dbname=test-db", "sslmode=require", } @@ -35,22 +35,130 @@ func TestPostgresConnectionString(t *testing.T) { } } +// TestPostgresConnectionStringSpecialChars tests DSN quoting handles special characters. +// The password is single-quoted in DSN format, so injection payloads like +// "' host=evil.com" become part of the quoted value, not separate DSN keys. +func TestPostgresConnectionStringSpecialChars(t *testing.T) { + tests := []struct { + name string + password string + wantQuoted string // expected quoted form in DSN + }{ + { + name: "password with spaces", + password: "my secret pass", + wantQuoted: "'my secret pass'", + }, + { + name: "password with single quotes", + password: "pass'word", + wantQuoted: `'pass\'word'`, + }, + { + name: "DSN injection attempt", + password: "' host=evil.com user=hacker password='", + wantQuoted: `'\' host=evil.com user=hacker password=\''`, + }, + { + name: "password with backslash", + password: `pass\word`, + wantQuoted: `'pass\\word'`, + }, + { + name: "password with equals", + password: "pass=word", + wantQuoted: "'pass=word'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &Config{ + PostgresHost: "localhost", + PostgresPort: 5432, + PostgresUser: "user", + PostgresPassword: tt.password, + PostgresDBName: "db", + PostgresSSLMode: "disable", + } + dsn := cfg.PostgresConnectionString() + + // Verify password is correctly quoted in DSN + wantPart := "password=" + tt.wantQuoted + if !strings.Contains(dsn, wantPart) { + t.Errorf("PostgresConnectionString() missing %q\ngot: %s", wantPart, dsn) + } + + // Verify DSN ends with expected structure after password + if !strings.HasSuffix(dsn, "dbname=db sslmode=disable") { + t.Errorf("PostgresConnectionString() has corrupted suffix\ngot: %s", dsn) + } + }) + } +} + +// TestQuoteDSNValue tests the DSN value quoting helper. +func TestQuoteDSNValue(t *testing.T) { + tests := []struct { + input string + want string + }{ + {input: "simple", want: "'simple'"}, + {input: "with space", want: "'with space'"}, + {input: "with'quote", want: `'with\'quote'`}, + {input: `with\backslash`, want: `'with\\backslash'`}, + {input: "", want: "''"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := quoteDSNValue(tt.input) + if got != tt.want { + t.Errorf("quoteDSNValue(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + // TestPostgresURL tests PostgreSQL URL generation for golang-migrate func TestPostgresURL(t *testing.T) { - cfg := &Config{ - PostgresHost: "test-host", - PostgresPort: 5433, - PostgresUser: "test-user", - PostgresPassword: "test-password", - PostgresDBName: "test-db", - PostgresSSLMode: "require", + tests := []struct { + name string + cfg *Config + want string + }{ + { + name: "simple credentials", + cfg: &Config{ + PostgresHost: "test-host", + PostgresPort: 5433, + PostgresUser: "test-user", + PostgresPassword: "test-password", + PostgresDBName: "test-db", + PostgresSSLMode: "require", + }, + want: "postgres://test-user:test-password@test-host:5433/test-db?sslmode=require", + }, + { + name: "password with special characters", + cfg: &Config{ + PostgresHost: "localhost", + PostgresPort: 5432, + PostgresUser: "koopa", + PostgresPassword: "p@ss/word#123", + PostgresDBName: "koopa", + PostgresSSLMode: "disable", + }, + want: "postgres://koopa:p%40ss%2Fword%23123@localhost:5432/koopa?sslmode=disable", + }, } - url := cfg.PostgresURL() - - expected := "postgres://test-user:test-password@test-host:5433/test-db?sslmode=require" - if url != expected { - t.Errorf("PostgresURL() = %q, want %q", url, expected) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.cfg.PostgresURL() + if got != tt.want { + t.Errorf("PostgresURL() = %q, want %q", got, tt.want) + } + }) } } diff --git a/internal/config/testdata/fuzz/FuzzConfigMarshalJSON/771e938e4458e983 b/internal/config/testdata/fuzz/FuzzConfigMarshalJSON/771e938e4458e983 deleted file mode 100644 index ee3f339..0000000 --- a/internal/config/testdata/fuzz/FuzzConfigMarshalJSON/771e938e4458e983 +++ /dev/null @@ -1,2 +0,0 @@ -go test fuzz v1 -string("0") diff --git a/internal/config/tools.go b/internal/config/tools.go index efb3a0e..4ceb5c8 100644 --- a/internal/config/tools.go +++ b/internal/config/tools.go @@ -1,46 +1,5 @@ package config -import ( - "encoding/json" - "fmt" -) - -// MCPConfig controls global MCP (Model Context Protocol) behavior. -type MCPConfig struct { - Allowed []string `mapstructure:"allowed" json:"allowed"` // Whitelist of server names (empty = all configured servers) - Excluded []string `mapstructure:"excluded" json:"excluded"` // Blacklist of server names (higher priority than Allowed) - Timeout int `mapstructure:"timeout" json:"timeout"` // Connection timeout in seconds (default: 5) -} - -// MCPServer defines a single MCP server configuration. -type MCPServer struct { - Command string `mapstructure:"command" json:"command"` // Required: executable path (e.g., "npx") - Args []string `mapstructure:"args" json:"args"` // Optional: command arguments - Env map[string]string `mapstructure:"env" json:"env"` // Optional: environment variables - SECURITY: May contain API keys/tokens - Timeout int `mapstructure:"timeout" json:"timeout"` // Optional: per-server timeout (overrides global) - IncludeTools []string `mapstructure:"include_tools" json:"include_tools"` // Optional: tool whitelist - ExcludeTools []string `mapstructure:"exclude_tools" json:"exclude_tools"` // Optional: tool blacklist -} - -// MarshalJSON implements json.Marshaler with sensitive field masking. -// Masks all values in the Env map as they may contain API keys/tokens. -func (m MCPServer) MarshalJSON() ([]byte, error) { - type alias MCPServer - a := alias(m) - if a.Env != nil { - maskedEnv := make(map[string]string, len(a.Env)) - for k, v := range a.Env { - maskedEnv[k] = maskSecret(v) - } - a.Env = maskedEnv - } - data, err := json.Marshal(a) - if err != nil { - return nil, fmt.Errorf("marshal mcp server: %w", err) - } - return data, nil -} - // SearXNGConfig holds SearXNG service configuration for web search. type SearXNGConfig struct { // BaseURL is the SearXNG instance URL (e.g., http://searxng:8080) diff --git a/internal/config/validation.go b/internal/config/validation.go index dd88a05..43d7e2a 100644 --- a/internal/config/validation.go +++ b/internal/config/validation.go @@ -8,7 +8,7 @@ import ( ) // supportedProviders lists all valid AI provider values. -var supportedProviders = []string{"gemini", "ollama", "openai"} +var supportedProviders = []string{ProviderGemini, ProviderOllama, ProviderOpenAI} // Validate validates configuration values. // Returns sentinel errors that can be checked with errors.Is(). @@ -54,17 +54,17 @@ func (c *Config) validateAI() error { } // Ollama host - if c.resolvedProvider() == "ollama" && c.OllamaHost == "" { + if c.resolvedProvider() == ProviderOllama && c.OllamaHost == "" { return fmt.Errorf("%w: ollama_host cannot be empty when provider is ollama", ErrInvalidOllamaHost) } - // RAG - if c.RAGTopK <= 0 || c.RAGTopK > 10 { - return fmt.Errorf("%w: must be between 1 and 10, got %d", ErrInvalidRAGTopK, c.RAGTopK) - } + // RAG embedder if c.EmbedderModel == "" { return fmt.Errorf("%w: embedder_model cannot be empty", ErrInvalidEmbedderModel) } + if err := c.validateEmbedder(); err != nil { + return err + } return nil } @@ -104,37 +104,98 @@ func (c *Config) validatePostgresSSL() error { ErrInvalidPostgresSSLMode) } if !slices.Contains(validSSLModes, c.PostgresSSLMode) { - return fmt.Errorf("%w: %q is not valid, must be one of: %v\n"+ - "Note: 'allow' and 'prefer' modes are deprecated (vulnerable to MITM attacks)", + return fmt.Errorf("%w: %q is not valid, must be one of: %v (allow/prefer excluded: MITM vulnerable)", ErrInvalidPostgresSSLMode, c.PostgresSSLMode, validSSLModes) } return nil } +// knownEmbedderDimensions maps provider → model → native output dimension. +// Used to catch dimension mismatches at startup before hitting pgvector errors. +var knownEmbedderDimensions = map[string]map[string]int{ + ProviderGemini: { + "gemini-embedding-001": 3072, + "text-embedding-004": 768, + }, + ProviderOpenAI: { + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + }, +} + +// requiredVectorDimension must match the pgvector schema: embedding vector(768). +const requiredVectorDimension = 768 + +// validateEmbedder checks that the configured embedder model produces vectors +// compatible with the database schema. For known models whose native dimension +// differs from requiredVectorDimension, this returns an error so operators +// know to set OutputDimensionality (handled by rag.NewDocStoreConfig). +// +// Unknown providers or models pass validation silently — the operator may +// know what they are doing (e.g., a custom Ollama embedder producing 768-dim). +func (c *Config) validateEmbedder() error { + models, ok := knownEmbedderDimensions[c.resolvedProvider()] + if !ok { + return nil // unknown provider (e.g., ollama) — skip + } + dim, ok := models[c.EmbedderModel] + if !ok { + return nil // unknown model — skip + } + if dim != requiredVectorDimension { + slog.Warn("embedder native dimension differs from schema", + "model", c.EmbedderModel, + "native_dim", dim, + "schema_dim", requiredVectorDimension, + "note", "rag.NewDocStoreConfig truncates output via OutputDimensionality") + } + return nil +} + // resolvedProvider returns the effective provider, defaulting to "gemini". func (c *Config) resolvedProvider() string { if c.Provider == "" { - return "gemini" + return ProviderGemini } return c.Provider } +// ValidateServe validates configuration specific to serve mode. +// HMAC_SECRET is required for CSRF protection in HTTP mode. +func (c *Config) ValidateServe() error { + if err := c.Validate(); err != nil { + return err + } + if c.HMACSecret == "" { + return fmt.Errorf("%w: HMAC_SECRET environment variable is required for serve mode (min 32 characters)", + ErrMissingHMACSecret) + } + if len(c.HMACSecret) < 32 { + return fmt.Errorf("%w: must be at least 32 characters, got %d", + ErrInvalidHMACSecret, len(c.HMACSecret)) + } + if c.TrustProxy { + slog.Warn("trust_proxy is enabled — ensure this server is behind a reverse proxy") + } + return nil +} + // validateProviderAPIKey checks that the required API key is set for the configured provider. func (c *Config) validateProviderAPIKey() error { switch c.resolvedProvider() { - case "gemini": + case ProviderGemini: if os.Getenv("GEMINI_API_KEY") == "" { return fmt.Errorf("%w: GEMINI_API_KEY environment variable is required for provider %q\n"+ "Get your API key at: https://ai.google.dev/gemini-api/docs/api-key", ErrMissingAPIKey, c.resolvedProvider()) } - case "openai": + case ProviderOpenAI: if os.Getenv("OPENAI_API_KEY") == "" { return fmt.Errorf("%w: OPENAI_API_KEY environment variable is required for provider %q\n"+ "Get your API key at: https://platform.openai.com/api-keys", ErrMissingAPIKey, c.resolvedProvider()) } - case "ollama": + case ProviderOllama: // Ollama runs locally, no API key required } return nil diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go index 7d0d826..05a2799 100644 --- a/internal/config/validation_test.go +++ b/internal/config/validation_test.go @@ -14,8 +14,7 @@ func validBaseConfig(provider string) *Config { ModelName: "gemini-2.5-flash", Temperature: 0.7, MaxTokens: 2048, - RAGTopK: 3, - EmbedderModel: "text-embedding-004", + EmbedderModel: "gemini-embedding-001", PostgresHost: "localhost", PostgresPort: 5432, PostgresPassword: "test_password", @@ -39,12 +38,12 @@ func setEnvForProvider(t *testing.T, provider string) func() { switch provider { case "gemini", "": if err := os.Setenv("GEMINI_API_KEY", "test-api-key"); err != nil { - t.Fatalf("Failed to set GEMINI_API_KEY: %v", err) + t.Fatalf("setting GEMINI_API_KEY: %v", err) } return func() { os.Unsetenv("GEMINI_API_KEY") } case "openai": if err := os.Setenv("OPENAI_API_KEY", "test-openai-key"); err != nil { - t.Fatalf("Failed to set OPENAI_API_KEY: %v", err) + t.Fatalf("setting OPENAI_API_KEY: %v", err) } return func() { os.Unsetenv("OPENAI_API_KEY") } case "ollama": @@ -69,7 +68,7 @@ func TestValidateSuccess(t *testing.T) { cfg := validBaseConfig(provider) if err := cfg.Validate(); err != nil { - t.Errorf("Validate() failed with valid config (provider %q): %v", provider, err) + t.Errorf("Validate() unexpected error with valid config (provider %q): %v", provider, err) } }) } @@ -230,43 +229,6 @@ func TestValidateOllamaHost(t *testing.T) { } } -// TestValidateRAGTopK tests RAG top K validation. -func TestValidateRAGTopK(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - - tests := []struct { - name string - ragTopK int - wantErr bool - }{ - {name: "valid min", ragTopK: 1}, - {name: "valid mid", ragTopK: 5}, - {name: "valid max", ragTopK: 10}, - {name: "invalid zero", ragTopK: 0, wantErr: true}, - {name: "invalid negative", ragTopK: -1, wantErr: true}, - {name: "invalid too high", ragTopK: 11, wantErr: true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cfg := validBaseConfig("gemini") - cfg.RAGTopK = tt.ragTopK - - err := cfg.Validate() - if tt.wantErr && err == nil { - t.Errorf("expected error for rag_top_k %d, got nil", tt.ragTopK) - } - if !tt.wantErr && err != nil { - t.Errorf("unexpected error for rag_top_k %d: %v", tt.ragTopK, err) - } - if tt.wantErr && err != nil && !errors.Is(err, ErrInvalidRAGTopK) { - t.Errorf("error should be ErrInvalidRAGTopK, got: %v", err) - } - }) - } -} - // TestValidateEmbedderModel tests embedder model validation. func TestValidateEmbedderModel(t *testing.T) { cleanup := setEnvForProvider(t, "gemini") @@ -279,8 +241,8 @@ func TestValidateEmbedderModel(t *testing.T) { if err == nil { t.Fatal("expected error for empty embedder_model, got nil") } - if !strings.Contains(err.Error(), "embedder_model") { - t.Errorf("error should mention embedder_model, got: %v", err) + if !errors.Is(err, ErrInvalidEmbedderModel) { + t.Errorf("Validate() error = %v, want ErrInvalidEmbedderModel", err) } } @@ -442,14 +404,14 @@ func TestValidatePostgresSSLMode(t *testing.T) { // BenchmarkValidate benchmarks configuration validation. func BenchmarkValidate(b *testing.B) { if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil { - b.Fatalf("Failed to set GEMINI_API_KEY: %v", err) + b.Fatalf("setting GEMINI_API_KEY: %v", err) } defer os.Unsetenv("GEMINI_API_KEY") cfg := validBaseConfig("gemini") if err := cfg.Validate(); err != nil { - b.Fatalf("Validate() failed: %v", err) + b.Fatalf("Validate() unexpected error: %v", err) } b.ResetTimer() diff --git a/internal/log/log.go b/internal/log/log.go deleted file mode 100644 index 099922b..0000000 --- a/internal/log/log.go +++ /dev/null @@ -1,109 +0,0 @@ -// Package log provides a unified logging infrastructure for the koopa application. -// -// This package provides: -// - A type alias for *slog.Logger to use as DI dependency -// - Factory functions to create configured loggers -// - A Nop logger for testing -// -// Design Philosophy: -// - Use Dependency Injection (DI) for loggers, not globals -// - Each component receives a logger via constructor -// - Components can add context via logger.With() -// -// Usage: -// -// // Create a logger at application startup -// logger := log.New(log.Config{Level: slog.LevelDebug}) -// -// // Inject into components with context -// fileToolset := tools.NewFileToolset(pathVal, logger.With("component", "file")) -// agent := agent.New(logger.With("component", "agent"), ...) -// -// // In tests, use Nop logger or capture to buffer -// testLogger := log.NewNop() -// // or -// var buf bytes.Buffer -// testLogger := log.NewWithWriter(&buf, log.Config{}) -package log - -import ( - "io" - "log/slog" - "os" -) - -// Logger is a type alias for *slog.Logger. -// Using the standard library type directly provides: -// - Full compatibility with slog ecosystem -// - Access to With() for adding context -// - No need for custom interface definitions -// -// Components should accept log.Logger as a dependency. -type Logger = *slog.Logger - -// Config defines logger configuration options. -type Config struct { - // Level sets the minimum log level. Default: slog.LevelInfo - Level slog.Level - - // JSON enables JSON format output. Default: false (text format) - JSON bool - - // AddSource adds source file information to log entries. Default: false - AddSource bool -} - -// New creates a new logger with the given configuration. -// Output is written to os.Stderr by default. -// -// Example: -// -// logger := log.New(log.Config{ -// Level: slog.LevelDebug, -// JSON: true, -// }) -func New(cfg Config) Logger { - return NewWithWriter(os.Stderr, cfg) -} - -// NewWithWriter creates a new logger that writes to the specified writer. -// Useful for testing or custom output destinations. -// -// Example: -// -// var buf bytes.Buffer -// logger := log.NewWithWriter(&buf, log.Config{}) -// // ... use logger -// fmt.Println(buf.String()) // inspect log output -func NewWithWriter(w io.Writer, cfg Config) Logger { - opts := &slog.HandlerOptions{ - Level: cfg.Level, - AddSource: cfg.AddSource, - } - - var handler slog.Handler - if cfg.JSON { - handler = slog.NewJSONHandler(w, opts) - } else { - handler = slog.NewTextHandler(w, opts) - } - - return slog.New(handler) -} - -// NewNop creates a logger that discards all output. -// -// WARNING: This should ONLY be used in tests. Never use NewNop() in production -// code as it will silently discard all logs, making debugging impossible. -// Production code should always use New() or NewWithWriter() with proper configuration. -// -// Example: -// -// func TestSomething(t *testing.T) { -// logger := log.NewNop() -// sut := NewMyComponent(logger) -// // ... test without log noise -// } -func NewNop() Logger { - return slog.New(slog.DiscardHandler) -} diff --git a/internal/log/log_test.go b/internal/log/log_test.go deleted file mode 100644 index 0a79fd6..0000000 --- a/internal/log/log_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package log - -import ( - "bytes" - "log/slog" - "strings" - "testing" -) - -func TestNew(t *testing.T) { - logger := New(Config{}) - if logger == nil { - t.Fatal("New() returned nil") - } -} - -func TestNewWithWriter(t *testing.T) { - var buf bytes.Buffer - - logger := NewWithWriter(&buf, Config{ - Level: slog.LevelDebug, - }) - - logger.Info("test message", "key", "value") - - output := buf.String() - if !strings.Contains(output, "test message") { - t.Errorf("expected output to contain 'test message', got: %s", output) - } - if !strings.Contains(output, "key=value") { - t.Errorf("expected output to contain 'key=value', got: %s", output) - } -} - -func TestNewWithWriter_JSON(t *testing.T) { - var buf bytes.Buffer - - logger := NewWithWriter(&buf, Config{ - Level: slog.LevelInfo, - JSON: true, - }) - - logger.Info("json test", "foo", "bar") - - output := buf.String() - if !strings.Contains(output, `"msg":"json test"`) { - t.Errorf("expected JSON output with msg field, got: %s", output) - } -} - -func TestNewNop(t *testing.T) { - logger := NewNop() - if logger == nil { - t.Fatal("NewNop() returned nil") - } - - // Should not panic - logger.Info("this should be discarded") - logger.Error("this too") -} - -func TestLogger_With(t *testing.T) { - var buf bytes.Buffer - - logger := NewWithWriter(&buf, Config{ - Level: slog.LevelInfo, - }) - - // Add component context - componentLogger := logger.With("component", "test") - componentLogger.Info("component log") - - output := buf.String() - if !strings.Contains(output, "component=test") { - t.Errorf("expected output to contain 'component=test', got: %s", output) - } -} - -func TestLogger_Levels(t *testing.T) { - var buf bytes.Buffer - - logger := NewWithWriter(&buf, Config{ - Level: slog.LevelDebug, - }) - - logger.Debug("debug msg") - logger.Info("info msg") - logger.Warn("warn msg") - logger.Error("error msg") - - output := buf.String() - - levels := []string{"DEBUG", "INFO", "WARN", "ERROR"} - for _, level := range levels { - if !strings.Contains(output, level) { - t.Errorf("expected output to contain %s level", level) - } - } -} - -func TestLogger_LevelFiltering(t *testing.T) { - var buf bytes.Buffer - - // Only INFO and above - logger := NewWithWriter(&buf, Config{ - Level: slog.LevelInfo, - }) - - logger.Debug("debug should not appear") - logger.Info("info should appear") - - output := buf.String() - - if strings.Contains(output, "debug should not appear") { - t.Error("DEBUG message should be filtered out") - } - if !strings.Contains(output, "info should appear") { - t.Error("INFO message should appear") - } -} - -func TestLoggerTypeAlias(t *testing.T) { - // Verify that Logger is compatible with *slog.Logger - logger := slog.Default() - if logger == nil { - t.Fatal("Logger type alias should be compatible with *slog.Logger") - } -} diff --git a/internal/mcp/benchmark_test.go b/internal/mcp/benchmark_test.go index afec82a..cf0791d 100644 --- a/internal/mcp/benchmark_test.go +++ b/internal/mcp/benchmark_test.go @@ -16,46 +16,48 @@ import ( // BenchmarkServer_Creation benchmarks MCP server creation. // Run with: go test -bench=BenchmarkServer_Creation -benchmem ./internal/mcp/... func BenchmarkServer_Creation(b *testing.B) { - for b.Loop() { - tmpDir := b.TempDir() - pathVal, err := security.NewPath([]string{tmpDir}) - if err != nil { - b.Fatalf("Failed to create path validator: %v", err) - } + // Setup toolsets once — benchmark only NewServer + tool registration. + tmpDir := b.TempDir() + pathVal, err := security.NewPath([]string{tmpDir}) + if err != nil { + b.Fatalf("creating path validator: %v", err) + } - fileTools, err := tools.NewFileTools(pathVal, slog.Default()) - if err != nil { - b.Fatalf("Failed to create file tools: %v", err) - } + fileTools, err := tools.NewFile(pathVal, slog.Default()) + if err != nil { + b.Fatalf("creating file tools: %v", err) + } - cmdVal := security.NewCommand() - envVal := security.NewEnv() - systemTools, err := tools.NewSystemTools(cmdVal, envVal, slog.Default()) - if err != nil { - b.Fatalf("Failed to create system tools: %v", err) - } + cmdVal := security.NewCommand() + envVal := security.NewEnv() + systemTools, err := tools.NewSystem(cmdVal, envVal, slog.Default()) + if err != nil { + b.Fatalf("creating system tools: %v", err) + } - networkTools, err := tools.NewNetworkTools(tools.NetworkConfig{ - SearchBaseURL: "http://localhost:8080", - FetchParallelism: 2, - FetchDelay: 100 * time.Millisecond, - FetchTimeout: 30 * time.Second, - }, slog.Default()) - if err != nil { - b.Fatalf("Failed to create network tools: %v", err) - } + networkTools, err := tools.NewNetwork(tools.NetConfig{ + SearchBaseURL: "http://localhost:8080", + FetchParallelism: 2, + FetchDelay: 100 * time.Millisecond, + FetchTimeout: 30 * time.Second, + }, slog.Default()) + if err != nil { + b.Fatalf("creating network tools: %v", err) + } - cfg := Config{ - Name: "benchmark-server", - Version: "1.0.0", - FileTools: fileTools, - SystemTools: systemTools, - NetworkTools: networkTools, - } + cfg := Config{ + Name: "benchmark-server", + Version: "1.0.0", + File: fileTools, + System: systemTools, + Network: networkTools, + } + b.ResetTimer() + for b.Loop() { _, err = NewServer(cfg) if err != nil { - b.Fatalf("NewServer failed: %v", err) + b.Fatalf("NewServer(): %v", err) } } } @@ -70,7 +72,7 @@ func BenchmarkJSONRPC_Parse(b *testing.B) { for b.Loop() { var request map[string]any if err := json.Unmarshal([]byte(requestJSON), &request); err != nil { - b.Fatalf("JSON unmarshal failed: %v", err) + b.Fatalf("unmarshaling JSON: %v", err) } } } @@ -96,7 +98,7 @@ func BenchmarkJSONRPC_Parse_LargePayload(b *testing.B) { for b.Loop() { var parsed map[string]any if err := json.Unmarshal(responseJSON, &parsed); err != nil { - b.Fatalf("JSON unmarshal failed: %v", err) + b.Fatalf("unmarshaling JSON: %v", err) } } } @@ -120,7 +122,7 @@ func BenchmarkJSONRPC_Serialize(b *testing.B) { for b.Loop() { _, err := json.Marshal(response) if err != nil { - b.Fatalf("JSON marshal failed: %v", err) + b.Fatalf("marshaling JSON: %v", err) } } } @@ -144,7 +146,7 @@ func BenchmarkJSONRPC_Serialize_LargePayload(b *testing.B) { for b.Loop() { _, err := json.Marshal(response) if err != nil { - b.Fatalf("JSON marshal failed: %v", err) + b.Fatalf("marshaling JSON: %v", err) } } } @@ -157,7 +159,7 @@ func BenchmarkReadFileInput_Parse(b *testing.B) { for b.Loop() { var input tools.ReadFileInput if err := json.Unmarshal([]byte(inputJSON), &input); err != nil { - b.Fatalf("JSON unmarshal failed: %v", err) + b.Fatalf("unmarshaling JSON: %v", err) } } } @@ -167,37 +169,37 @@ func BenchmarkConfig_Validation(b *testing.B) { tmpDir := b.TempDir() pathVal, err := security.NewPath([]string{tmpDir}) if err != nil { - b.Fatalf("Failed to create path validator: %v", err) + b.Fatalf("creating path validator: %v", err) } - fileTools, err := tools.NewFileTools(pathVal, slog.Default()) + fileTools, err := tools.NewFile(pathVal, slog.Default()) if err != nil { - b.Fatalf("Failed to create file tools: %v", err) + b.Fatalf("creating file tools: %v", err) } cmdVal := security.NewCommand() envVal := security.NewEnv() - systemTools, err := tools.NewSystemTools(cmdVal, envVal, slog.Default()) + systemTools, err := tools.NewSystem(cmdVal, envVal, slog.Default()) if err != nil { - b.Fatalf("Failed to create system tools: %v", err) + b.Fatalf("creating system tools: %v", err) } - networkTools, err := tools.NewNetworkTools(tools.NetworkConfig{ + networkTools, err := tools.NewNetwork(tools.NetConfig{ SearchBaseURL: "http://localhost:8080", FetchParallelism: 2, FetchDelay: 100 * time.Millisecond, FetchTimeout: 30 * time.Second, }, slog.Default()) if err != nil { - b.Fatalf("Failed to create network tools: %v", err) + b.Fatalf("creating network tools: %v", err) } cfg := Config{ - Name: "validation-test", - Version: "1.0.0", - FileTools: fileTools, - SystemTools: systemTools, - NetworkTools: networkTools, + Name: "validation-test", + Version: "1.0.0", + File: fileTools, + System: systemTools, + Network: networkTools, } b.ResetTimer() diff --git a/internal/mcp/doc.go b/internal/mcp/doc.go index 5534034..7247cc8 100644 --- a/internal/mcp/doc.go +++ b/internal/mcp/doc.go @@ -21,7 +21,7 @@ // +-- Handler Methods (ReadFile, WriteFile, ExecuteCommand, ...) // | // v -// Toolsets (FileTools, SystemTools, NetworkTools, KnowledgeTools) +// Toolsets (File, System, Network, Knowledge) // | // v // Execution Results @@ -34,7 +34,7 @@ // // Network tools (2): web_search, web_fetch // -// Knowledge tools (3, optional): search_history, search_documents, search_system_knowledge +// Knowledge tools (3-4, optional): search_history, search_documents, search_system_knowledge, knowledge_store // // # Tool Handler Pattern // diff --git a/internal/mcp/file.go b/internal/mcp/file.go index b13024b..a4b35bc 100644 --- a/internal/mcp/file.go +++ b/internal/mcp/file.go @@ -10,61 +10,73 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -// registerFileTools registers all file operation tools to the MCP server. +// registerFile registers all file operation tools to the MCP server. // Tools: read_file, write_file, list_files, delete_file, get_file_info -func (s *Server) registerFileTools() error { +func (s *Server) registerFile() error { // read_file readFileSchema, err := jsonschema.For[tools.ReadFileInput](nil) if err != nil { - return fmt.Errorf("schema for %s: %w", tools.ToolReadFile, err) + return fmt.Errorf("schema for %s: %w", tools.ReadFileName, err) } mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolReadFile, - Description: "Read the complete content of any text-based file.", + Name: tools.ReadFileName, + Description: "Read the complete content of a text-based file. " + + "Use this to examine source code, configuration files, logs, or documentation. " + + "Supports files up to 10MB. Binary files are not supported and will return an error. " + + "Returns: file path, content (UTF-8), size in bytes, and line count.", InputSchema: readFileSchema, }, s.ReadFile) // write_file writeFileSchema, err := jsonschema.For[tools.WriteFileInput](nil) if err != nil { - return fmt.Errorf("schema for %s: %w", tools.ToolWriteFile, err) + return fmt.Errorf("schema for %s: %w", tools.WriteFileName, err) } mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolWriteFile, - Description: "Write or create any text-based file.", + Name: tools.WriteFileName, + Description: "Write or create a text-based file with the specified content. " + + "Creates parent directories automatically if they don't exist. " + + "Overwrites existing files without confirmation. " + + "Returns: file path, bytes written, whether file was created or updated.", InputSchema: writeFileSchema, }, s.WriteFile) // list_files listFilesSchema, err := jsonschema.For[tools.ListFilesInput](nil) if err != nil { - return fmt.Errorf("schema for %s: %w", tools.ToolListFiles, err) + return fmt.Errorf("schema for %s: %w", tools.ListFilesName, err) } mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolListFiles, - Description: "List all files and subdirectories in a directory.", + Name: tools.ListFilesName, + Description: "List files and subdirectories in a directory. " + + "Returns file names, sizes, types (file/directory), and modification times. " + + "Does not recurse into subdirectories.", InputSchema: listFilesSchema, }, s.ListFiles) // delete_file deleteFileSchema, err := jsonschema.For[tools.DeleteFileInput](nil) if err != nil { - return fmt.Errorf("schema for %s: %w", tools.ToolDeleteFile, err) + return fmt.Errorf("schema for %s: %w", tools.DeleteFileName, err) } mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolDeleteFile, - Description: "Delete a file permanently.", + Name: tools.DeleteFileName, + Description: "Permanently delete a file or empty directory. " + + "WARNING: This action cannot be undone. " + + "Only deletes empty directories.", InputSchema: deleteFileSchema, }, s.DeleteFile) // get_file_info getFileInfoSchema, err := jsonschema.For[tools.GetFileInfoInput](nil) if err != nil { - return fmt.Errorf("schema for %s: %w", tools.ToolGetFileInfo, err) + return fmt.Errorf("schema for %s: %w", tools.FileInfoName, err) } mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolGetFileInfo, - Description: "Get detailed metadata about a file.", + Name: tools.FileInfoName, + Description: "Get detailed metadata about a file without reading its contents. " + + "Returns: file size, modification time, permissions, and type (file/directory). " + + "More efficient than read_file when you only need metadata.", InputSchema: getFileInfoSchema, }, s.GetFileInfo) @@ -74,54 +86,54 @@ func (s *Server) registerFileTools() error { // ReadFile handles the readFile MCP tool call. func (s *Server) ReadFile(ctx context.Context, _ *mcp.CallToolRequest, input tools.ReadFileInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.fileTools.ReadFile(toolCtx, input) + result, err := s.file.ReadFile(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("readFile failed: %w", err) + return nil, nil, fmt.Errorf("reading file: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } // WriteFile handles the writeFile MCP tool call. func (s *Server) WriteFile(ctx context.Context, _ *mcp.CallToolRequest, input tools.WriteFileInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.fileTools.WriteFile(toolCtx, input) + result, err := s.file.WriteFile(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("writeFile failed: %w", err) + return nil, nil, fmt.Errorf("writing file: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } // ListFiles handles the listFiles MCP tool call. func (s *Server) ListFiles(ctx context.Context, _ *mcp.CallToolRequest, input tools.ListFilesInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.fileTools.ListFiles(toolCtx, input) + result, err := s.file.ListFiles(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("listFiles failed: %w", err) + return nil, nil, fmt.Errorf("listing files: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } // DeleteFile handles the deleteFile MCP tool call. func (s *Server) DeleteFile(ctx context.Context, _ *mcp.CallToolRequest, input tools.DeleteFileInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.fileTools.DeleteFile(toolCtx, input) + result, err := s.file.DeleteFile(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("deleteFile failed: %w", err) + return nil, nil, fmt.Errorf("deleting file: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } // GetFileInfo handles the getFileInfo MCP tool call. func (s *Server) GetFileInfo(ctx context.Context, _ *mcp.CallToolRequest, input tools.GetFileInfoInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.fileTools.GetFileInfo(toolCtx, input) + result, err := s.file.GetFileInfo(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("getFileInfo failed: %w", err) + return nil, nil, fmt.Errorf("getting file info: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } diff --git a/internal/mcp/file_test.go b/internal/mcp/file_test.go index deff7f6..c8b4efc 100644 --- a/internal/mcp/file_test.go +++ b/internal/mcp/file_test.go @@ -16,14 +16,14 @@ func TestReadFile_Success(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } // Create test file testFile := filepath.Join(h.tempDir, "test.txt") testContent := "hello world" if err := os.WriteFile(testFile, []byte(testContent), 0o600); err != nil { - t.Fatalf("failed to create test file: %v", err) + t.Fatalf("creating test file: %v", err) } // Call ReadFile handler @@ -32,7 +32,7 @@ func TestReadFile_Success(t *testing.T) { }) if err != nil { - t.Fatalf("ReadFile failed: %v", err) + t.Fatalf("ReadFile(): %v", err) } if result.IsError { @@ -51,7 +51,7 @@ func TestReadFile_FileNotFound(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } // Call ReadFile with non-existent file @@ -74,7 +74,7 @@ func TestReadFile_SecurityViolation(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } // Try to read file outside allowed directory @@ -97,7 +97,7 @@ func TestWriteFile_Success(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } testFile := filepath.Join(h.tempDir, "write_test.txt") @@ -109,7 +109,7 @@ func TestWriteFile_Success(t *testing.T) { }) if err != nil { - t.Fatalf("WriteFile failed: %v", err) + t.Fatalf("WriteFile(): %v", err) } if result.IsError { @@ -119,7 +119,7 @@ func TestWriteFile_Success(t *testing.T) { // Verify file was created content, err := os.ReadFile(testFile) if err != nil { - t.Fatalf("failed to read written file: %v", err) + t.Fatalf("reading written file: %v", err) } if string(content) != testContent { @@ -133,7 +133,7 @@ func TestWriteFile_CreatesDirectory(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } // Write to a file in a non-existent subdirectory @@ -145,7 +145,7 @@ func TestWriteFile_CreatesDirectory(t *testing.T) { }) if err != nil { - t.Fatalf("WriteFile failed: %v", err) + t.Fatalf("WriteFile(): %v", err) } if result.IsError { @@ -164,19 +164,19 @@ func TestListFiles_Success(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } // Create some test files for _, name := range []string{"file1.txt", "file2.txt"} { if err := os.WriteFile(filepath.Join(h.tempDir, name), []byte("test"), 0o600); err != nil { - t.Fatalf("failed to create test file: %v", err) + t.Fatalf("creating test file: %v", err) } } // Create a subdirectory if err := os.Mkdir(filepath.Join(h.tempDir, "subdir"), 0o750); err != nil { - t.Fatalf("failed to create subdirectory: %v", err) + t.Fatalf("creating subdirectory: %v", err) } result, _, err := server.ListFiles(context.Background(), &mcp.CallToolRequest{}, tools.ListFilesInput{ @@ -184,7 +184,7 @@ func TestListFiles_Success(t *testing.T) { }) if err != nil { - t.Fatalf("ListFiles failed: %v", err) + t.Fatalf("ListFiles(): %v", err) } if result.IsError { @@ -198,7 +198,7 @@ func TestListFiles_DirectoryNotFound(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } result, _, err := server.ListFiles(context.Background(), &mcp.CallToolRequest{}, tools.ListFilesInput{ @@ -220,13 +220,13 @@ func TestDeleteFile_Success(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } // Create test file testFile := filepath.Join(h.tempDir, "to_delete.txt") if err := os.WriteFile(testFile, []byte("delete me"), 0o600); err != nil { - t.Fatalf("failed to create test file: %v", err) + t.Fatalf("creating test file: %v", err) } result, _, err := server.DeleteFile(context.Background(), &mcp.CallToolRequest{}, tools.DeleteFileInput{ @@ -234,7 +234,7 @@ func TestDeleteFile_Success(t *testing.T) { }) if err != nil { - t.Fatalf("DeleteFile failed: %v", err) + t.Fatalf("DeleteFile(): %v", err) } if result.IsError { @@ -253,7 +253,7 @@ func TestDeleteFile_FileNotFound(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } result, _, err := server.DeleteFile(context.Background(), &mcp.CallToolRequest{}, tools.DeleteFileInput{ @@ -275,14 +275,14 @@ func TestGetFileInfo_Success(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } // Create test file testFile := filepath.Join(h.tempDir, "info_test.txt") testContent := "test content for info" if err := os.WriteFile(testFile, []byte(testContent), 0o600); err != nil { - t.Fatalf("failed to create test file: %v", err) + t.Fatalf("creating test file: %v", err) } result, _, err := server.GetFileInfo(context.Background(), &mcp.CallToolRequest{}, tools.GetFileInfoInput{ @@ -290,7 +290,7 @@ func TestGetFileInfo_Success(t *testing.T) { }) if err != nil { - t.Fatalf("GetFileInfo failed: %v", err) + t.Fatalf("GetFileInfo(): %v", err) } if result.IsError { @@ -304,7 +304,7 @@ func TestGetFileInfo_FileNotFound(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } result, _, err := server.GetFileInfo(context.Background(), &mcp.CallToolRequest{}, tools.GetFileInfoInput{ diff --git a/internal/mcp/integration_test.go b/internal/mcp/integration_test.go index 1e0a42b..3d64418 100644 --- a/internal/mcp/integration_test.go +++ b/internal/mcp/integration_test.go @@ -31,35 +31,35 @@ func createIntegrationTestConfig(t *testing.T, name string) Config { t.Fatalf("security.NewPath(%q) unexpected error: %v", realTmpDir, err) } - fileTools, err := tools.NewFileTools(pathVal, slog.Default()) + file, err := tools.NewFile(pathVal, slog.Default()) if err != nil { - t.Fatalf("tools.NewFileTools() unexpected error: %v", err) + t.Fatalf("tools.NewFile() unexpected error: %v", err) } cmdVal := security.NewCommand() envVal := security.NewEnv() - systemTools, err := tools.NewSystemTools(cmdVal, envVal, slog.Default()) + system, err := tools.NewSystem(cmdVal, envVal, slog.Default()) if err != nil { - t.Fatalf("tools.NewSystemTools() unexpected error: %v", err) + t.Fatalf("tools.NewSystem() unexpected error: %v", err) } - networkCfg := tools.NetworkConfig{ + networkCfg := tools.NetConfig{ SearchBaseURL: "http://localhost:8080", FetchParallelism: 2, FetchDelay: 100 * time.Millisecond, FetchTimeout: 30 * time.Second, } - networkTools, err := tools.NewNetworkTools(networkCfg, slog.Default()) + network, err := tools.NewNetwork(networkCfg, slog.Default()) if err != nil { - t.Fatalf("tools.NewNetworkTools() unexpected error: %v", err) + t.Fatalf("tools.NewNetwork() unexpected error: %v", err) } return Config{ - Name: name, - Version: "1.0.0", - FileTools: fileTools, - SystemTools: systemTools, - NetworkTools: networkTools, + Name: name, + Version: "1.0.0", + File: file, + System: system, + Network: network, } } @@ -69,24 +69,29 @@ func createIntegrationTestConfig(t *testing.T, name string) Config { // Run with: go test -race ./internal/mcp/... func TestServer_ConcurrentCreation(t *testing.T) { const numGoroutines = 10 + + // Pre-create configs outside goroutines — createIntegrationTestConfig + // calls t.Fatalf which is undefined behavior from goroutines. + configs := make([]Config, numGoroutines) + for i := range configs { + configs[i] = createIntegrationTestConfig(t, "race-test-server") + } + var wg sync.WaitGroup errors := make(chan error, numGoroutines) servers := make(chan *Server, numGoroutines) for i := 0; i < numGoroutines; i++ { wg.Add(1) - go func(id int) { + go func(cfg Config) { defer wg.Done() - - cfg := createIntegrationTestConfig(t, "race-test-server") - server, err := NewServer(cfg) if err != nil { errors <- err return } servers <- server - }(i) + }(configs[i]) } wg.Wait() @@ -143,11 +148,9 @@ func TestServer_ConcurrentToolsetAccess(t *testing.T) { go func() { defer wg.Done() // Access toolset fields (read-only) - _ = server.fileTools - _ = server.systemTools - _ = server.networkTools - _ = server.name - _ = server.version + _ = server.file + _ = server.system + _ = server.network }() } @@ -171,36 +174,24 @@ func TestServer_RaceDetector(t *testing.T) { // Concurrent field access (read-only operations) for i := 0; i < numOps; i++ { - wg.Add(5) - - // Read name - go func() { - defer wg.Done() - _ = server.name - }() - - // Read version - go func() { - defer wg.Done() - _ = server.version - }() + wg.Add(3) // Read file tools go func() { defer wg.Done() - _ = server.fileTools + _ = server.file }() // Read system tools go func() { defer wg.Done() - _ = server.systemTools + _ = server.system }() // Read network tools go func() { defer wg.Done() - _ = server.networkTools + _ = server.network }() } @@ -215,14 +206,14 @@ func TestConfig_ConcurrentValidation(t *testing.T) { validCfg := createIntegrationTestConfig(t, "valid") configs := []Config{ - {Name: "server1", Version: "1.0.0", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools}, - {Name: "server2", Version: "2.0.0", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools}, - {Name: "", Version: "1.0.0", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools}, // Invalid: no name - {Name: "server3", Version: "", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools}, // Invalid: no version - {Name: "server4", Version: "1.0.0", FileTools: nil, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools}, // Invalid: no file tools - {Name: "server5", Version: "1.0.0", FileTools: validCfg.FileTools, SystemTools: nil, NetworkTools: validCfg.NetworkTools}, // Invalid: no system tools - {Name: "server6", Version: "1.0.0", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: nil}, // Invalid: no network tools - {Name: "server7", Version: "3.0.0", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools}, + {Name: "server1", Version: "1.0.0", File: validCfg.File, System: validCfg.System, Network: validCfg.Network}, + {Name: "server2", Version: "2.0.0", File: validCfg.File, System: validCfg.System, Network: validCfg.Network}, + {Name: "", Version: "1.0.0", File: validCfg.File, System: validCfg.System, Network: validCfg.Network}, // Invalid: no name + {Name: "server3", Version: "", File: validCfg.File, System: validCfg.System, Network: validCfg.Network}, // Invalid: no version + {Name: "server4", Version: "1.0.0", File: nil, System: validCfg.System, Network: validCfg.Network}, // Invalid: no file tools + {Name: "server5", Version: "1.0.0", File: validCfg.File, System: nil, Network: validCfg.Network}, // Invalid: no system tools + {Name: "server6", Version: "1.0.0", File: validCfg.File, System: validCfg.System, Network: nil}, // Invalid: no network tools + {Name: "server7", Version: "3.0.0", File: validCfg.File, System: validCfg.System, Network: validCfg.Network}, } var wg sync.WaitGroup diff --git a/internal/mcp/knowledge.go b/internal/mcp/knowledge.go index 853f7c6..f1e1428 100644 --- a/internal/mcp/knowledge.go +++ b/internal/mcp/knowledge.go @@ -10,90 +10,104 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -// registerKnowledgeTools registers all knowledge tools to the MCP server. +// registerKnowledge registers all knowledge tools to the MCP server. // Tools: search_history, search_documents, search_system_knowledge, knowledge_store -func (s *Server) registerKnowledgeTools() error { +func (s *Server) registerKnowledge() error { searchSchema, err := jsonschema.For[tools.KnowledgeSearchInput](nil) if err != nil { return fmt.Errorf("schema for knowledge search tools: %w", err) } mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolSearchHistory, + Name: tools.SearchHistoryName, Description: "Search conversation history using semantic similarity. " + - "Finds past exchanges related to the query.", + "Finds past exchanges that are conceptually related to the query. " + + "Returns: matched conversation turns with timestamps and similarity scores. " + + "Use this to: recall past discussions, find context from earlier conversations. " + + "Default topK: 3. Maximum topK: 10.", InputSchema: searchSchema, }, s.SearchHistory) mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolSearchDocuments, + Name: tools.SearchDocumentsName, Description: "Search indexed documents (PDFs, code files, notes) using semantic similarity. " + - "Finds document sections related to the query.", + "Finds document sections that are conceptually related to the query. " + + "Returns: document titles, content excerpts, and similarity scores. " + + "Use this to: find relevant documentation, locate code examples, research topics. " + + "Default topK: 5. Maximum topK: 10.", InputSchema: searchSchema, }, s.SearchDocuments) mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolSearchSystemKnowledge, + Name: tools.SearchSystemKnowledgeName, Description: "Search system knowledge base (tool usage, commands, patterns) using semantic similarity. " + - "Finds internal system documentation and usage patterns.", + "Finds internal system documentation and usage patterns. " + + "Returns: knowledge entries with descriptions and examples. " + + "Use this to: understand tool capabilities, find command syntax, learn system patterns. " + + "Default topK: 3. Maximum topK: 10.", InputSchema: searchSchema, }, s.SearchSystemKnowledge) - storeSchema, err := jsonschema.For[tools.KnowledgeStoreInput](nil) - if err != nil { - return fmt.Errorf("schema for knowledge store tool: %w", err) + // Register knowledge_store only when DocStore is available. + if s.knowledge.HasDocStore() { + storeSchema, err := jsonschema.For[tools.KnowledgeStoreInput](nil) + if err != nil { + return fmt.Errorf("schema for knowledge store tool: %w", err) + } + + mcp.AddTool(s.mcpServer, &mcp.Tool{ + Name: tools.StoreKnowledgeName, + Description: "Store a knowledge entry for later retrieval via search_documents. " + + "Use this to save important information, notes, or learnings " + + "that the user wants to remember across sessions. " + + "Each entry gets a unique ID and is indexed for semantic search.", + InputSchema: storeSchema, + }, s.StoreKnowledge) } - mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolStoreKnowledge, - Description: "Store a knowledge entry for later retrieval via search_documents. " + - "Saves important information, notes, or learnings across sessions.", - InputSchema: storeSchema, - }, s.StoreKnowledge) - return nil } // SearchHistory handles the search_history MCP tool call. func (s *Server) SearchHistory(ctx context.Context, _ *mcp.CallToolRequest, input tools.KnowledgeSearchInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.knowledgeTools.SearchHistory(toolCtx, input) + result, err := s.knowledge.SearchHistory(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("searchHistory failed: %w", err) + return nil, nil, fmt.Errorf("searching history: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } // SearchDocuments handles the search_documents MCP tool call. func (s *Server) SearchDocuments(ctx context.Context, _ *mcp.CallToolRequest, input tools.KnowledgeSearchInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.knowledgeTools.SearchDocuments(toolCtx, input) + result, err := s.knowledge.SearchDocuments(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("searchDocuments failed: %w", err) + return nil, nil, fmt.Errorf("searching documents: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } // SearchSystemKnowledge handles the search_system_knowledge MCP tool call. func (s *Server) SearchSystemKnowledge(ctx context.Context, _ *mcp.CallToolRequest, input tools.KnowledgeSearchInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.knowledgeTools.SearchSystemKnowledge(toolCtx, input) + result, err := s.knowledge.SearchSystemKnowledge(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("searchSystemKnowledge failed: %w", err) + return nil, nil, fmt.Errorf("searching system knowledge: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } // StoreKnowledge handles the knowledge_store MCP tool call. func (s *Server) StoreKnowledge(ctx context.Context, _ *mcp.CallToolRequest, input tools.KnowledgeStoreInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.knowledgeTools.StoreKnowledge(toolCtx, input) + result, err := s.knowledge.StoreKnowledge(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("storeKnowledge failed: %w", err) + return nil, nil, fmt.Errorf("storing knowledge: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } diff --git a/internal/mcp/network.go b/internal/mcp/network.go index b475248..dcd8276 100644 --- a/internal/mcp/network.go +++ b/internal/mcp/network.go @@ -10,28 +10,34 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -// registerNetworkTools registers all network operation tools to the MCP server. +// registerNetwork registers all network operation tools to the MCP server. // Tools: web_search, web_fetch -func (s *Server) registerNetworkTools() error { +func (s *Server) registerNetwork() error { // web_search searchSchema, err := jsonschema.For[tools.SearchInput](nil) if err != nil { - return fmt.Errorf("schema for %s: %w", tools.ToolWebSearch, err) + return fmt.Errorf("schema for %s: %w", tools.WebSearchName, err) } mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolWebSearch, - Description: "Search the web for information. Returns relevant results with titles, URLs, and content snippets.", + Name: tools.WebSearchName, + Description: "Search the web for information. Returns relevant results with titles, URLs, and content snippets. " + + "Use this to find current information, news, or facts from the internet.", InputSchema: searchSchema, }, s.WebSearch) // web_fetch fetchSchema, err := jsonschema.For[tools.FetchInput](nil) if err != nil { - return fmt.Errorf("schema for %s: %w", tools.ToolWebFetch, err) + return fmt.Errorf("schema for %s: %w", tools.WebFetchName, err) } mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolWebFetch, - Description: "Fetch and extract content from one or more URLs (max 10). Supports HTML, JSON, and plain text.", + Name: tools.WebFetchName, + Description: "Fetch and extract content from one or more URLs (max 10). " + + "Supports HTML pages, JSON APIs, and plain text. " + + "For HTML: uses Readability algorithm to extract main content. " + + "Supports parallel fetching with rate limiting. " + + "Returns extracted content (max 50KB per URL). " + + "Note: Does not render JavaScript - for SPA pages, content may be incomplete.", InputSchema: fetchSchema, }, s.WebFetch) @@ -39,28 +45,22 @@ func (s *Server) registerNetworkTools() error { } // WebSearch handles the web_search MCP tool call. -// Architecture: Direct method call (consistent with file.go and system.go). func (s *Server) WebSearch(ctx context.Context, _ *mcp.CallToolRequest, input tools.SearchInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - - // Direct method call - O(1), consistent with FileToolset.ReadFile() pattern - result, err := s.networkTools.Search(toolCtx, input) + result, err := s.network.Search(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("web_search failed: %w", err) + return nil, nil, fmt.Errorf("searching web: %w", err) } return dataToMCP(result), nil, nil } // WebFetch handles the web_fetch MCP tool call. -// Architecture: Direct method call (consistent with file.go and system.go). func (s *Server) WebFetch(ctx context.Context, _ *mcp.CallToolRequest, input tools.FetchInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - - // Direct method call - O(1), consistent with FileToolset.ReadFile() pattern - result, err := s.networkTools.Fetch(toolCtx, input) + result, err := s.network.Fetch(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("web_fetch failed: %w", err) + return nil, nil, fmt.Errorf("fetching web: %w", err) } return dataToMCP(result), nil, nil diff --git a/internal/mcp/protocol_test.go b/internal/mcp/protocol_test.go index c03885e..667fcdd 100644 --- a/internal/mcp/protocol_test.go +++ b/internal/mcp/protocol_test.go @@ -10,15 +10,12 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -// connectTestServer creates a Koopa MCP server and an SDK client connected -// via in-memory transports. Returns the client session for making protocol calls. -// The server session and client session are cleaned up via t.Cleanup. -func connectTestServer(t *testing.T) *mcp.ClientSession { +// connectServer creates a Koopa MCP server from the given config and an SDK +// client connected via in-memory transports. Returns the client session for +// making protocol calls. Both sessions are cleaned up via t.Cleanup. +func connectServer(t *testing.T, cfg Config) *mcp.ClientSession { t.Helper() - h := newTestHelper(t) - cfg := h.createValidConfig() - server, err := NewServer(cfg) if err != nil { t.Fatalf("NewServer() unexpected error: %v", err) @@ -27,14 +24,12 @@ func connectTestServer(t *testing.T) *mcp.ClientSession { ctx := context.Background() serverTransport, clientTransport := mcp.NewInMemoryTransports() - // Connect server side serverSession, err := server.mcpServer.Connect(ctx, serverTransport, nil) if err != nil { t.Fatalf("server.Connect() unexpected error: %v", err) } t.Cleanup(func() { _ = serverSession.Close() }) - // Connect client side client := mcp.NewClient(&mcp.Implementation{ Name: "test-client", Version: "1.0.0", @@ -49,6 +44,22 @@ func connectTestServer(t *testing.T) *mcp.ClientSession { return clientSession } +// connectTestServer creates a Koopa MCP server without knowledge tools +// and an SDK client connected via in-memory transports. +func connectTestServer(t *testing.T) *mcp.ClientSession { + t.Helper() + h := newTestHelper(t) + return connectServer(t, h.createValidConfig()) +} + +// connectTestServerWithKnowledge creates a Koopa MCP server including +// knowledge tools (backed by a mock retriever) and an SDK client. +func connectTestServerWithKnowledge(t *testing.T) *mcp.ClientSession { + t.Helper() + h := newTestHelper(t) + return connectServer(t, h.createConfigWithKnowledge()) +} + // TestProtocol_ListTools verifies that the MCP JSON-RPC tools/list // endpoint returns all registered tools with correct names. func TestProtocol_ListTools(t *testing.T) { @@ -137,7 +148,7 @@ func TestProtocol_CallTool_CurrentTime(t *testing.T) { // Parse the JSON text result (contains mixed types: strings and numbers) var timeResult map[string]any if err := json.Unmarshal([]byte(textContent.Text), &timeResult); err != nil { - t.Fatalf("CallTool(current_time) failed to parse JSON: %v\ntext: %s", err, textContent.Text) + t.Fatalf("CallTool(current_time) parsing JSON: %v\ntext: %s", err, textContent.Text) } // Should contain time fields @@ -164,3 +175,127 @@ func TestProtocol_CallTool_UnknownTool(t *testing.T) { t.Errorf("CallTool(nonexistent_tool) error = %q, want to contain tool name", err.Error()) } } + +// TestProtocol_ListTools_WithKnowledge verifies that knowledge search tools +// are registered when Knowledge is configured. knowledge_store is excluded +// because the test helper creates Knowledge with docStore=nil. +func TestProtocol_ListTools_WithKnowledge(t *testing.T) { + session := connectTestServerWithKnowledge(t) + + result, err := session.ListTools(context.Background(), nil) + if err != nil { + t.Fatalf("ListTools() unexpected error: %v", err) + } + + var names []string + for _, tool := range result.Tools { + names = append(names, tool.Name) + } + sort.Strings(names) + + // 10 base + 3 knowledge search tools (knowledge_store excluded: docStore=nil) + wantNames := []string{ + "current_time", + "delete_file", + "execute_command", + "get_env", + "get_file_info", + "list_files", + "read_file", + "search_documents", + "search_history", + "search_system_knowledge", + "web_fetch", + "web_search", + "write_file", + } + + if len(names) != len(wantNames) { + t.Fatalf("ListTools() returned %d tools, want %d\ngot: %v\nwant: %v", len(names), len(wantNames), names, wantNames) + } + + for i, got := range names { + if got != wantNames[i] { + t.Errorf("ListTools() tool[%d] = %q, want %q", i, got, wantNames[i]) + } + } +} + +// TestProtocol_CallTool_KnowledgeSearch verifies that each knowledge search +// tool can be called through the MCP JSON-RPC layer and returns results. +func TestProtocol_CallTool_KnowledgeSearch(t *testing.T) { + session := connectTestServerWithKnowledge(t) + + tests := []struct { + name string + toolName string + }{ + {name: "search_history", toolName: "search_history"}, + {name: "search_documents", toolName: "search_documents"}, + {name: "search_system_knowledge", toolName: "search_system_knowledge"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := session.CallTool(context.Background(), &mcp.CallToolParams{ + Name: tt.toolName, + Arguments: map[string]any{ + "query": "test query", + "topK": 3, + }, + }) + if err != nil { + t.Fatalf("CallTool(%q) unexpected error: %v", tt.toolName, err) + } + + if result.IsError { + t.Fatalf("CallTool(%q) returned error result", tt.toolName) + } + + if len(result.Content) == 0 { + t.Fatalf("CallTool(%q) returned empty content", tt.toolName) + } + + textContent, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatalf("CallTool(%q) content[0] type = %T, want *mcp.TextContent", tt.toolName, result.Content[0]) + } + + // Parse JSON and verify result structure + var parsed map[string]any + if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil { + t.Fatalf("CallTool(%q) parsing JSON: %v\ntext: %s", tt.toolName, err, textContent.Text) + } + + if parsed["query"] != "test query" { + t.Errorf("CallTool(%q) query = %v, want %q", tt.toolName, parsed["query"], "test query") + } + + // mock retriever returns 1 document + if count, ok := parsed["result_count"].(float64); !ok || count != 1 { + t.Errorf("CallTool(%q) result_count = %v, want 1", tt.toolName, parsed["result_count"]) + } + }) + } +} + +// TestProtocol_CallTool_KnowledgeStore_NoDocStore verifies that +// knowledge_store is not registered when docStore is nil. +func TestProtocol_CallTool_KnowledgeStore_NoDocStore(t *testing.T) { + session := connectTestServerWithKnowledge(t) + + _, err := session.CallTool(context.Background(), &mcp.CallToolParams{ + Name: "knowledge_store", + Arguments: map[string]any{ + "title": "test", + "content": "test content", + }, + }) + if err == nil { + t.Fatal("CallTool(knowledge_store) expected error for unregistered tool, got nil") + } + + if !strings.Contains(err.Error(), "knowledge_store") { + t.Errorf("CallTool(knowledge_store) error = %q, want to contain tool name", err.Error()) + } +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 164015c..8e5dd7a 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -2,7 +2,9 @@ package mcp import ( "context" + "errors" "fmt" + "log/slog" "github.com/koopa0/koopa/internal/tools" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -11,42 +13,42 @@ import ( // Server wraps the MCP SDK server and Koopa's tool handlers. // It exposes Koopa's tools via the Model Context Protocol. type Server struct { - mcpServer *mcp.Server - fileTools *tools.FileTools - systemTools *tools.SystemTools - networkTools *tools.NetworkTools - knowledgeTools *tools.KnowledgeTools // nil when knowledge search is unavailable - name string - version string + mcpServer *mcp.Server + logger *slog.Logger + file *tools.File + system *tools.System + network *tools.Network + knowledge *tools.Knowledge // nil when knowledge search is unavailable } // Config holds MCP server configuration. type Config struct { - Name string - Version string - FileTools *tools.FileTools - SystemTools *tools.SystemTools - NetworkTools *tools.NetworkTools - KnowledgeTools *tools.KnowledgeTools // Optional: nil disables knowledge search tools + Name string + Version string + Logger *slog.Logger // Optional: nil uses slog.Default() + File *tools.File + System *tools.System + Network *tools.Network + Knowledge *tools.Knowledge // Optional: nil disables knowledge search tools } // NewServer creates a new MCP server with the given configuration. func NewServer(cfg Config) (*Server, error) { // Validate required config if cfg.Name == "" { - return nil, fmt.Errorf("server name is required") + return nil, errors.New("server name is required") } if cfg.Version == "" { - return nil, fmt.Errorf("server version is required") + return nil, errors.New("server version is required") } - if cfg.FileTools == nil { - return nil, fmt.Errorf("file tools is required") + if cfg.File == nil { + return nil, errors.New("file tools is required") } - if cfg.SystemTools == nil { - return nil, fmt.Errorf("system tools is required") + if cfg.System == nil { + return nil, errors.New("system tools is required") } - if cfg.NetworkTools == nil { - return nil, fmt.Errorf("network tools is required") + if cfg.Network == nil { + return nil, errors.New("network tools is required") } // Create MCP server (using official SDK) @@ -55,19 +57,23 @@ func NewServer(cfg Config) (*Server, error) { Version: cfg.Version, }, nil) + logger := cfg.Logger + if logger == nil { + logger = slog.Default() + } + s := &Server{ - mcpServer: mcpServer, - fileTools: cfg.FileTools, - systemTools: cfg.SystemTools, - networkTools: cfg.NetworkTools, - knowledgeTools: cfg.KnowledgeTools, - name: cfg.Name, - version: cfg.Version, + mcpServer: mcpServer, + logger: logger, + file: cfg.File, + system: cfg.System, + network: cfg.Network, + knowledge: cfg.Knowledge, } // Register all tools if err := s.registerTools(); err != nil { - return nil, fmt.Errorf("failed to register tools: %w", err) + return nil, fmt.Errorf("registering tools: %w", err) } return s, nil @@ -77,28 +83,28 @@ func NewServer(cfg Config) (*Server, error) { // This is a blocking call that handles all MCP protocol communication. func (s *Server) Run(ctx context.Context, transport mcp.Transport) error { if err := s.mcpServer.Run(ctx, transport); err != nil { - return fmt.Errorf("MCP server run failed: %w", err) + return fmt.Errorf("running mcp server: %w", err) } return nil } // registerTools registers all Toolset tools to the MCP server. func (s *Server) registerTools() error { - if err := s.registerFileTools(); err != nil { + if err := s.registerFile(); err != nil { return fmt.Errorf("register file tools: %w", err) } - if err := s.registerSystemTools(); err != nil { + if err := s.registerSystem(); err != nil { return fmt.Errorf("register system tools: %w", err) } - if err := s.registerNetworkTools(); err != nil { + if err := s.registerNetwork(); err != nil { return fmt.Errorf("register network tools: %w", err) } // Knowledge tools are optional (require DB + embedder) - if s.knowledgeTools != nil { - if err := s.registerKnowledgeTools(); err != nil { + if s.knowledge != nil { + if err := s.registerKnowledge(); err != nil { return fmt.Errorf("register knowledge tools: %w", err) } } diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 7a77b96..d066f45 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -1,17 +1,36 @@ package mcp import ( + "context" "log/slog" "net/http" "net/http/httptest" "path/filepath" + "strings" "testing" "time" + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "github.com/koopa0/koopa/internal/security" "github.com/koopa0/koopa/internal/tools" ) +// mcpTestRetriever implements ai.Retriever for MCP protocol tests. +// Returns a single mock document so tests can verify result structure. +type mcpTestRetriever struct{} + +func (*mcpTestRetriever) Name() string { return "mock-retriever" } +func (*mcpTestRetriever) Retrieve(_ context.Context, _ *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { + return &ai.RetrieverResponse{ + Documents: []*ai.Document{ + ai.DocumentFromText("mock result for protocol test", nil), + }, + }, nil +} +func (*mcpTestRetriever) Register(_ api.Registry) {} + // testHelper provides common test utilities. type testHelper struct { t *testing.T @@ -24,7 +43,7 @@ func newTestHelper(t *testing.T) *testHelper { tempDir := t.TempDir() realTempDir, err := filepath.EvalSymlinks(tempDir) if err != nil { - t.Fatalf("failed to resolve temp dir symlinks: %v", err) + t.Fatalf("resolving temp dir symlinks: %v", err) } return &testHelper{ t: t, @@ -32,33 +51,33 @@ func newTestHelper(t *testing.T) *testHelper { } } -func (h *testHelper) createFileTools() *tools.FileTools { +func (h *testHelper) createFile() *tools.File { h.t.Helper() pathVal, err := security.NewPath([]string{h.tempDir}) if err != nil { - h.t.Fatalf("failed to create path validator: %v", err) + h.t.Fatalf("creating path validator: %v", err) } - ft, err := tools.NewFileTools(pathVal, slog.Default()) + ft, err := tools.NewFile(pathVal, slog.Default()) if err != nil { - h.t.Fatalf("failed to create file tools: %v", err) + h.t.Fatalf("creating file tools: %v", err) } return ft } -func (h *testHelper) createSystemTools() *tools.SystemTools { +func (h *testHelper) createSystem() *tools.System { h.t.Helper() cmdVal := security.NewCommand() envVal := security.NewEnv() - st, err := tools.NewSystemTools(cmdVal, envVal, slog.Default()) + st, err := tools.NewSystem(cmdVal, envVal, slog.Default()) if err != nil { - h.t.Fatalf("failed to create system tools: %v", err) + h.t.Fatalf("creating system tools: %v", err) } return st } -func (h *testHelper) createNetworkTools() *tools.NetworkTools { +func (h *testHelper) createNetwork() *tools.Network { h.t.Helper() // Use httptest.NewServer instead of hardcoded localhost URL @@ -72,8 +91,8 @@ func (h *testHelper) createNetworkTools() *tools.NetworkTools { // SSRF protection is only checked during Fetch(), not at construction. // These tests only construct tools for Config, they don't execute network operations. - nt, err := tools.NewNetworkTools( - tools.NetworkConfig{ + nt, err := tools.NewNetwork( + tools.NetConfig{ SearchBaseURL: mockServer.URL, FetchParallelism: 2, FetchDelay: 100 * time.Millisecond, @@ -82,22 +101,38 @@ func (h *testHelper) createNetworkTools() *tools.NetworkTools { slog.Default(), ) if err != nil { - h.t.Fatalf("failed to create network tools: %v", err) + h.t.Fatalf("creating network tools: %v", err) } return nt } +func (h *testHelper) createKnowledge() *tools.Knowledge { + h.t.Helper() + kt, err := tools.NewKnowledge(&mcpTestRetriever{}, nil, slog.New(slog.DiscardHandler)) + if err != nil { + h.t.Fatalf("creating knowledge tools: %v", err) + } + return kt +} + func (h *testHelper) createValidConfig() Config { h.t.Helper() return Config{ - Name: "test-server", - Version: "1.0.0", - FileTools: h.createFileTools(), - SystemTools: h.createSystemTools(), - NetworkTools: h.createNetworkTools(), + Name: "test-server", + Version: "1.0.0", + File: h.createFile(), + System: h.createSystem(), + Network: h.createNetwork(), } } +func (h *testHelper) createConfigWithKnowledge() Config { + h.t.Helper() + cfg := h.createValidConfig() + cfg.Knowledge = h.createKnowledge() + return cfg +} + // TestNewServer_Success tests successful server creation with all tools. func TestNewServer_Success(t *testing.T) { h := newTestHelper(t) @@ -105,41 +140,32 @@ func TestNewServer_Success(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) - } - - // Verify server fields are correctly set - if server.name != "test-server" { - t.Errorf("server.name = %q, want %q", server.name, "test-server") - } - - if server.version != "1.0.0" { - t.Errorf("server.version = %q, want %q", server.version, "1.0.0") + t.Fatalf("NewServer(): %v", err) } if server.mcpServer == nil { t.Error("server.mcpServer is nil") } - if server.fileTools == nil { - t.Error("server.fileTools is nil") + if server.file == nil { + t.Error("server.file is nil") } - if server.systemTools == nil { - t.Error("server.systemTools is nil") + if server.system == nil { + t.Error("server.system is nil") } - if server.networkTools == nil { - t.Error("server.networkTools is nil") + if server.network == nil { + t.Error("server.network is nil") } } // TestNewServer_ValidationErrors tests config validation. func TestNewServer_ValidationErrors(t *testing.T) { h := newTestHelper(t) - validFile := h.createFileTools() - validSystem := h.createSystemTools() - validNetwork := h.createNetworkTools() + validFile := h.createFile() + validSystem := h.createSystem() + validNetwork := h.createNetwork() tests := []struct { name string @@ -149,50 +175,50 @@ func TestNewServer_ValidationErrors(t *testing.T) { { name: "missing name", config: Config{ - Version: "1.0.0", - FileTools: validFile, - SystemTools: validSystem, - NetworkTools: validNetwork, + Version: "1.0.0", + File: validFile, + System: validSystem, + Network: validNetwork, }, wantErr: "server name is required", }, { name: "missing version", config: Config{ - Name: "test", - FileTools: validFile, - SystemTools: validSystem, - NetworkTools: validNetwork, + Name: "test", + File: validFile, + System: validSystem, + Network: validNetwork, }, wantErr: "server version is required", }, { name: "missing file tools", config: Config{ - Name: "test", - Version: "1.0.0", - SystemTools: validSystem, - NetworkTools: validNetwork, + Name: "test", + Version: "1.0.0", + System: validSystem, + Network: validNetwork, }, wantErr: "file tools is required", }, { name: "missing system tools", config: Config{ - Name: "test", - Version: "1.0.0", - FileTools: validFile, - NetworkTools: validNetwork, + Name: "test", + Version: "1.0.0", + File: validFile, + Network: validNetwork, }, wantErr: "system tools is required", }, { name: "missing network tools", config: Config{ - Name: "test", - Version: "1.0.0", - FileTools: validFile, - SystemTools: validSystem, + Name: "test", + Version: "1.0.0", + File: validFile, + System: validSystem, }, wantErr: "network tools is required", }, @@ -204,42 +230,9 @@ func TestNewServer_ValidationErrors(t *testing.T) { if err == nil { t.Fatal("NewServer succeeded, want error") } - if !contains(err.Error(), tt.wantErr) { + if !strings.Contains(err.Error(), tt.wantErr) { t.Errorf("NewServer error = %q, want to contain %q", err.Error(), tt.wantErr) } }) } } - -// TestRegisterTools_AllToolsRegistered verifies all 10 tools are registered. -func TestRegisterTools_AllToolsRegistered(t *testing.T) { - h := newTestHelper(t) - cfg := h.createValidConfig() - - server, err := NewServer(cfg) - if err != nil { - t.Fatalf("NewServer failed: %v", err) - } - - // Verify server was created successfully (tools registered in constructor) - if server.mcpServer == nil { - t.Fatal("mcpServer is nil") - } - - // Note: We can't directly verify tool registration without accessing - // internal MCP server state. The fact that NewServer succeeded without - // error means registerTools() completed successfully for all 10 tools: - // - File: read_file, write_file, list_files, delete_file, get_file_info - // - System: current_time, execute_command, get_env - // - Network: web_search, web_fetch -} - -// contains checks if s contains substr. -func contains(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/internal/mcp/system.go b/internal/mcp/system.go index ac6bdd3..c74a450 100644 --- a/internal/mcp/system.go +++ b/internal/mcp/system.go @@ -10,39 +10,48 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -// registerSystemTools registers all system operation tools to the MCP server. +// registerSystem registers all system operation tools to the MCP server. // Tools: current_time, execute_command, get_env -func (s *Server) registerSystemTools() error { +func (s *Server) registerSystem() error { // current_time currentTimeSchema, err := jsonschema.For[tools.CurrentTimeInput](nil) if err != nil { - return fmt.Errorf("schema for %s: %w", tools.ToolCurrentTime, err) + return fmt.Errorf("schema for %s: %w", tools.CurrentTimeName, err) } mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolCurrentTime, - Description: "Get the current system date and time in formatted string.", + Name: tools.CurrentTimeName, + Description: "Get the current system date and time. " + + "Returns: formatted time string, Unix timestamp, and ISO 8601 format. " + + "Always returns the server's local time zone.", InputSchema: currentTimeSchema, }, s.CurrentTime) // execute_command executeCommandSchema, err := jsonschema.For[tools.ExecuteCommandInput](nil) if err != nil { - return fmt.Errorf("schema for %s: %w", tools.ToolExecuteCommand, err) + return fmt.Errorf("schema for %s: %w", tools.ExecuteCommandName, err) } mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolExecuteCommand, - Description: "Execute a shell command with security validation. Dangerous commands (rm -rf, sudo, etc.) are blocked.", + Name: tools.ExecuteCommandName, + Description: "Execute a shell command from the allowed list with security validation. " + + "Allowed commands: git, npm, yarn, go, make, docker, kubectl, ls, cat, grep, find, pwd, echo. " + + "Commands run with a timeout to prevent hanging. " + + "Returns: stdout, stderr, exit code, and execution time. " + + "Security: Dangerous commands (rm -rf, sudo, chmod, etc.) are blocked.", InputSchema: executeCommandSchema, }, s.ExecuteCommand) // get_env getEnvSchema, err := jsonschema.For[tools.GetEnvInput](nil) if err != nil { - return fmt.Errorf("schema for %s: %w", tools.ToolGetEnv, err) + return fmt.Errorf("schema for %s: %w", tools.GetEnvName, err) } mcp.AddTool(s.mcpServer, &mcp.Tool{ - Name: tools.ToolGetEnv, - Description: "Read an environment variable value. Sensitive variables (*KEY*, *SECRET*, *TOKEN*) are protected.", + Name: tools.GetEnvName, + Description: "Read an environment variable value from the system. " + + "Returns: the variable name and its value. " + + "Use this to: check configuration, verify paths, read non-sensitive settings. " + + "Security: Sensitive variables containing KEY, SECRET, TOKEN, or PASSWORD in their names are protected and will not be returned.", InputSchema: getEnvSchema, }, s.GetEnv) @@ -52,30 +61,30 @@ func (s *Server) registerSystemTools() error { // CurrentTime handles the currentTime MCP tool call. func (s *Server) CurrentTime(ctx context.Context, _ *mcp.CallToolRequest, input tools.CurrentTimeInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.systemTools.CurrentTime(toolCtx, input) + result, err := s.system.CurrentTime(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("currentTime failed: %w", err) + return nil, nil, fmt.Errorf("getting current time: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } // ExecuteCommand handles the executeCommand MCP tool call. func (s *Server) ExecuteCommand(ctx context.Context, _ *mcp.CallToolRequest, input tools.ExecuteCommandInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.systemTools.ExecuteCommand(toolCtx, input) + result, err := s.system.ExecuteCommand(toolCtx, input) if err != nil { // Only infrastructure errors (context cancellation) return Go error - return nil, nil, fmt.Errorf("executeCommand failed: %w", err) + return nil, nil, fmt.Errorf("executing command: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } // GetEnv handles the getEnv MCP tool call. func (s *Server) GetEnv(ctx context.Context, _ *mcp.CallToolRequest, input tools.GetEnvInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} - result, err := s.systemTools.GetEnv(toolCtx, input) + result, err := s.system.GetEnv(toolCtx, input) if err != nil { - return nil, nil, fmt.Errorf("getEnv failed: %w", err) + return nil, nil, fmt.Errorf("getting env: %w", err) } - return resultToMCP(result), nil, nil + return resultToMCP(result, s.logger), nil, nil } diff --git a/internal/mcp/system_test.go b/internal/mcp/system_test.go index 3db437e..73cd4f5 100644 --- a/internal/mcp/system_test.go +++ b/internal/mcp/system_test.go @@ -15,13 +15,13 @@ func TestCurrentTime_Success(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } result, _, err := server.CurrentTime(context.Background(), &mcp.CallToolRequest{}, tools.CurrentTimeInput{}) if err != nil { - t.Fatalf("CurrentTime failed: %v", err) + t.Fatalf("CurrentTime(): %v", err) } if result.IsError { @@ -60,7 +60,7 @@ func TestExecuteCommand_Success(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } result, _, err := server.ExecuteCommand(context.Background(), &mcp.CallToolRequest{}, tools.ExecuteCommandInput{ @@ -69,7 +69,7 @@ func TestExecuteCommand_Success(t *testing.T) { }) if err != nil { - t.Fatalf("ExecuteCommand failed: %v", err) + t.Fatalf("ExecuteCommand(): %v", err) } if result.IsError { @@ -97,7 +97,7 @@ func TestExecuteCommand_DangerousCommandBlocked(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } // Try to execute a dangerous command @@ -122,7 +122,7 @@ func TestExecuteCommand_CommandNotFound(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } result, _, err := server.ExecuteCommand(context.Background(), &mcp.CallToolRequest{}, tools.ExecuteCommandInput{ @@ -146,7 +146,7 @@ func TestGetEnv_Success(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } // Set a test environment variable using t.Setenv for automatic cleanup @@ -159,7 +159,7 @@ func TestGetEnv_Success(t *testing.T) { }) if err != nil { - t.Fatalf("GetEnv failed: %v", err) + t.Fatalf("GetEnv(): %v", err) } if result.IsError { @@ -187,7 +187,7 @@ func TestGetEnv_NotSet(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } result, _, err := server.GetEnv(context.Background(), &mcp.CallToolRequest{}, tools.GetEnvInput{ @@ -195,7 +195,7 @@ func TestGetEnv_NotSet(t *testing.T) { }) if err != nil { - t.Fatalf("GetEnv failed: %v", err) + t.Fatalf("GetEnv(): %v", err) } if result.IsError { @@ -224,7 +224,7 @@ func TestGetEnv_SensitiveVariableBlocked(t *testing.T) { server, err := NewServer(cfg) if err != nil { - t.Fatalf("NewServer failed: %v", err) + t.Fatalf("NewServer(): %v", err) } // Try to access a sensitive variable diff --git a/internal/mcp/util.go b/internal/mcp/util.go index e5a54a2..2356ecd 100644 --- a/internal/mcp/util.go +++ b/internal/mcp/util.go @@ -26,7 +26,12 @@ import ( // resultToMCP converts a tools.Result to mcp.CallToolResult. // This follows the Direct Inline Handling principle but extracts the common pattern. -func resultToMCP(result tools.Result) *mcp.CallToolResult { +// If logger is nil, falls back to slog.Default(). +func resultToMCP(result tools.Result, logger *slog.Logger) *mcp.CallToolResult { + if logger == nil { + logger = slog.Default() + } + if result.Status == tools.StatusError { errorText := fmt.Sprintf("[%s] %s", result.Error.Code, result.Error.Message) if result.Error.Details != nil { @@ -36,7 +41,7 @@ func resultToMCP(result tools.Result) *mcp.CallToolResult { detailsJSON, err := json.Marshal(sanitized) if err != nil { // Log internal error, don't expose to client - slog.Warn("failed to marshal sanitized error details", "error", err) + logger.Warn("marshaling sanitized error details", "error", err) errorText += "\nDetails: (see server logs)" } else { errorText += fmt.Sprintf("\nDetails: %s", string(detailsJSON)) @@ -44,7 +49,7 @@ func resultToMCP(result tools.Result) *mcp.CallToolResult { } // Always log full details server-side for debugging - slog.Debug("MCP error details", "details", result.Error.Details) + logger.Debug("MCP error details", "details", result.Error.Details) } return &mcp.CallToolResult{ diff --git a/internal/mcp/util_test.go b/internal/mcp/util_test.go index 5422290..f18be55 100644 --- a/internal/mcp/util_test.go +++ b/internal/mcp/util_test.go @@ -1,6 +1,7 @@ package mcp import ( + "strings" "testing" "github.com/koopa0/koopa/internal/tools" @@ -13,7 +14,7 @@ func TestResultToMCP_Success(t *testing.T) { Data: map[string]any{"result": "value", "count": 42}, } - mcpResult := resultToMCP(result) + mcpResult := resultToMCP(result, nil) if mcpResult.IsError { t.Error("resultToMCP should not set IsError for success status") @@ -29,7 +30,7 @@ func TestResultToMCP_Success(t *testing.T) { } // Data should be JSON marshaled - if !contains(textContent.Text, "result") || !contains(textContent.Text, "value") { + if !strings.Contains(textContent.Text, "result") || !strings.Contains(textContent.Text, "value") { t.Errorf("resultToMCP text should contain JSON data: %s", textContent.Text) } } @@ -43,7 +44,7 @@ func TestResultToMCP_Error(t *testing.T) { }, } - mcpResult := resultToMCP(result) + mcpResult := resultToMCP(result, nil) if !mcpResult.IsError { t.Error("resultToMCP should set IsError for error status") @@ -59,11 +60,11 @@ func TestResultToMCP_Error(t *testing.T) { } // Should contain error code and message - if !contains(textContent.Text, string(tools.ErrCodeNotFound)) { + if !strings.Contains(textContent.Text, string(tools.ErrCodeNotFound)) { t.Errorf("resultToMCP text should contain error code: %s", textContent.Text) } - if !contains(textContent.Text, "File not found") { + if !strings.Contains(textContent.Text, "File not found") { t.Errorf("resultToMCP text should contain error message: %s", textContent.Text) } } @@ -79,7 +80,7 @@ func TestResultToMCP_ErrorWithDetails(t *testing.T) { }, } - mcpResult := resultToMCP(result) + mcpResult := resultToMCP(result, nil) if !mcpResult.IsError { t.Error("resultToMCP should set IsError for error status") @@ -95,15 +96,11 @@ func TestResultToMCP_ErrorWithDetails(t *testing.T) { } // Should contain "Details:" with whitelisted fields - if !contains(textContent.Text, "Details:") { + if !strings.Contains(textContent.Text, "Details:") { t.Errorf("resultToMCP text should contain 'Details:': %s", textContent.Text) } } -// ============================================================================= -// dataToMCP Tests -// ============================================================================= - func TestDataToMCP_ValidData(t *testing.T) { data := map[string]any{"key": "value", "count": 42} result := dataToMCP(data) @@ -121,7 +118,7 @@ func TestDataToMCP_ValidData(t *testing.T) { t.Fatal("dataToMCP content is not TextContent") } - if !contains(textContent.Text, "key") || !contains(textContent.Text, "value") { + if !strings.Contains(textContent.Text, "key") || !strings.Contains(textContent.Text, "value") { t.Errorf("dataToMCP should contain JSON data: %s", textContent.Text) } } @@ -161,7 +158,7 @@ func TestDataToMCP_SliceData(t *testing.T) { t.Fatal("dataToMCP content is not TextContent") } - if !contains(textContent.Text, "item1") { + if !strings.Contains(textContent.Text, "item1") { t.Errorf("dataToMCP should contain JSON array: %s", textContent.Text) } } @@ -187,7 +184,7 @@ func TestDataToMCP_NestedStruct(t *testing.T) { t.Fatal("dataToMCP content is not TextContent") } - if !contains(textContent.Text, "test") || !contains(textContent.Text, "42") { + if !strings.Contains(textContent.Text, "test") || !strings.Contains(textContent.Text, "42") { t.Errorf("dataToMCP should contain nested JSON: %s", textContent.Text) } } @@ -222,7 +219,7 @@ func TestDataToMCP_ResultNilData(t *testing.T) { Data: nil, } - mcpResult := resultToMCP(result) + mcpResult := resultToMCP(result, nil) if mcpResult.IsError { t.Error("resultToMCP should not set IsError for success with nil data") @@ -238,10 +235,6 @@ func TestDataToMCP_ResultNilData(t *testing.T) { } } -// ============================================================================= -// sanitizeErrorDetails Tests -// ============================================================================= - func TestSanitizeErrorDetails(t *testing.T) { tests := []struct { name string diff --git a/internal/observability/datadog.go b/internal/observability/datadog.go deleted file mode 100644 index 60c0990..0000000 --- a/internal/observability/datadog.go +++ /dev/null @@ -1,172 +0,0 @@ -// Package observability provides OpenTelemetry integration for distributed tracing. -// -// # Architecture Decision: Datadog Agent Mode -// -// We use the Datadog Agent for OTLP ingestion instead of direct API endpoint. -// This decision was made because: -// -// - Direct OTLP Traces API is in Preview status (as of Nov 2025) -// - Agent provides better reliability with local buffering and retry -// - Lower latency (localhost vs internet roundtrip) -// - Agent handles authentication - no need to pass DD_API_KEY in app -// - Supports all Datadog features (metrics, logs, traces in one agent) -// -// # Prerequisites -// -// 1. Datadog Account with US5 region (or your region) -// 2. DD_API_KEY from https://us5.datadoghq.com → Organization Settings → API Keys -// -// # macOS Installation -// -// Install Datadog Agent: -// -// DD_API_KEY="your-key" DD_SITE="us5.datadoghq.com" \ -// bash -c "$(curl -L https://install.datadoghq.com/scripts/install_mac_os.sh)" -// -// # Enable OTLP Receiver -// -// Add to /opt/datadog-agent/etc/datadog.yaml (at the end of file): -// -// otlp_config: -// receiver: -// protocols: -// http: -// endpoint: "localhost:4318" -// traces: -// enabled: true -// span_name_as_resource_name: true -// -// # Restart Agent -// -// Option 1 - Using launchctl: -// -// sudo launchctl stop com.datadoghq.agent -// sudo launchctl start com.datadoghq.agent -// -// Option 2 - Kill and restart: -// -// sudo pkill -9 -f datadog -// sudo /opt/datadog-agent/bin/agent/agent run & -// -// # Option 3 - Use Datadog Agent GUI app -// -// # Verify OTLP is Enabled -// -// datadog-agent status | grep -A 5 "OTLP" -// -// Expected output: -// -// OTLP -// ==== -// Status: Enabled -// Collector status: Running -// -// # View Traces in Datadog -// -// After running koopa with tracing enabled: -// - Go to https://us5.datadoghq.com/apm/traces -// - Search for service:koopa or your configured service name -// - Traces appear within 1-2 minutes after app shutdown (flush) -// -// # Troubleshooting -// -// Agent not running: -// -// launchctl list | grep datadog # PID should not be "-" -// -// Check Agent logs: -// -// sudo tail -50 /var/log/datadog/agent.log -// -// Test OTLP endpoint: -// -// curl -v http://localhost:4318/v1/traces -// -// # Configuration -// -// Environment variables (optional): -// - DD_AGENT_HOST: Override agent host (default: localhost:4318) -// - DD_ENV: Environment tag (default: dev) -// - DD_SERVICE: Service name (default: koopa) -// -// Config file (~/.koopa/config.yaml): -// -// datadog: -// agent_host: "localhost:4318" -// environment: "dev" -// service_name: "koopa" -package observability - -import ( - "context" - "log/slog" - "os" - - "github.com/firebase/genkit/go/core/tracing" - "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" - sdktrace "go.opentelemetry.io/otel/sdk/trace" -) - -// Config for Datadog OTEL setup. -type Config struct { - // AgentHost is the Datadog Agent OTLP endpoint (default: localhost:4318) - AgentHost string - // Environment is the deployment environment (dev, staging, prod) - Environment string - // ServiceName is the service name shown in Datadog APM - ServiceName string -} - -// DefaultAgentHost is the default Datadog Agent OTLP HTTP endpoint. -const DefaultAgentHost = "localhost:4318" - -// SetupDatadog registers a Datadog Agent exporter with Genkit's TracerProvider. -// Traces are sent to the local Datadog Agent via OTLP HTTP protocol. -// -// Returns a shutdown function that flushes pending spans. -// If AgentHost is empty, uses DefaultAgentHost (localhost:4318). -func SetupDatadog(ctx context.Context, cfg Config) (shutdown func(context.Context) error, err error) { - agentHost := cfg.AgentHost - if agentHost == "" { - agentHost = DefaultAgentHost - } - - // Set OTEL env vars for Genkit's TracerProvider to pick up. - // SAFETY: os.Setenv is not concurrent-safe, but this function is called - // exactly once during startup in InitializeApp, before goroutines are spawned. - if cfg.ServiceName != "" { - _ = os.Setenv("OTEL_SERVICE_NAME", cfg.ServiceName) - } - if cfg.Environment != "" { - _ = os.Setenv("OTEL_RESOURCE_ATTRIBUTES", "deployment.environment="+cfg.Environment) - } - - // Create OTLP HTTP exporter pointing to local Datadog Agent - // Agent handles authentication and forwarding to Datadog backend - exporter, err := otlptracehttp.New(ctx, - otlptracehttp.WithEndpoint(agentHost), - otlptracehttp.WithInsecure(), // localhost doesn't need TLS - ) - if err != nil { - slog.Warn("failed to create datadog exporter, tracing disabled", "error", err) - return func(context.Context) error { return nil }, nil - } - - // Register BatchSpanProcessor with Genkit's TracerProvider - processor := sdktrace.NewBatchSpanProcessor(exporter) - tracing.TracerProvider().RegisterSpanProcessor(processor) - - slog.Debug("datadog tracing enabled", - "agent", agentHost, - "service", cfg.ServiceName, - "environment", cfg.Environment, - ) - - // Create a test span to verify the pipeline works - tracer := tracing.TracerProvider().Tracer("koopa-init") - _, span := tracer.Start(ctx, "koopa.init") - span.End() - slog.Debug("test span created for datadog verification") - - return tracing.TracerProvider().Shutdown, nil -} diff --git a/internal/observability/datadog_test.go b/internal/observability/datadog_test.go deleted file mode 100644 index b57937c..0000000 --- a/internal/observability/datadog_test.go +++ /dev/null @@ -1,115 +0,0 @@ -package observability - -import ( - "context" - "testing" -) - -func TestSetupDatadog_DefaultAgentHost(t *testing.T) { - // NOTE: not parallel — SetupDatadog modifies global state (os.Setenv, TracerProvider) - - cfg := Config{ - AgentHost: "", // Empty should use default - Environment: "test", - ServiceName: "test-service", - } - - ctx := context.Background() - shutdown, err := SetupDatadog(ctx, cfg) - - // Should not fail even with empty AgentHost - if err != nil { - t.Fatalf("SetupDatadog() unexpected error: %v", err) - } - if shutdown == nil { - t.Fatal("SetupDatadog() shutdown = nil, want non-nil") - } - - // Cleanup - if err := shutdown(ctx); err != nil { - t.Errorf("shutdown() unexpected error: %v", err) - } -} - -func TestSetupDatadog_CustomAgentHost(t *testing.T) { - // NOTE: not parallel — SetupDatadog modifies global state (os.Setenv, TracerProvider) - - cfg := Config{ - AgentHost: "custom-host:4318", - Environment: "staging", - ServiceName: "custom-service", - } - - ctx := context.Background() - shutdown, err := SetupDatadog(ctx, cfg) - - // Should not fail with custom host - if err != nil { - t.Fatalf("SetupDatadog() unexpected error: %v", err) - } - if shutdown == nil { - t.Fatal("SetupDatadog() shutdown = nil, want non-nil") - } - - // Cleanup - if err := shutdown(ctx); err != nil { - t.Errorf("shutdown() unexpected error: %v", err) - } -} - -func TestSetupDatadog_AgentUnavailable_GracefulDegradation(t *testing.T) { - // NOTE: not parallel — SetupDatadog modifies global state (os.Setenv, TracerProvider) - - // Point to a non-existent agent - cfg := Config{ - AgentHost: "localhost:99999", // Invalid port - Environment: "test", - ServiceName: "graceful-test", - } - - ctx := context.Background() - shutdown, err := SetupDatadog(ctx, cfg) - - // Should NOT fail - graceful degradation - // The exporter creation may succeed but spans will fail to export silently - if err != nil { - t.Fatalf("SetupDatadog() unexpected error: %v", err) - } - if shutdown == nil { - t.Fatal("SetupDatadog() shutdown = nil, want non-nil") - } - - // Shutdown should not panic - if err := shutdown(ctx); err != nil { - t.Errorf("shutdown() unexpected error: %v", err) - } -} - -func TestSetupDatadog_EmptyConfig(t *testing.T) { - // NOTE: not parallel — SetupDatadog modifies global state (os.Setenv, TracerProvider) - - // All empty config - should use defaults - cfg := Config{} - - ctx := context.Background() - shutdown, err := SetupDatadog(ctx, cfg) - - if err != nil { - t.Fatalf("SetupDatadog() unexpected error: %v", err) - } - if shutdown == nil { - t.Fatal("SetupDatadog() shutdown = nil, want non-nil") - } - - if err := shutdown(ctx); err != nil { - t.Errorf("shutdown() unexpected error: %v", err) - } -} - -func TestDefaultAgentHost_Value(t *testing.T) { - t.Parallel() - - if got, want := DefaultAgentHost, "localhost:4318"; got != want { - t.Errorf("DefaultAgentHost = %q, want %q", got, want) - } -} diff --git a/internal/rag/constants.go b/internal/rag/constants.go index 81980c6..c3a7da0 100644 --- a/internal/rag/constants.go +++ b/internal/rag/constants.go @@ -1,14 +1,9 @@ -// Package rag constants.go defines shared constants, types, and configuration for RAG operations. -// -// Contents: -// - Source type constants (SourceTypeConversation, SourceTypeFile, SourceTypeSystem) -// - Table schema constants for documents table -// - NewDocStoreConfig factory for consistent DocStore configuration package rag import ( "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/plugins/postgresql" + "google.golang.org/genai" ) // Source type constants for knowledge documents. @@ -24,6 +19,12 @@ const ( SourceTypeSystem = "system" ) +// VectorDimension is the vector dimension used by the pgvector schema. +// Must match the documents table migration: embedding vector(768). +// gemini-embedding-001 produces 3072 dimensions by default; +// we truncate to 768 via OutputDimensionality in EmbedderOptions. +const VectorDimension int32 = 768 + // Table schema constants for Genkit PostgreSQL plugin. // These match the documents table in db/migrations. const ( @@ -37,7 +38,9 @@ const ( // NewDocStoreConfig creates a postgresql.Config for the documents table. // This factory ensures consistent configuration across production and tests. +// EmbedderOptions sets OutputDimensionality to match the pgvector schema. func NewDocStoreConfig(embedder ai.Embedder) *postgresql.Config { + dim := VectorDimension return &postgresql.Config{ TableName: DocumentsTableName, SchemaName: DocumentsSchemaName, @@ -47,5 +50,6 @@ func NewDocStoreConfig(embedder ai.Embedder) *postgresql.Config { MetadataJSONColumn: DocumentsMetadataCol, MetadataColumns: []string{"source_type"}, // For filtering by type Embedder: embedder, + EmbedderOptions: &genai.EmbedContentConfig{OutputDimensionality: &dim}, } } diff --git a/internal/rag/doc.go b/internal/rag/doc.go index da78850..d324943 100644 --- a/internal/rag/doc.go +++ b/internal/rag/doc.go @@ -8,17 +8,11 @@ // RAG enhances LLM responses by augmenting prompts with relevant context from a knowledge base. // The rag package manages: // -// - System knowledge indexing (IndexSystemKnowledge function) +// - DocStore configuration for vector storage // - Integration with Genkit's PostgreSQL DocStore // // # Architecture // -// System Knowledge Docs -// | -// v -// IndexSystemKnowledge() -// | -// v // Genkit PostgreSQL DocStore // | // +-- Vector embedding (via AI Embedder) @@ -35,10 +29,9 @@ // // # Key Components // -// IndexSystemKnowledge: Package-level function that indexes built-in knowledge: -// - Go best practices and coding standards -// - Agent capabilities and tool usage -// - Architecture principles +// NewDocStoreConfig: Creates configuration for the Genkit PostgreSQL DocStore. +// +// DeleteByIDs: Removes documents by ID for UPSERT emulation. // // # Source Types // @@ -51,5 +44,4 @@ // # Thread Safety // // DocStore handles concurrent operations safely. -// IndexSystemKnowledge is called once at startup. package rag diff --git a/internal/rag/fuzz_test.go b/internal/rag/fuzz_test.go index d5bac97..301026b 100644 --- a/internal/rag/fuzz_test.go +++ b/internal/rag/fuzz_test.go @@ -33,8 +33,7 @@ func FuzzDeleteByIDs_SQLInjection(f *testing.F) { } // Setup test database - dbContainer, cleanup := testutil.SetupTestDB(t) - defer cleanup() + dbContainer := testutil.SetupTestDB(t) ctx := context.Background() pool := dbContainer.Pool @@ -70,15 +69,14 @@ func FuzzDeleteByIDs_SQLInjection(f *testing.F) { func TestDeleteByIDs_EmptySlice(t *testing.T) { t.Parallel() - dbContainer, cleanup := testutil.SetupTestDB(t) - defer cleanup() + dbContainer := testutil.SetupTestDB(t) ctx := context.Background() // Empty slice should return nil without executing query err := rag.DeleteByIDs(ctx, dbContainer.Pool, []string{}) if err != nil { - t.Errorf("deleteByIDs with empty slice should return nil, got: %v", err) + t.Errorf("DeleteByIDs(empty slice) unexpected error: %v", err) } } @@ -86,8 +84,7 @@ func TestDeleteByIDs_EmptySlice(t *testing.T) { func TestDeleteByIDs_ValidUUIDs(t *testing.T) { t.Parallel() - dbContainer, cleanup := testutil.SetupTestDB(t) - defer cleanup() + dbContainer := testutil.SetupTestDB(t) ctx := context.Background() @@ -99,6 +96,6 @@ func TestDeleteByIDs_ValidUUIDs(t *testing.T) { err := rag.DeleteByIDs(ctx, dbContainer.Pool, validIDs) if err != nil { - t.Errorf("deleteByIDs with valid UUIDs should succeed, got: %v", err) + t.Errorf("DeleteByIDs(valid UUIDs) unexpected error: %v", err) } } diff --git a/internal/rag/system.go b/internal/rag/system.go index db0cdbd..3e3d285 100644 --- a/internal/rag/system.go +++ b/internal/rag/system.go @@ -1,59 +1,12 @@ -// Package rag provides RAG (Retrieval-Augmented Generation) functionality. -// This file implements IndexSystemKnowledge for managing built-in knowledge -// about Agent capabilities, Golang best practices, and architecture principles. package rag import ( "context" "fmt" - "log/slog" - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/plugins/postgresql" "github.com/jackc/pgx/v5/pgxpool" ) -// IndexSystemKnowledge indexes all built-in system knowledge documents. -// Called once during application startup. -// -// Features: -// - Uses fixed document IDs (e.g., "system:golang-errors") -// - UPSERT behavior via delete-then-insert (Genkit DocStore doesn't support UPSERT) -// - Returns count of successfully indexed documents -// -// Returns: (indexedCount, error) -// Error: returns error if indexing fails -func IndexSystemKnowledge(ctx context.Context, store *postgresql.DocStore, pool *pgxpool.Pool) (int, error) { - docs := buildSystemKnowledgeDocs() - - // Extract document IDs for deletion (UPSERT emulation) - ids := make([]string, 0, len(docs)) - for _, doc := range docs { - if id, ok := doc.Metadata["id"].(string); ok { - ids = append(ids, id) - } - } - - // Delete existing documents first (UPSERT emulation). - // Genkit DocStore.Index() only does INSERT, so we must delete first. - // NOTE: Not fully atomic — Genkit DocStore manages its own connections, - // so delete (via pool) and insert (via DocStore) cannot share a transaction. - // This is acceptable because IndexSystemKnowledge runs only at startup. - if err := DeleteByIDs(ctx, pool, ids); err != nil { - // DELETE with non-existent IDs returns 0 rows (not an error). - // A real error here indicates a connection or SQL problem. - slog.Warn("failed to delete existing system knowledge", "error", err, "ids", ids) - } - - if err := store.Index(ctx, docs); err != nil { - return 0, fmt.Errorf("failed to index system knowledge: %w", err) - } - - slog.Debug("system knowledge indexed", "count", len(docs)) - - return len(docs), nil -} - // DeleteByIDs deletes documents by their IDs. // Used for UPSERT emulation since Genkit DocStore only supports INSERT. // Exported for testing (fuzz tests in rag_test package). @@ -69,235 +22,3 @@ func DeleteByIDs(ctx context.Context, pool *pgxpool.Pool, ids []string) error { } return nil } - -// buildSystemKnowledgeDocs constructs all system knowledge documents. -func buildSystemKnowledgeDocs() []*ai.Document { - var docs []*ai.Document - - // 1. Golang Style Guide - docs = append(docs, buildGolangStyleDocs()...) - - // 2. Agent Capabilities - docs = append(docs, buildCapabilitiesDocs()...) - - // 3. Architecture Principles - docs = append(docs, buildArchitectureDocs()...) - - return docs -} - -// buildGolangStyleDocs creates Golang best practices documents. -func buildGolangStyleDocs() []*ai.Document { - return []*ai.Document{ - // Document 1: Error Handling - ai.DocumentFromText(`# Golang Error Handling Best Practices - -## Core Principles -- Always check errors immediately after function calls -- Use fmt.Errorf with %w for error wrapping (enables errors.Is/As) -- Avoid naked returns in error paths -- Return errors to callers, don't panic unless truly exceptional - -## Examples -Good: - result, err := doSomething() - if err != nil { - return fmt.Errorf("failed to do something: %w", err) - } - -Bad: - result, _ := doSomething() // Ignoring errors - -## Security -- Never expose internal error details to users -- Log full errors, return sanitized messages`, - map[string]any{ - "id": "system:golang-errors", - "source_type": SourceTypeSystem, - "category": "golang", - "topic": "error-handling", - "version": "1.0", - }), - - // Document 2: Concurrency Patterns - ai.DocumentFromText(`# Golang Concurrency Best Practices - -## Goroutines -- Always have a way to stop goroutines (context, done channel) -- Use WaitGroups for coordinating multiple goroutines -- Avoid goroutine leaks by ensuring all goroutines eventually exit - -## Channels -- Close channels from sender side only -- Use select with context.Done() for cancellation -- Buffered channels for non-blocking sends - -## Context -- Pass context as first parameter -- Use context.WithTimeout for operations with deadlines -- Never store context in struct fields (exception: short-lived request-scoped structs) - -## Mutexes -- Keep critical sections small -- Use RWMutex when read-heavy workload -- Prefer channels over shared memory when possible`, - map[string]any{ - "id": "system:golang-concurrency", - "source_type": SourceTypeSystem, - "category": "golang", - "topic": "concurrency", - "version": "1.0", - }), - - // Document 3: Naming Conventions - ai.DocumentFromText(`# Golang Naming Conventions - -## Packages -- Short, lowercase, no underscores (e.g., httputil, not http_util) -- Singular form (e.g., encoding, not encodings) - -## Interfaces -- One-method interfaces: name with -er suffix (Reader, Writer, Closer) -- Avoid "I" prefix (use Reader, not IReader) - -## Getters/Setters -- No "Get" prefix for getters (use Owner(), not GetOwner()) -- Use "Set" prefix for setters (SetOwner()) - -## Acronyms -- Keep consistent casing: URL, HTTP, ID (not Url, Http, Id) -- In names: use URLParser, not UrlParser - -## Exported vs Unexported -- Exported: PascalCase (MyFunction, MyStruct) -- Unexported: camelCase (myFunction, myStruct)`, - map[string]any{ - "id": "system:golang-naming", - "source_type": SourceTypeSystem, - "category": "golang", - "topic": "naming", - "version": "1.0", - }), - } -} - -// buildCapabilitiesDocs creates Agent capabilities documents. -func buildCapabilitiesDocs() []*ai.Document { - return []*ai.Document{ - // Document 4: Available Tools - ai.DocumentFromText(`# Agent Available Tools - -## File Operations -- read_file: Read file contents -- write_file: Create or update file -- list_files: List directory contents with glob patterns -- delete_file: Remove file (requires confirmation) -- get_file_info: Get file metadata - -## System Operations -- current_time: Get current timestamp -- execute_command: Run shell commands (requires confirmation for destructive ops) -- get_env: Read environment variables - -## Network Operations -- web_search: Search the web using SearXNG metasearch engine -- web_fetch: Fetch and extract content from a specific URL - -## Knowledge Operations -- search_history: Search conversation history -- search_documents: Search user-indexed documents (Notion pages, local files) -- search_system_knowledge: Search Agent's built-in knowledge -- knowledge_store: Store new knowledge documents for later retrieval - -## Limitations -- File operations limited to current working directory -- Commands requiring sudo are blocked -- Cannot access system files (/etc, /sys, etc.)`, - map[string]any{ - "id": "system:agent-tools", - "source_type": SourceTypeSystem, - "category": "capabilities", - "topic": "available-tools", - "version": "1.0", - }), - - // Document 5: Best Practices - ai.DocumentFromText(`# Agent Best Practices - -## When to Use Tools -- search_history: When user asks "what did I say about X?" -- search_documents: When user asks about their notes/documents -- search_system_knowledge: When unsure about Golang conventions or Agent capabilities -- read_file before write_file: Always read first to understand context - -## Communication -- Be concise but informative -- Use code blocks for code snippets -- Explain what you're about to do before using destructive tools -- Ask for confirmation when ambiguous - -## Error Handling -- If a tool fails, try alternative approaches -- Explain errors in user-friendly language -- Don't give up after first failure - retry with different parameters - -## Security -- Never execute user-provided code without understanding it -- Always validate file paths before operations -- Sanitize command inputs to prevent injection`, - map[string]any{ - "id": "system:agent-best-practices", - "source_type": SourceTypeSystem, - "category": "capabilities", - "topic": "best-practices", - "version": "1.0", - }), - } -} - -// buildArchitectureDocs creates architecture principles documents. -func buildArchitectureDocs() []*ai.Document { - return []*ai.Document{ - // Document 6: Design Principles - ai.DocumentFromText(`# Koopa CLI Architecture Principles - -## Dependency Injection -- Use struct-based DI -- Define interfaces in consumer packages (not provider packages) -- Accept interfaces, return structs - -## Package Structure -- internal/agent: Core AI interaction logic -- internal/tools: Tool definitions and implementations -- internal/rag: Knowledge retrieval and indexing -- internal/session: Session persistence -- cmd: CLI commands and user interaction - -## Error Handling -- Errors propagate up, logged at boundaries -- Use error wrapping (fmt.Errorf with %w) -- Graceful degradation when non-critical services fail - -## Testing -- Unit tests for business logic -- Integration tests for cross-package interactions -- Use interfaces for mocking - -## Security -- Principle of least privilege -- Input validation at boundaries -- Explicit user confirmation for destructive operations - -## Concurrency -- Use context for cancellation -- Protect shared state with mutexes -- Goroutines must have cleanup mechanism`, - map[string]any{ - "id": "system:architecture-principles", - "source_type": SourceTypeSystem, - "category": "architecture", - "topic": "design-principles", - "version": "1.0", - }), - } -} diff --git a/internal/rag/system_test.go b/internal/rag/system_test.go deleted file mode 100644 index a84a046..0000000 --- a/internal/rag/system_test.go +++ /dev/null @@ -1,206 +0,0 @@ -//go:build integration - -// Package rag_test provides integration tests for RAG system knowledge functions. -package rag_test - -import ( - "context" - "testing" - - "github.com/koopa0/koopa/internal/rag" - "github.com/koopa0/koopa/internal/testutil" -) - -// TestIndexSystemKnowledge_FirstTime verifies first-time indexing works correctly. -func TestIndexSystemKnowledge_FirstTime(t *testing.T) { - t.Parallel() - - dbContainer, cleanup := testutil.SetupTestDB(t) - defer cleanup() - - ragSetup := testutil.SetupRAG(t, dbContainer.Pool) - ctx := context.Background() - - // Index system knowledge for the first time - count, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool) - if err != nil { - t.Fatalf("IndexSystemKnowledge() first-time indexing unexpected error: %v (should succeed)", err) - } - if count <= 0 { - t.Errorf("IndexSystemKnowledge() count = %d, want > 0 (should index at least one document)", count) - } - - t.Logf("Indexed %d system knowledge documents", count) -} - -// TestIndexSystemKnowledge_Reindexing verifies UPSERT behavior (no duplicates on re-index). -func TestIndexSystemKnowledge_Reindexing(t *testing.T) { - t.Parallel() - - dbContainer, cleanup := testutil.SetupTestDB(t) - defer cleanup() - - ragSetup := testutil.SetupRAG(t, dbContainer.Pool) - ctx := context.Background() - - // First indexing - count1, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool) - if err != nil { - t.Fatalf("IndexSystemKnowledge() first indexing unexpected error: %v (should succeed)", err) - } - - // Second indexing (should not create duplicates) - count2, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool) - if err != nil { - t.Fatalf("IndexSystemKnowledge() re-indexing unexpected error: %v (should succeed)", err) - } - - if count1 != count2 { - t.Errorf("IndexSystemKnowledge() re-indexing count = %d, want %d (should index same number of documents)", count2, count1) - } - - // Verify no duplicates by counting documents with system source type - var totalCount int - err = dbContainer.Pool.QueryRow(ctx, - `SELECT COUNT(*) FROM documents WHERE metadata->>'source_type' = $1`, - rag.SourceTypeSystem, - ).Scan(&totalCount) - if err != nil { - t.Fatalf("QueryRow() counting documents unexpected error: %v (should succeed)", err) - } - - if totalCount != count1 { - t.Errorf("total system documents after re-indexing = %d, want %d (should have no duplicate documents)", totalCount, count1) - } - t.Logf("Total system documents after re-indexing: %d (expected: %d)", totalCount, count1) -} - -// TestIndexSystemKnowledge_DocumentMetadata verifies all documents have required metadata. -func TestIndexSystemKnowledge_DocumentMetadata(t *testing.T) { - t.Parallel() - - dbContainer, cleanup := testutil.SetupTestDB(t) - defer cleanup() - - ragSetup := testutil.SetupRAG(t, dbContainer.Pool) - ctx := context.Background() - - // Index documents - if _, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool); err != nil { - t.Fatalf("IndexSystemKnowledge() unexpected error: %v", err) - } - - // Query documents and verify metadata - rows, err := dbContainer.Pool.Query(ctx, - `SELECT id, metadata FROM documents WHERE metadata->>'source_type' = $1`, - rag.SourceTypeSystem, - ) - if err != nil { - t.Fatalf("Query() unexpected error: %v", err) - } - defer rows.Close() - - docCount := 0 - for rows.Next() { - var id string - var metadata map[string]any - if err := rows.Scan(&id, &metadata); err != nil { - t.Fatalf("rows.Scan() unexpected error: %v", err) - } - - // Verify required metadata fields - if metadata["id"] == "" || metadata["id"] == nil { - t.Errorf("document %q metadata[id] = empty, want non-empty", id) - } - if metadata["source_type"] == "" || metadata["source_type"] == nil { - t.Errorf("document %q metadata[source_type] = empty, want non-empty", id) - } - if metadata["category"] == "" || metadata["category"] == nil { - t.Errorf("document %q metadata[category] = empty, want non-empty", id) - } - if metadata["topic"] == "" || metadata["topic"] == nil { - t.Errorf("document %q metadata[topic] = empty, want non-empty", id) - } - - docCount++ - } - - if err := rows.Err(); err != nil { - t.Fatalf("rows.Err() unexpected error: %v", err) - } - if docCount <= 0 { - t.Errorf("verified documents count = %d, want > 0 (should have at least one system document)", docCount) - } - t.Logf("Verified metadata for %d documents", docCount) -} - -// TestIndexSystemKnowledge_UniqueIDs verifies all documents have unique IDs. -func TestIndexSystemKnowledge_UniqueIDs(t *testing.T) { - t.Parallel() - - dbContainer, cleanup := testutil.SetupTestDB(t) - defer cleanup() - - ragSetup := testutil.SetupRAG(t, dbContainer.Pool) - ctx := context.Background() - - // Index documents - if _, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool); err != nil { - t.Fatalf("IndexSystemKnowledge() unexpected error: %v", err) - } - - // Query for duplicate IDs - rows, err := dbContainer.Pool.Query(ctx, - `SELECT metadata->>'id' as doc_id, COUNT(*) as cnt - FROM documents - WHERE metadata->>'source_type' = $1 - GROUP BY metadata->>'id' - HAVING COUNT(*) > 1`, - rag.SourceTypeSystem, - ) - if err != nil { - t.Fatalf("Query() unexpected error: %v", err) - } - defer rows.Close() - - duplicates := 0 - for rows.Next() { - var docID string - var count int - if err := rows.Scan(&docID, &count); err != nil { - t.Fatalf("rows.Scan() unexpected error: %v", err) - } - t.Errorf("duplicate document ID found: %s (count: %d)", docID, count) - duplicates++ - } - - if err := rows.Err(); err != nil { - t.Fatalf("rows.Err() unexpected error: %v", err) - } - if duplicates != 0 { - t.Errorf("duplicate document IDs = %d, want 0", duplicates) - } -} - -// TestIndexSystemKnowledge_CanceledContext verifies graceful handling of canceled context. -func TestIndexSystemKnowledge_CanceledContext(t *testing.T) { - t.Parallel() - - dbContainer, cleanup := testutil.SetupTestDB(t) - defer cleanup() - - ragSetup := testutil.SetupRAG(t, dbContainer.Pool) - - // Create already-canceled context - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - // Should fail gracefully with canceled context - _, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool) - - // Error is expected (context canceled) - if err == nil { - t.Error("IndexSystemKnowledge(canceled context) error = nil, want non-nil (should fail with canceled context)") - } - t.Logf("Expected error: %v", err) -} diff --git a/internal/security/command.go b/internal/security/command.go index 07a59dc..420a4a2 100644 --- a/internal/security/command.go +++ b/internal/security/command.go @@ -3,14 +3,14 @@ package security import ( "fmt" "log/slog" + "slices" "strings" ) // Command validates commands to prevent injection attacks. // Used to prevent command injection attacks (CWE-78). type Command struct { - blacklist []string - whitelist []string // If non-empty, only allow commands in the whitelist + whitelist []string // Only allow commands in this list blockedSubcommands map[string][]string // cmd → blocked first-arg subcommands blockedArgPatterns map[string][]string // cmd → blocked argument patterns (any position) } @@ -31,7 +31,6 @@ type Command struct { // make and mkdir are NOT whitelisted — make can execute arbitrary Makefile targets. func NewCommand() *Command { return &Command{ - blacklist: []string{}, // Whitelist mode doesn't need blacklist whitelist: []string{ // File listing (metadata only — no content reading) "ls", "wc", "sort", "uniq", @@ -68,7 +67,7 @@ func NewCommand() *Command { } } -// ValidateCommand validates whether a command is safe. +// Validate validates whether a command is safe to execute. // // SECURITY NOTE: This validator is designed for use with exec.Command(cmd, args...), // which does NOT pass arguments through a shell. Therefore: @@ -79,14 +78,14 @@ func NewCommand() *Command { // Parameters: // - cmd: command name (executable) // - args: command arguments (passed directly to exec.Command, not shell-interpreted) -func (v *Command) ValidateCommand(cmd string, args []string) error { +func (v *Command) Validate(cmd string, args []string) error { // 1. Check for empty command if strings.TrimSpace(cmd) == "" { return fmt.Errorf("command cannot be empty") } // 2. Validate command name only (no args yet) - if err := v.validateCommandName(cmd); err != nil { + if err := validateCommandName(cmd); err != nil { return fmt.Errorf("validating command name: %w", err) } @@ -111,7 +110,7 @@ func (v *Command) ValidateCommand(cmd string, args []string) error { // NOTE: We do NOT check for shell metacharacters (|, $, >, etc.) because // exec.Command treats them as literal strings, not shell operators for i, arg := range args { - if err := v.validateArgument(arg); err != nil { + if err := validateArgument(arg); err != nil { slog.Warn("dangerous argument detected", "command", cmd, "arg_index", i, @@ -122,52 +121,27 @@ func (v *Command) ValidateCommand(cmd string, args []string) error { } } - // 6. Check full command string (cmd + args) for dangerous patterns - // Some dangerous patterns span command and arguments (e.g., "rm -rf /") - fullCmd := strings.ToLower(cmd + " " + strings.Join(args, " ")) - for _, pattern := range v.blacklist { - if strings.Contains(fullCmd, strings.ToLower(pattern)) { - slog.Warn("command+args match dangerous pattern", - "command", cmd, - "args", args, - "full_command", fullCmd, - "dangerous_pattern", pattern, - "security_event", "dangerous_command_combination") - return fmt.Errorf("command contains dangerous pattern: '%s'", pattern) - } - } - return nil } +// shellMetachars lists characters that indicate shell injection in a command name. +const shellMetachars = ";|&`\n><$()" + // validateCommandName validates the command name (executable) only. -// Checks blacklist patterns and shell injection attempts in the command name itself. -func (v *Command) validateCommandName(cmd string) error { +// Checks for shell injection attempts in the command name itself. +func validateCommandName(cmd string) error { // Normalize command name cmd = strings.TrimSpace(strings.ToLower(cmd)) - // Check blacklisted command patterns - for _, pattern := range v.blacklist { - if strings.Contains(cmd, strings.ToLower(pattern)) { - slog.Warn("command matches blacklisted pattern", - "command", cmd, - "dangerous_pattern", pattern, - "security_event", "command_blacklist_violation") - return fmt.Errorf("command contains dangerous pattern: '%s'", pattern) - } - } - // Check for shell metacharacters in command name itself // (These would indicate shell injection attempt) - shellMetachars := []string{";", "|", "&", "`", "\n", ">", "<", "$", "(", ")"} - for _, char := range shellMetachars { - if strings.Contains(cmd, char) { - slog.Warn("command name contains shell metacharacter", - "command", cmd, - "character", char, - "security_event", "shell_injection_in_command_name") - return fmt.Errorf("command name contains shell metacharacter: '%s'", char) - } + if i := strings.IndexAny(cmd, shellMetachars); i >= 0 { + char := string(cmd[i]) + slog.Warn("command name contains shell metacharacter", + "command", cmd, + "character", char, + "security_event", "shell_injection_in_command_name") + return fmt.Errorf("command name contains shell metacharacter: %q", char) } return nil @@ -193,14 +167,12 @@ func (v *Command) validateSubcommands(cmd string, args []string) error { // Check blocked subcommands (first argument) if blocked, ok := v.blockedSubcommands[cmdLower]; ok && len(args) > 0 { firstArg := strings.ToLower(strings.TrimSpace(args[0])) - for _, sub := range blocked { - if firstArg == sub { - slog.Warn("blocked subcommand", - "command", cmd, - "subcommand", args[0], - "security_event", "blocked_subcommand") - return fmt.Errorf("subcommand '%s %s' is not allowed (can execute arbitrary code)", cmd, args[0]) - } + if slices.Contains(blocked, firstArg) { + slog.Warn("blocked subcommand", + "command", cmd, + "subcommand", args[0], + "security_event", "blocked_subcommand") + return fmt.Errorf("subcommand '%s %s' is not allowed (can execute arbitrary code)", cmd, args[0]) } } @@ -224,6 +196,20 @@ func (v *Command) validateSubcommands(cmd string, args []string) error { return nil } +// dangerousArgPatterns lists embedded command patterns that are dangerous +// even when passed as arguments via exec.Command. +var dangerousArgPatterns = []string{ + "rm -rf /", + "rm -rf /*", + "rm -rf ~", + "mkfs", + "dd if=/dev/zero", + "dd if=/dev/urandom", + "shutdown", + "reboot", + "sudo su", +} + // validateArgument checks if an argument contains obviously malicious patterns. // // IMPORTANT: This function does NOT check for shell metacharacters like $, |, >, < @@ -232,7 +218,7 @@ func (v *Command) validateSubcommands(cmd string, args []string) error { // - Embedded dangerous commands (e.g., "rm -rf /") // - Null bytes // - Extremely long arguments (possible buffer overflow) -func (*Command) validateArgument(arg string) error { +func validateArgument(arg string) error { // Check for null bytes (often used in injection attacks) if strings.Contains(arg, "\x00") { return fmt.Errorf("argument contains null byte") @@ -246,19 +232,7 @@ func (*Command) validateArgument(arg string) error { // Check for embedded dangerous command patterns // These are suspicious even in arguments argLower := strings.ToLower(arg) - dangerousPatterns := []string{ - "rm -rf /", - "rm -rf /*", - "rm -rf ~", - "mkfs", - "dd if=/dev/zero", - "dd if=/dev/urandom", - "shutdown", - "reboot", - "sudo su", - } - - for _, pattern := range dangerousPatterns { + for _, pattern := range dangerousArgPatterns { if strings.Contains(argLower, pattern) { return fmt.Errorf("argument contains dangerous pattern: %s", pattern) } diff --git a/internal/security/command_test.go b/internal/security/command_test.go index 95efef5..0f0d6da 100644 --- a/internal/security/command_test.go +++ b/internal/security/command_test.go @@ -89,12 +89,12 @@ func TestCommandValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := cmdValidator.ValidateCommand(tt.command, tt.args) + err := cmdValidator.Validate(tt.command, tt.args) if tt.shouldErr && err == nil { - t.Errorf("expected error for %q, but got none: %s", tt.command, tt.reason) + t.Errorf("Validate(%q, %v) = nil, want error: %s", tt.command, tt.args, tt.reason) } if !tt.shouldErr && err != nil { - t.Errorf("unexpected error for %q: %v (%s)", tt.command, err, tt.reason) + t.Errorf("Validate(%q, %v) = %v, want nil (%s)", tt.command, tt.args, err, tt.reason) } }) } @@ -133,12 +133,12 @@ func TestStrictCommandValidator(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validator.ValidateCommand(tt.command, tt.args) + err := validator.Validate(tt.command, tt.args) if tt.shouldErr && err == nil { - t.Errorf("expected error for %q", tt.command) + t.Errorf("Validate(%q, %v) = nil, want error", tt.command, tt.args) } if !tt.shouldErr && err != nil { - t.Errorf("unexpected error for %q: %v", tt.command, err) + t.Errorf("Validate(%q, %v) = %v, want nil", tt.command, tt.args, err) } }) } @@ -193,18 +193,33 @@ func TestBlockedSubcommands(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := v.ValidateCommand(tt.command, tt.args) + err := v.Validate(tt.command, tt.args) if tt.shouldErr && err == nil { - t.Errorf("ValidateCommand(%q, %v) = nil, want error", tt.command, tt.args) + t.Errorf("Validate(%q, %v) = nil, want error", tt.command, tt.args) } if !tt.shouldErr && err != nil { - t.Errorf("ValidateCommand(%q, %v) = %v, want nil", tt.command, tt.args, err) + t.Errorf("Validate(%q, %v) = %v, want nil", tt.command, tt.args, err) } }) } } // TestCommandValidationEdgeCases tests edge cases in command validation +// TestAllShellMetacharsBlocked verifies every shell metacharacter in the const +// is blocked when it appears in a command name. This prevents regressions if +// shellMetachars is modified. +func TestAllShellMetacharsBlocked(t *testing.T) { + v := NewCommand() + metachars := []string{";", "|", "&", "`", "\n", ">", "<", "$", "(", ")"} + + for _, char := range metachars { + cmd := "ls" + char + "cat" + if err := v.Validate(cmd, nil); err == nil { + t.Errorf("Validate(%q, nil) = nil, want error for metachar %q", cmd, char) + } + } +} + func TestCommandValidationEdgeCases(t *testing.T) { cmdValidator := NewCommand() @@ -275,12 +290,12 @@ func TestCommandValidationEdgeCases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := cmdValidator.ValidateCommand(tt.command, tt.args) + err := cmdValidator.Validate(tt.command, tt.args) if tt.shouldErr && err == nil { - t.Errorf("expected error for %q, but got none: %s", tt.name, tt.reason) + t.Errorf("Validate(%q, %v) = nil, want error: %s", tt.command, tt.args, tt.reason) } if !tt.shouldErr && err != nil { - t.Errorf("unexpected error for %q: %v (%s)", tt.name, err, tt.reason) + t.Errorf("Validate(%q, %v) = %v, want nil (%s)", tt.command, tt.args, err, tt.reason) } }) } diff --git a/internal/security/doc.go b/internal/security/doc.go index 02d203f..8e17ca3 100644 --- a/internal/security/doc.go +++ b/internal/security/doc.go @@ -16,30 +16,30 @@ // Path Validator: Prevents directory traversal and ensures file operations // stay within allowed boundaries. // -// pathValidator := security.NewPath() -// if err := pathValidator.ValidatePath(userInput); err != nil { +// pathValidator, err := security.NewPath([]string{"/safe/dir"}) +// if _, err := pathValidator.Validate(userInput); err != nil { // return fmt.Errorf("invalid path: %w", err) // } // // Command Validator: Blocks dangerous shell commands and prevents command injection. // // cmdValidator := security.NewCommand() -// if err := cmdValidator.ValidateCommand(cmd, args); err != nil { +// if err := cmdValidator.Validate(cmd, args); err != nil { // return fmt.Errorf("dangerous command: %w", err) // } // // Dangerous commands blocked include: rm -rf, sudo, shutdown, dd, mkfs, // format, and other destructive operations. // -// HTTP Validator: Prevents SSRF attacks by blocking requests to private networks +// URL Validator: Prevents SSRF attacks by blocking requests to private networks // and cloud metadata endpoints. // -// httpValidator := security.NewHTTP() -// if err := httpValidator.ValidateURL(url); err != nil { +// urlValidator := security.NewURL() +// if err := urlValidator.Validate(rawURL); err != nil { // return fmt.Errorf("SSRF attempt blocked: %w", err) // } -// // Use the validator's HTTP client for safe requests -// resp, err := httpValidator.Client().Get(url) +// // Use SafeTransport for DNS-rebinding protection +// client := &http.Client{Transport: urlValidator.SafeTransport()} // // Blocked targets include: // - Private IP ranges (127.0.0.1, 192.168.x.x, 10.x.x.x) @@ -50,7 +50,7 @@ // from unauthorized access. // // envValidator := security.NewEnv() -// if err := envValidator.ValidateEnvAccess(key); err != nil { +// if err := envValidator.Validate(key); err != nil { // return fmt.Errorf("access to sensitive variable blocked: %w", err) // } // @@ -67,27 +67,32 @@ // # Integration Example // // // Create validators -// pathVal := security.NewPath() +// pathVal, _ := security.NewPath([]string{workDir}) // cmdVal := security.NewCommand() -// httpVal := security.NewHTTP() +// urlVal := security.NewURL() // envVal := security.NewEnv() // -// // Pass to toolsets during initialization -// fileToolset := tools.NewFileToolset(pathVal) -// systemToolset := tools.NewSystemToolset(cmdVal, envVal) -// networkToolset := tools.NewNetworkToolset(httpVal) +// // Pass to tool constructors during initialization +// fileTools, _ := tools.NewFile(pathVal, logger) +// systemTools, _ := tools.NewSystem(cmdVal, envVal, logger) +// networkTools, _ := tools.NewNetwork(urlVal, logger) // // # Configuration // -// The HTTP validator supports configuration for response size limits, -// timeouts, and redirect limits: +// The URL validator uses SafeTransport for DNS-resolution-level SSRF protection: // -// httpValidator := security.NewHTTP() -// // Default: 10MB max response, 30s timeout, 10 redirects -// client := httpValidator.Client() +// urlValidator := security.NewURL() +// client := &http.Client{Transport: urlValidator.SafeTransport()} // // Other validators use secure defaults and require no configuration. // +// # Error Handling +// +// Validators intentionally both log and return errors. This is a deliberate +// exception to the "handle errors once" rule: security events require an +// audit trail (via logging) AND must propagate the error to callers so they +// can deny the operation. Removing either side would create a security gap. +// // # Testing // // Each validator includes comprehensive tests covering: diff --git a/internal/security/env.go b/internal/security/env.go index 6582c5f..c869dd9 100644 --- a/internal/security/env.go +++ b/internal/security/env.go @@ -22,7 +22,6 @@ func NewEnv() *Env { "SECRET", "PASSWORD", "PASSWD", - "PWD", "TOKEN", "ACCESS_TOKEN", "REFRESH_TOKEN", @@ -34,8 +33,8 @@ func NewEnv() *Env { // Cloud services related "AWS_SECRET", "AWS_ACCESS_KEY", - "AZURE_", - "GCP_", + "AZURE", + "GCP", "GOOGLE_API", "GOOGLE_APPLICATION_CREDENTIALS", @@ -88,13 +87,14 @@ func NewEnv() *Env { } } -// ValidateEnvAccess validates whether access to the specified environment variable is allowed -func (v *Env) ValidateEnvAccess(envName string) error { +// Validate validates whether access to the specified environment variable is allowed. +func (v *Env) Validate(envName string) error { envUpper := strings.ToUpper(envName) - // Check if it matches sensitive patterns + // Check if it matches sensitive patterns using word-boundary matching. + // Splits on "_" to avoid false positives like PWD matching PASSWORD. for _, pattern := range v.sensitivePatterns { - if strings.Contains(envUpper, pattern) { + if isSensitivePattern(envUpper, pattern) { slog.Warn("sensitive environment variable access attempt", "env_name", envName, "matched_pattern", pattern, @@ -105,3 +105,27 @@ func (v *Env) ValidateEnvAccess(envName string) error { return nil } + +// isSensitivePattern checks if envName matches pattern. +// +// For composite patterns (containing "_" like "API_KEY"), it uses substring matching +// to catch variables like "MY_API_KEY". +// +// For single-word patterns (like "SECRET"), it uses word-boundary matching by splitting +// on "_" to avoid false positives (e.g., "GOPATH" should not match "PATH"). +func isSensitivePattern(envName, pattern string) bool { + if envName == pattern { + return true + } + // Composite patterns: substring matching (e.g., "API_KEY" in "MY_API_KEY") + if strings.Contains(pattern, "_") { + return strings.Contains(envName, pattern) + } + // Single-word patterns: word-boundary matching (e.g., "SECRET" in "MY_SECRET_VAR") + for _, segment := range strings.Split(envName, "_") { + if segment == pattern { + return true + } + } + return false +} diff --git a/internal/security/env_test.go b/internal/security/env_test.go index 74a2d9a..b34bc3e 100644 --- a/internal/security/env_test.go +++ b/internal/security/env_test.go @@ -4,44 +4,95 @@ import ( "testing" ) -// TestEnvValidator tests environment variable validation +// TestEnvValidator tests environment variable validation with word-boundary matching. func TestEnvValidator(t *testing.T) { envValidator := NewEnv() tests := []struct { - name string - key string - shouldErr bool - reason string + name string + key string + wantErr bool }{ - { - name: "valid env key", - key: "MY_VAR", - shouldErr: false, - reason: "valid env key should be allowed", - }, - { - name: "API_KEY should be blocked", - key: "API_KEY", - shouldErr: true, - reason: "API_KEY is sensitive and should be blocked", - }, - { - name: "PASSWORD should be blocked", - key: "PASSWORD", - shouldErr: true, - reason: "PASSWORD is sensitive and should be blocked", - }, + // Allowed — no sensitive segment + {name: "generic var", key: "MY_VAR", wantErr: false}, + {name: "HOME", key: "HOME", wantErr: false}, + {name: "SHELL", key: "SHELL", wantErr: false}, + {name: "TERM", key: "TERM", wantErr: false}, + + // Word-boundary: previously false positives, now allowed + {name: "PWD allowed", key: "PWD", wantErr: false}, + {name: "GOPATH allowed", key: "GOPATH", wantErr: false}, + {name: "MANPATH allowed", key: "MANPATH", wantErr: false}, + {name: "PYTHONPATH allowed", key: "PYTHONPATH", wantErr: false}, + + // Blocked — exact match + {name: "API_KEY blocked", key: "API_KEY", wantErr: true}, + {name: "PASSWORD blocked", key: "PASSWORD", wantErr: true}, + {name: "SECRET blocked", key: "SECRET", wantErr: true}, + {name: "TOKEN blocked", key: "TOKEN", wantErr: true}, + {name: "DATABASE_URL blocked", key: "DATABASE_URL", wantErr: true}, + + // Blocked — segment match + {name: "DB_PASSWORD blocked", key: "DB_PASSWORD", wantErr: true}, + {name: "MY_API_KEY blocked", key: "MY_API_KEY", wantErr: true}, + {name: "MY_SECRET_KEY blocked", key: "MY_SECRET_KEY", wantErr: true}, + {name: "APP_TOKEN blocked", key: "APP_TOKEN", wantErr: true}, + + // Cloud provider segment matching + {name: "AZURE_CLIENT_ID blocked", key: "AZURE_CLIENT_ID", wantErr: true}, + {name: "AZURE_TENANT_ID blocked", key: "AZURE_TENANT_ID", wantErr: true}, + {name: "GCP_PROJECT blocked", key: "GCP_PROJECT", wantErr: true}, + {name: "GCP_REGION blocked", key: "GCP_REGION", wantErr: true}, + + // Case insensitive + {name: "lowercase password blocked", key: "db_password", wantErr: true}, + {name: "mixed case api_key blocked", key: "My_Api_Key", wantErr: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := envValidator.ValidateEnvAccess(tt.key) - if tt.shouldErr && err == nil { - t.Errorf("expected error for %q, but got none: %s", tt.key, tt.reason) + err := envValidator.Validate(tt.key) + if tt.wantErr && err == nil { + t.Errorf("Validate(%q) = nil, want error", tt.key) } - if !tt.shouldErr && err != nil { - t.Errorf("unexpected error for %q: %v (%s)", tt.key, err, tt.reason) + if !tt.wantErr && err != nil { + t.Errorf("Validate(%q) = %v, want nil", tt.key, err) + } + }) + } +} + +// TestIsSensitivePattern tests the word-boundary matching logic directly. +func TestIsSensitivePattern(t *testing.T) { + tests := []struct { + name string + envName string + pattern string + want bool + }{ + // Single-word patterns: word-boundary matching + {name: "exact match", envName: "PASSWORD", pattern: "PASSWORD", want: true}, + {name: "segment match", envName: "DB_PASSWORD", pattern: "PASSWORD", want: true}, + {name: "no match substring", envName: "PWD", pattern: "PASSWORD", want: false}, + {name: "no match compound", envName: "GOPATH", pattern: "PATH", want: false}, + {name: "segment at start", envName: "SECRET_KEY", pattern: "SECRET", want: true}, + {name: "segment at end", envName: "MY_TOKEN", pattern: "TOKEN", want: true}, + {name: "segment in middle", envName: "MY_SECRET_KEY", pattern: "SECRET", want: true}, + {name: "no partial segment", envName: "MYPASSWORD", pattern: "PASSWORD", want: false}, + + // Composite patterns: substring matching + {name: "composite exact", envName: "API_KEY", pattern: "API_KEY", want: true}, + {name: "composite prefix", envName: "MY_API_KEY", pattern: "API_KEY", want: true}, + {name: "composite suffix", envName: "API_KEY_OLD", pattern: "API_KEY", want: true}, + {name: "composite in middle", envName: "MY_API_KEY_OLD", pattern: "API_KEY", want: true}, + {name: "composite no match", envName: "APIKEY", pattern: "API_KEY", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isSensitivePattern(tt.envName, tt.pattern) + if got != tt.want { + t.Errorf("isSensitivePattern(%q, %q) = %v, want %v", tt.envName, tt.pattern, got, tt.want) } }) } diff --git a/internal/security/fuzz_test.go b/internal/security/fuzz_test.go index fc278d3..226487f 100644 --- a/internal/security/fuzz_test.go +++ b/internal/security/fuzz_test.go @@ -72,7 +72,7 @@ func FuzzPathValidation(f *testing.F) { tmpDir := f.TempDir() validator, err := NewPath([]string{tmpDir}) if err != nil { - f.Fatalf("failed to create validator: %v", err) + f.Fatalf("creating validator: %v", err) } f.Fuzz(func(t *testing.T, input string) { @@ -132,14 +132,14 @@ func FuzzPathValidationWithSymlinks(f *testing.F) { tmpDir := t.TempDir() validator, err := NewPath([]string{tmpDir}) if err != nil { - t.Skipf("failed to create validator: %v", err) + t.Skipf("creating validator: %v", err) } // Create a symlink pointing outside allowed directories linkPath := filepath.Join(tmpDir, linkName) err = os.Symlink("/etc/passwd", linkPath) if err != nil { - t.Skipf("failed to create symlink: %v", err) + t.Skipf("creating symlink: %v", err) } // Validation should fail because symlink points outside allowed dirs @@ -208,7 +208,7 @@ func FuzzCommandValidation(f *testing.F) { f.Fuzz(func(t *testing.T, cmd, args string) { argSlice := strings.Fields(args) - err := validator.ValidateCommand(cmd, argSlice) + err := validator.Validate(cmd, argSlice) // Property 1: Commands with shell metacharacters in name must be rejected shellMetachars := []string{";", "|", "&", "`", "$", "(", ")", "\n", ">", "<"} @@ -268,3 +268,68 @@ func FuzzCommandValidation(f *testing.F) { } }) } + +// ============================================================================= +// URL Fuzzing Tests +// ============================================================================= + +// FuzzURLValidation tests URL validation against SSRF bypass attempts. +// Run with: go test -fuzz=FuzzURLValidation -fuzztime=30s ./internal/security/ +func FuzzURLValidation(f *testing.F) { + seeds := []string{ + // Valid public URLs + "https://example.com", + "http://example.com/path?q=1", + + // Blocked schemes + "ftp://example.com", + "file:///etc/passwd", + "javascript:alert(1)", + "gopher://evil.com", + + // Loopback + "http://127.0.0.1", + "http://127.0.0.1:8080", + "http://[::1]", + + // Private IPs + "http://10.0.0.1", + "http://172.16.0.1", + "http://192.168.1.1", + + // Cloud metadata + "http://169.254.169.254/latest/meta-data/", + "http://metadata.google.internal", + + // Blocked hosts + "http://localhost", + "http://localhost:3000", + + // Edge cases + "", + "://", + "http://", + "http://0.0.0.0", + "http://[::ffff:127.0.0.1]", + + // Encoding tricks + "http://0x7f000001", // 127.0.0.1 as hex + "http://2130706433", // 127.0.0.1 as decimal + "http://017700000001", // 127.0.0.1 as octal + "http://[::ffff:7f00:1]", // IPv6-mapped IPv4 loopback + "http://127.1", // short form loopback + "http://0x7f.0.0.1", // partial hex loopback + "http://0177.0.0.1", // octal first octet + } + + for _, seed := range seeds { + f.Add(seed) + } + + validator := NewURL() + + f.Fuzz(func(t *testing.T, rawURL string) { + // Must not panic + _ = validator.Validate(rawURL) + }) +} diff --git a/internal/security/path_test.go b/internal/security/path_test.go index efbfa50..be8e75f 100644 --- a/internal/security/path_test.go +++ b/internal/security/path_test.go @@ -14,18 +14,18 @@ func TestPathValidation(t *testing.T) { tmpDir := t.TempDir() workDir, err := os.Getwd() if err != nil { - t.Fatalf("failed to get working directory: %v", err) + t.Fatalf("getting working directory: %v", err) } // Change to temp directory for testing if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("failed to change to temp directory: %v", err) + t.Fatalf("changing to temp directory: %v", err) } defer func() { _ = os.Chdir(workDir) }() // Restore original directory validator, err := NewPath([]string{tmpDir}) if err != nil { - t.Fatalf("failed to create path validator: %v", err) + t.Fatalf("creating path validator: %v", err) } tests := []struct { @@ -64,10 +64,10 @@ func TestPathValidation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { _, err := validator.Validate(tt.path) if tt.shouldErr && err == nil { - t.Errorf("expected error for %s, but got none: %s", tt.path, tt.reason) + t.Errorf("Validate(%q) = nil, want error: %s", tt.path, tt.reason) } if !tt.shouldErr && err != nil { - t.Errorf("unexpected error for %s: %v (%s)", tt.path, err, tt.reason) + t.Errorf("Validate(%q) unexpected error: %v (%s)", tt.path, err, tt.reason) } }) } @@ -77,7 +77,7 @@ func TestPathValidation(t *testing.T) { func TestPathErrorSanitization(t *testing.T) { validator, err := NewPath([]string{}) if err != nil { - t.Fatalf("failed to create path validator: %v", err) + t.Fatalf("creating path validator: %v", err) } // Try to access a path outside allowed directories @@ -103,24 +103,24 @@ func TestSymlinkValidation(t *testing.T) { tmpDir := t.TempDir() workDir, err := os.Getwd() if err != nil { - t.Fatalf("failed to get working directory: %v", err) + t.Fatalf("getting working directory: %v", err) } // Change to temp directory for testing if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("failed to change to temp directory: %v", err) + t.Fatalf("changing to temp directory: %v", err) } defer func() { _ = os.Chdir(workDir) }() // Restore original directory validator, err := NewPath([]string{tmpDir}) if err != nil { - t.Fatalf("failed to create path validator: %v", err) + t.Fatalf("creating path validator: %v", err) } // Create a file targetFile := filepath.Join(tmpDir, "target.txt") if err := os.WriteFile(targetFile, []byte("test"), 0o600); err != nil { - t.Fatalf("failed to create target file: %v", err) + t.Fatalf("creating target file: %v", err) } // Create a symlink to the file @@ -141,7 +141,7 @@ func TestSymlinkValidation(t *testing.T) { expectedPath = targetFile } if resolvedPath != expectedPath { - t.Errorf("expected resolved path %s, got %s", expectedPath, resolvedPath) + t.Errorf("Validate() resolved path = %q, want %q", resolvedPath, expectedPath) } } @@ -150,17 +150,17 @@ func TestPathValidationWithNonExistentFile(t *testing.T) { tmpDir := t.TempDir() workDir, err := os.Getwd() if err != nil { - t.Fatalf("failed to get working directory: %v", err) + t.Fatalf("getting working directory: %v", err) } if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("failed to change to temp directory: %v", err) + t.Fatalf("changing to temp directory: %v", err) } defer func() { _ = os.Chdir(workDir) }() validator, err := NewPath([]string{tmpDir}) if err != nil { - t.Fatalf("failed to create path validator: %v", err) + t.Fatalf("creating path validator: %v", err) } // Test with non-existent file (should be allowed for creating new files) @@ -170,7 +170,7 @@ func TestPathValidationWithNonExistentFile(t *testing.T) { t.Errorf("validation of non-existent file failed: %v", err) } if validatedPath != nonExistentPath { - t.Errorf("expected path %s, got %s", nonExistentPath, validatedPath) + t.Errorf("Validate(%q) = %q, want %q", nonExistentPath, validatedPath, nonExistentPath) } } @@ -179,24 +179,24 @@ func TestSymlinkBypassAttempt(t *testing.T) { tmpDir := t.TempDir() workDir, err := os.Getwd() if err != nil { - t.Fatalf("failed to get working directory: %v", err) + t.Fatalf("getting working directory: %v", err) } // Create another temp directory outside the allowed directory outsideDir := t.TempDir() outsideFile := filepath.Join(outsideDir, "secret.txt") if err := os.WriteFile(outsideFile, []byte("secret data"), 0o600); err != nil { - t.Fatalf("failed to create outside file: %v", err) + t.Fatalf("creating outside file: %v", err) } if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("failed to change to temp directory: %v", err) + t.Fatalf("changing to temp directory: %v", err) } defer func() { _ = os.Chdir(workDir) }() validator, err := NewPath([]string{tmpDir}) if err != nil { - t.Fatalf("failed to create path validator: %v", err) + t.Fatalf("creating path validator: %v", err) } // Create symlink inside allowed dir pointing to file outside @@ -212,7 +212,7 @@ func TestSymlinkBypassAttempt(t *testing.T) { } if err != nil && !errors.Is(err, ErrSymlinkOutsideAllowed) { - t.Errorf("expected ErrSymlinkOutsideAllowed, got: %v", err) + t.Errorf("Validate(symlink) error = %v, want ErrSymlinkOutsideAllowed", err) } } @@ -221,31 +221,32 @@ func TestPathValidationErrors(t *testing.T) { tmpDir := t.TempDir() workDir, err := os.Getwd() if err != nil { - t.Fatalf("failed to get working directory: %v", err) + t.Fatalf("getting working directory: %v", err) } if err := os.Chdir(tmpDir); err != nil { - t.Fatalf("failed to change to temp directory: %v", err) + t.Fatalf("changing to temp directory: %v", err) } defer func() { _ = os.Chdir(workDir) }() validator, err := NewPath([]string{tmpDir}) if err != nil { - t.Fatalf("failed to create path validator: %v", err) + t.Fatalf("creating path validator: %v", err) } // Test with extremely long path (should be handled gracefully) longPath := filepath.Join(tmpDir, string(make([]byte, 1000))) _, err = validator.Validate(longPath) - // Should not panic, error is acceptable - _ = err + if err == nil { + t.Error("Validate(longPath) = nil, want error for 1000-byte filename") + } } // BenchmarkPathValidation benchmarks path validation performance func BenchmarkPathValidation(b *testing.B) { validator, err := NewPath([]string{}) if err != nil { - b.Fatalf("failed to create path validator: %v", err) + b.Fatalf("creating path validator: %v", err) } b.ResetTimer() diff --git a/internal/security/url.go b/internal/security/url.go index 3a3b0dc..aef0122 100644 --- a/internal/security/url.go +++ b/internal/security/url.go @@ -1,7 +1,3 @@ -// Package security provides security validators for Koopa. -// -// URL validator prevents SSRF (Server-Side Request Forgery) attacks by blocking -// requests to private networks, cloud metadata endpoints, and other dangerous targets. package security import ( @@ -173,7 +169,7 @@ func (v *URL) safeDialContext(ctx context.Context, network, addr string) (net.Co } conn, dialErr := (&net.Dialer{}).DialContext(ctx, network, addr) if dialErr != nil { - return nil, fmt.Errorf("dial failed: %w", dialErr) + return nil, fmt.Errorf("dialing: %w", dialErr) } return conn, nil } @@ -181,7 +177,7 @@ func (v *URL) safeDialContext(ctx context.Context, network, addr string) (net.Co // Resolve DNS and check all returned IPs ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) if err != nil { - return nil, fmt.Errorf("DNS lookup failed: %w", err) + return nil, fmt.Errorf("resolving DNS: %w", err) } // Check all resolved IPs @@ -200,7 +196,7 @@ func (v *URL) safeDialContext(ctx context.Context, network, addr string) (net.Co } conn, err := (&net.Dialer{}).DialContext(ctx, network, targetAddr) if err != nil { - return nil, fmt.Errorf("dial to %s failed: %w", targetAddr, err) + return nil, fmt.Errorf("dialing %s: %w", targetAddr, err) } return conn, nil } diff --git a/internal/security/url_test.go b/internal/security/url_test.go index af9e650..20455e6 100644 --- a/internal/security/url_test.go +++ b/internal/security/url_test.go @@ -2,6 +2,7 @@ package security import ( "net" + "strings" "testing" ) @@ -164,7 +165,7 @@ func TestURL_Validate(t *testing.T) { t.Errorf("Validate(%q) expected error, got nil", tt.url) return } - if tt.errMsg != "" && !urlContains(err.Error(), tt.errMsg) { + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { t.Errorf("Validate(%q) error = %q, want error containing %q", tt.url, err.Error(), tt.errMsg) } } else if err != nil { @@ -205,7 +206,7 @@ func TestURL_checkIP(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ip := net.ParseIP(tt.ip) if ip == nil { - t.Fatalf("failed to parse IP: %s", tt.ip) + t.Fatalf("parsing IP: %s", tt.ip) } err := v.checkIP(ip) if tt.wantErr && err == nil { @@ -229,20 +230,32 @@ func TestURL_SafeTransport(t *testing.T) { if transport.DialContext == nil { t.Error("SafeTransport() DialContext is nil") } -} - -// Helper functions -func urlContains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || substr == "" || - (s != "" && substr != "" && urlContainsSubstring(s, substr))) -} + // Verify SafeTransport blocks dangerous IPs at the dial level. + // This tests DNS-rebinding protection: even if DNS resolves to a blocked IP, + // the custom DialContext must reject the connection. + tests := []struct { + name string + addr string + wantSub string // expected substring in error message + }{ + {name: "loopback", addr: "127.0.0.1:80", wantSub: "loopback"}, + {name: "private 10.x", addr: "10.0.0.1:80", wantSub: "private"}, + {name: "private 192.168.x", addr: "192.168.1.1:80", wantSub: "private"}, + {name: "link-local metadata", addr: "169.254.169.254:80", wantSub: "link-local"}, + {name: "IPv6 loopback", addr: "[::1]:80", wantSub: "loopback"}, + } -func urlContainsSubstring(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := transport.DialContext(t.Context(), "tcp", tt.addr) + if err == nil { + t.Errorf("SafeTransport().DialContext(%q) = nil, want error", tt.addr) + return + } + if !strings.Contains(err.Error(), tt.wantSub) { + t.Errorf("SafeTransport().DialContext(%q) error = %q, want error containing %q", tt.addr, err.Error(), tt.wantSub) + } + }) } - return false } diff --git a/internal/session/benchmark_test.go b/internal/session/benchmark_test.go index 42ebff7..bf7f20b 100644 --- a/internal/session/benchmark_test.go +++ b/internal/session/benchmark_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/firebase/genkit/go/ai" + "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" "github.com/koopa0/koopa/internal/sqlc" ) @@ -34,7 +35,7 @@ func BenchmarkStore_GetHistory(b *testing.B) { for b.Loop() { _, err := store.History(ctx, sessionID) if err != nil { - b.Fatalf("GetHistory failed: %v", err) + b.Fatalf("History(): %v", err) } } } @@ -51,7 +52,7 @@ func BenchmarkStore_GetHistory_SmallSession(b *testing.B) { b.ResetTimer() for b.Loop() { if _, err := store.History(ctx, sessionID); err != nil { - b.Fatalf("GetHistory failed: %v", err) + b.Fatalf("History(): %v", err) } } } @@ -68,7 +69,7 @@ func BenchmarkStore_GetHistory_LargeSession(b *testing.B) { b.ResetTimer() for b.Loop() { if _, err := store.History(ctx, sessionID); err != nil { - b.Fatalf("GetHistory failed: %v", err) + b.Fatalf("History(): %v", err) } } } @@ -83,9 +84,9 @@ func BenchmarkStore_AddMessages(b *testing.B) { store := New(sqlc.New(pool), pool, logger) // Create a test session - session, err := store.CreateSession(ctx, "Benchmark-AddMessages", "", "") + session, err := store.CreateSession(ctx, "Benchmark-AddMessages") if err != nil { - b.Fatalf("Failed to create session: %v", err) + b.Fatalf("creating session: %v", err) } defer func() { _ = store.DeleteSession(ctx, session.ID) }() @@ -124,9 +125,9 @@ func BenchmarkStore_AppendMessages(b *testing.B) { store := New(sqlc.New(pool), pool, logger) // Create a test session - session, err := store.CreateSession(ctx, "Benchmark-AppendMessages", "", "") + session, err := store.CreateSession(ctx, "Benchmark-AppendMessages") if err != nil { - b.Fatalf("Failed to create session: %v", err) + b.Fatalf("creating session: %v", err) } defer func() { _ = store.DeleteSession(ctx, session.ID) }() @@ -158,23 +159,21 @@ func BenchmarkStore_CreateSession(b *testing.B) { logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) store := New(sqlc.New(pool), pool, logger) - createdSessionIDs := make([]string, 0, b.N) + createdSessionIDs := make([]uuid.UUID, 0, b.N) defer func() { for _, id := range createdSessionIDs { - if parsed, err := parseBenchUUID(id); err == nil { - _ = store.DeleteSession(context.Background(), parsed) - } + _ = store.DeleteSession(context.Background(), id) } }() b.ReportAllocs() b.ResetTimer() for i := range b.N { - session, err := store.CreateSession(ctx, fmt.Sprintf("Benchmark-Session-%d", i), "test-model", "test-prompt") + session, err := store.CreateSession(ctx, fmt.Sprintf("Benchmark-Session-%d", i)) if err != nil { b.Fatalf("CreateSession failed at iteration %d: %v", i, err) } - createdSessionIDs = append(createdSessionIDs, session.ID.String()) + createdSessionIDs = append(createdSessionIDs, session.ID) } } @@ -188,9 +187,9 @@ func BenchmarkStore_GetSession(b *testing.B) { store := New(sqlc.New(pool), pool, logger) // Create a test session - session, err := store.CreateSession(ctx, "Benchmark-GetSession", "", "") + session, err := store.CreateSession(ctx, "Benchmark-GetSession") if err != nil { - b.Fatalf("Failed to create session: %v", err) + b.Fatalf("creating session: %v", err) } defer func() { _ = store.DeleteSession(ctx, session.ID) }() @@ -199,13 +198,13 @@ func BenchmarkStore_GetSession(b *testing.B) { for b.Loop() { _, err := store.Session(ctx, session.ID) if err != nil { - b.Fatalf("GetSession failed: %v", err) + b.Fatalf("Session(): %v", err) } } } -// BenchmarkStore_ListSessions benchmarks listing sessions. -func BenchmarkStore_ListSessions(b *testing.B) { +// BenchmarkStore_Sessions benchmarks listing sessions. +func BenchmarkStore_Sessions(b *testing.B) { ctx := context.Background() pool, cleanup := setupBenchmarkDB(b, ctx) defer cleanup() @@ -215,9 +214,9 @@ func BenchmarkStore_ListSessions(b *testing.B) { // Create some test sessions for i := 0; i < 20; i++ { - session, err := store.CreateSession(ctx, fmt.Sprintf("Benchmark-List-%d", i), "", "") + session, err := store.CreateSession(ctx, fmt.Sprintf("Benchmark-List-%d", i)) if err != nil { - b.Fatalf("Failed to create session: %v", err) + b.Fatalf("creating session: %v", err) } defer func(s *Session) { _ = store.DeleteSession(context.Background(), s.ID) }(session) } @@ -225,9 +224,9 @@ func BenchmarkStore_ListSessions(b *testing.B) { b.ReportAllocs() b.ResetTimer() for b.Loop() { - _, err := store.ListSessions(ctx, 100, 0) + _, err := store.Sessions(ctx, 100, 0) if err != nil { - b.Fatalf("ListSessions failed: %v", err) + b.Fatalf("Sessions() unexpected error: %v", err) } } } @@ -243,7 +242,7 @@ func BenchmarkStore_GetMessages(b *testing.B) { for b.Loop() { _, err := store.Messages(ctx, session.ID, 100, 0) if err != nil { - b.Fatalf("GetMessages failed: %v", err) + b.Fatalf("Messages(): %v", err) } } } @@ -257,10 +256,10 @@ func setupBenchmarkSession(b *testing.B, ctx context.Context, numMessages int) ( store := New(sqlc.New(pool), pool, logger) // Create a test session - session, err := store.CreateSession(ctx, "Benchmark-Session", "", "") + session, err := store.CreateSession(ctx, "Benchmark-Session") if err != nil { cleanup() - b.Fatalf("Failed to create session: %v", err) + b.Fatalf("creating session: %v", err) } // Pre-load messages (in batches for efficiency) @@ -285,7 +284,7 @@ func setupBenchmarkSession(b *testing.B, ctx context.Context, numMessages int) ( if err := store.AddMessages(ctx, session.ID, messages); err != nil { cleanup() - b.Fatalf("Failed to add messages: %v", err) + b.Fatalf("adding messages: %v", err) } } @@ -311,12 +310,12 @@ func setupBenchmarkDB(b *testing.B, ctx context.Context) (*pgxpool.Pool, func()) pool, err := pgxpool.New(ctx, dbURL) if err != nil { - b.Fatalf("Failed to connect to database: %v", err) + b.Fatalf("connecting to database: %v", err) } if err := pool.Ping(ctx); err != nil { pool.Close() - b.Fatalf("Failed to ping database: %v", err) + b.Fatalf("pinging database: %v", err) } cleanup := func() { @@ -327,34 +326,3 @@ func setupBenchmarkDB(b *testing.B, ctx context.Context) (*pgxpool.Pool, func()) return pool, cleanup } - -// parseBenchUUID is a helper to parse UUID strings for benchmarks. -func parseBenchUUID(s string) ([16]byte, error) { - var u [16]byte - if len(s) != 36 { - return u, fmt.Errorf("invalid UUID length") - } - // Simple UUID parsing - parse each segment into temporary variables - var a, b, c, d, e uint64 - if _, err := fmt.Sscanf(s, "%08x-%04x-%04x-%04x-%012x", &a, &b, &c, &d, &e); err != nil { - return u, err - } - // Pack into [16]byte - u[0] = byte(a >> 24) - u[1] = byte(a >> 16) - u[2] = byte(a >> 8) - u[3] = byte(a) - u[4] = byte(b >> 8) - u[5] = byte(b) - u[6] = byte(c >> 8) - u[7] = byte(c) - u[8] = byte(d >> 8) - u[9] = byte(d) - u[10] = byte(e >> 40) - u[11] = byte(e >> 32) - u[12] = byte(e >> 24) - u[13] = byte(e >> 16) - u[14] = byte(e >> 8) - u[15] = byte(e) - return u, nil -} diff --git a/internal/session/doc.go b/internal/session/doc.go index 7c8a3a0..17dfeb2 100644 --- a/internal/session/doc.go +++ b/internal/session/doc.go @@ -1,225 +1,30 @@ -// Package session provides conversation history persistence. +// Package session provides conversation history persistence with PostgreSQL. // -// The session package manages conversation sessions and messages with PostgreSQL backend. -// It provides thread-safe storage for conversation history with concurrent message insertion -// and transaction-safe operations. +// A session represents a conversation context containing ordered messages +// exchanged between user and model. The [Store] handles persistence while +// the agent handles conversation logic. // -// # Overview +// Key operations: // -// A session represents a conversation context. Each session can contain multiple messages -// exchanged between user and model. The package handles persistence while the agent handles -// the conversation logic. -// -// Key responsibilities: -// -// - Session lifecycle (create, retrieve, list, delete) -// - Message persistence with sequential ordering -// - Transaction-safe batch message insertion -// - Concurrent access safety -// -// # Architecture -// -// Session and Message Organization: -// -// Session (conversation context) -// | -// +-- Metadata (ID, title, model_name, created_at, updated_at) -// | -// v -// Messages (ordered conversation) -// | -// +-- Message 1 (role: "user") -// +-- Message 2 (role: "model") -// +-- Message 3 (role: "user") -// +-- ... -// -// Messages are ordered by sequence number within a session, ensuring -// consistent retrieval and reconstruction of conversation context. -// -// # Session Management -// -// The Store type provides session operations: -// -// CreateSession(ctx, title, modelName, systemPrompt) - Create new session -// Session(ctx, sessionID) - Retrieve session -// ListSessions(ctx, limit, offset) - List sessions with pagination -// DeleteSession(ctx, sessionID) - Delete session and messages -// -// Sessions store optional metadata: -// -// - Title: User-friendly name -// - ModelName: LLM model used for this session -// - SystemPrompt: Custom system instructions -// - MessageCount: Total number of messages (cached for efficiency) -// -// # Message Persistence -// -// The Store provides message operations: -// -// AddMessages(ctx, sessionID, messages) - Batch insert with transaction safety -// Messages(ctx, sessionID, limit, offset) - Retrieve messages with pagination -// -// Messages are stored with: -// -// - Role: "user", "model", or "tool" -// - Content: Array of ai.Part (serialized as JSONB) -// - SequenceNumber: Sequential ordering within session -// - CreatedAt: Timestamp +// - Session lifecycle: [Store.CreateSession], [Store.Session], [Store.Sessions], [Store.DeleteSession], [Store.ResolveCurrentSession] +// - Message persistence: [Store.AddMessages], [Store.Messages] (transaction-safe batch insertion) +// - Agent integration: [Store.History], [Store.AppendMessages] // // # Transaction Safety // -// AddMessages provides ACID guarantees for batch message insertion: -// -// 1. Lock session row (SELECT ... FOR UPDATE) -// 2. Get current max sequence number -// 3. Insert messages in batch with next sequence numbers -// 4. Update session metadata (message_count, updated_at) -// 5. Commit transaction atomically -// -// If any step fails, the entire transaction rolls back, ensuring consistency. -// Session locking prevents race conditions in concurrent scenarios. -// -// # Chat Agent Integration -// -// The Store provides methods for integration with the Chat agent: -// -// History(ctx, sessionID) - Get conversation history for agent -// AppendMessages(ctx, sessionID, messages) - Persist conversation messages -// -// Following Go standard library conventions (similar to database/sql returning *sql.DB), -// consumers use *session.Store directly instead of defining separate interfaces. -// Testability is achieved via the internal Querier interface (for mocking database operations). -// -// # Database Backend -// -// The session store requires PostgreSQL with the following schema: -// -// sessions table: -// id UUID PRIMARY KEY -// title TEXT -// model_name TEXT -// system_prompt TEXT -// message_count INT32 -// created_at TIMESTAMPTZ -// updated_at TIMESTAMPTZ -// -// session_messages table: -// id UUID PRIMARY KEY -// session_id UUID FOREIGN KEY (CASCADE) -// role TEXT (user|model|tool) -// content JSONB (ai.Part array) -// sequence_number INT32 -// created_at TIMESTAMPTZ -// -// # Example Usage -// -// package main -// -// import ( -// "context" -// "github.com/firebase/genkit/go/ai" -// "github.com/jackc/pgx/v5/pgxpool" -// "github.com/koopa0/koopa/internal/session" -// "log/slog" -// ) -// -// func main() { -// ctx := context.Background() -// -// // Connect to PostgreSQL -// dbPool, _ := pgxpool.New(ctx, "postgresql://...") -// defer dbPool.Close() -// -// // Create session store -// store := session.New(sqlc.New(dbPool), dbPool, slog.Default()) -// -// // Create a new session -// sess, err := store.CreateSession(ctx, "My Conversation", "gemini-pro", "") -// if err != nil { -// panic(err) -// } -// println("Session ID:", sess.ID) -// -// // Add messages to session -// messages := []*session.Message{ -// { -// Role: session.RoleUser, -// Content: []*ai.Part{ai.NewTextPart("Hello!")}, -// }, -// { -// Role: session.RoleModel, -// Content: []*ai.Part{ai.NewTextPart("Hi there!")}, -// }, -// } -// -// err = store.AddMessages(ctx, sess.ID, messages) -// if err != nil { -// panic(err) -// } -// -// // Retrieve messages -// retrieved, _ := store.Messages(ctx, sess.ID, 100, 0) -// println("Retrieved", len(retrieved), "messages") -// -// // Load history for agent -// history, _ := store.History(ctx, sess.ID) -// println("History messages:", len(history.Messages())) -// -// // List all sessions -// sessions, _ := store.ListSessions(ctx, 10, 0) -// println("Total sessions:", len(sessions)) -// } -// -// # Concurrency and Race Conditions -// -// The Store is designed to handle concurrent access safely: -// -// - Session locking (SELECT ... FOR UPDATE) prevents concurrent modifications -// - Transactions ensure atomicity of batch operations -// - PostgreSQL isolation levels handle concurrent reads -// - No shared state in Go code (all state in database) -// -// However, callers should avoid concurrent modifications to the same session -// to prevent transaction conflicts and retries. -// -// # Pagination -// -// All list operations support limit/offset pagination: -// -// - ListSessions(ctx, limit=10, offset=0) // First 10 sessions -// - GetMessages(ctx, sessionID, limit=50, offset=0) // First 50 messages -// -// Sessions are ordered by updated_at descending (most recent first). -// Messages are ordered by sequence_number ascending (earliest first). -// -// # Error Handling -// -// The Store propagates database errors with context: -// -// - "failed to create session: ..." - Creation failures -// - "failed to get session ...: ..." - Retrieval failures -// - "failed to lock session: ..." - Concurrency issues -// - "failed to insert message ...: ..." - Message insertion failures -// - "failed to unmarshal message content: ..." - Deserialization errors (skipped) -// -// Malformed messages are skipped during retrieval, allowing resilience -// to schema changes. -// -// # Testing -// -// The session package is designed for testability: +// [Store.AddMessages] uses SELECT ... FOR UPDATE to lock the session row, +// preventing race conditions on sequence numbers during concurrent writes. +// If any step fails, the entire transaction rolls back. // -// - Store accepts Querier interface for mock database -// - New() accepts interface, pass mock querier directly for tests -// - Integration tests use real PostgreSQL database -// - Supports non-transactional mode for mock testing +// # Concurrency // -// # Thread Safety +// Store is safe for concurrent use. All state lives in PostgreSQL; +// no shared Go-side state exists. Session locking and transaction +// isolation handle concurrent access. // -// The Store is thread-safe for concurrent use: +// # Local State // -// - All database operations use connection pool -// - PostgreSQL handles concurrent access safely -// - No shared state in Go code -// - Transactions provide isolation +// [SaveCurrentSessionID] and [LoadCurrentSessionID] persist the active session +// to ~/.koopa/current_session using atomic writes (temp file + rename) with +// file locking via [github.com/gofrs/flock]. package session diff --git a/internal/session/errors.go b/internal/session/errors.go deleted file mode 100644 index 94ead7e..0000000 --- a/internal/session/errors.go +++ /dev/null @@ -1,56 +0,0 @@ -package session - -import "errors" - -// Message status constants for streaming lifecycle. -const ( - // StatusStreaming indicates the message is being streamed (AI generating). - StatusStreaming = "streaming" - - // StatusCompleted indicates the message has been fully generated. - StatusCompleted = "completed" - - // StatusFailed indicates the message generation failed. - StatusFailed = "failed" -) - -// History limit constants. -const ( - // DefaultHistoryLimit is the default number of messages to load. - DefaultHistoryLimit int32 = 100 - - // MaxHistoryLimit is the absolute maximum to prevent OOM. - MaxHistoryLimit int32 = 10000 - - // MinHistoryLimit is the minimum allowed value for history limit. - MinHistoryLimit int32 = 10 -) - -// Sentinel errors for session operations. -// These errors are part of the Store's public API and should be checked using errors.Is(). -// -// Example: -// -// sess, err := store.Session(ctx, id) -// if errors.Is(err, session.ErrSessionNotFound) { -// // Handle missing session -// } -// -// ErrSessionNotFound indicates the requested session does not exist in the database. -var ErrSessionNotFound = errors.New("session not found") - -// NormalizeHistoryLimit normalizes the history limit value. -// Returns DefaultHistoryLimit for zero/negative values. -// Clamps to MinHistoryLimit/MaxHistoryLimit as bounds. -func NormalizeHistoryLimit(limit int32) int32 { - if limit <= 0 { - return DefaultHistoryLimit - } - if limit < MinHistoryLimit { - return MinHistoryLimit - } - if limit > MaxHistoryLimit { - return MaxHistoryLimit - } - return limit -} diff --git a/internal/session/errors_test.go b/internal/session/errors_test.go deleted file mode 100644 index e9592ce..0000000 --- a/internal/session/errors_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package session - -import ( - "testing" -) - -func TestNormalizeHistoryLimit(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input int32 - want int32 - }{ - // Default cases - {"zero defaults", 0, DefaultHistoryLimit}, - {"negative defaults", -1, DefaultHistoryLimit}, - {"large negative defaults", -999, DefaultHistoryLimit}, - - // Clamping to minimum - {"below min clamped", MinHistoryLimit - 1, MinHistoryLimit}, - {"exactly min", MinHistoryLimit, MinHistoryLimit}, - - // Valid middle values - {"valid 50", 50, 50}, - {"valid 100", 100, 100}, - {"valid 500", 500, 500}, - {"valid 5000", 5000, 5000}, - - // Clamping to maximum - {"exactly max", MaxHistoryLimit, MaxHistoryLimit}, - {"above max clamped", MaxHistoryLimit + 1, MaxHistoryLimit}, - {"large above max", MaxHistoryLimit * 2, MaxHistoryLimit}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NormalizeHistoryLimit(tt.input) - if got != tt.want { - t.Errorf("NormalizeHistoryLimit(%d) = %d, want %d", tt.input, got, tt.want) - } - }) - } -} - -func TestConstants(t *testing.T) { - t.Parallel() - - t.Run("DefaultHistoryLimit", func(t *testing.T) { - if DefaultHistoryLimit != 100 { - t.Errorf("DefaultHistoryLimit = %d, want %d", DefaultHistoryLimit, 100) - } - }) - - t.Run("MaxHistoryLimit", func(t *testing.T) { - if MaxHistoryLimit != 10000 { - t.Errorf("MaxHistoryLimit = %d, want %d", MaxHistoryLimit, 10000) - } - }) - - t.Run("MinHistoryLimit", func(t *testing.T) { - if MinHistoryLimit != 10 { - t.Errorf("MinHistoryLimit = %d, want %d", MinHistoryLimit, 10) - } - }) - - t.Run("StatusConstants", func(t *testing.T) { - if StatusStreaming != "streaming" { - t.Errorf("StatusStreaming = %q, want %q", StatusStreaming, "streaming") - } - if StatusCompleted != "completed" { - t.Errorf("StatusCompleted = %q, want %q", StatusCompleted, "completed") - } - if StatusFailed != "failed" { - t.Errorf("StatusFailed = %q, want %q", StatusFailed, "failed") - } - }) -} - -// BenchmarkNormalizeHistoryLimit benchmarks limit normalization. -func BenchmarkNormalizeHistoryLimit(b *testing.B) { - limits := []int32{0, -1, 50, 100, 10001} - - b.ResetTimer() - for b.Loop() { - for _, limit := range limits { - _ = NormalizeHistoryLimit(limit) - } - } -} - -// TestNormalizeRole tests the Genkit role normalization function. -// Genkit uses "model" for AI responses, but we store "assistant" in the database -// for consistency with the CHECK constraint. -func TestNormalizeRole(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input string - expected string - }{ - {"model to assistant", "model", "assistant"}, - {"user unchanged", "user", "user"}, - {"assistant unchanged", "assistant", "assistant"}, - {"system unchanged", "system", "system"}, - {"tool unchanged", "tool", "tool"}, - {"empty passthrough", "", ""}, - {"unknown passthrough", "unknown", "unknown"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got := normalizeRole(tt.input) - if got != tt.expected { - t.Errorf("normalizeRole(%q) = %q, want %q", tt.input, got, tt.expected) - } - }) - } -} diff --git a/internal/session/integration_test.go b/internal/session/integration_test.go index fbad17e..35593f5 100644 --- a/internal/session/integration_test.go +++ b/internal/session/integration_test.go @@ -20,31 +20,22 @@ import ( "github.com/koopa0/koopa/internal/testutil" ) -// ============================================================================= -// Test Setup Helper -// ============================================================================= - // setupIntegrationTest creates a Store with test database connection. // All integration tests should use this unified setup. -func setupIntegrationTest(t *testing.T) (*Store, func()) { +// Cleanup is automatic via tb.Cleanup — no manual cleanup needed. +func setupIntegrationTest(t *testing.T) *Store { t.Helper() - dbContainer, cleanup := testutil.SetupTestDB(t) - store := New(sqlc.New(dbContainer.Pool), dbContainer.Pool, slog.Default()) - return store, cleanup + dbContainer := testutil.SetupTestDB(t) + return New(sqlc.New(dbContainer.Pool), dbContainer.Pool, slog.Default()) } -// ============================================================================= -// Basic CRUD Tests -// ============================================================================= - // TestStore_CreateAndGet tests creating and retrieving a session func TestStore_CreateAndGet(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "Test Session", "gemini-2.5-flash", "You are a helpful assistant") + session, err := store.CreateSession(ctx, "Test Session") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -57,12 +48,6 @@ func TestStore_CreateAndGet(t *testing.T) { if session.Title != "Test Session" { t.Errorf("CreateSession() Title = %q, want %q", session.Title, "Test Session") } - if session.ModelName != "gemini-2.5-flash" { - t.Errorf("CreateSession() ModelName = %q, want %q", session.ModelName, "gemini-2.5-flash") - } - if session.SystemPrompt != "You are a helpful assistant" { - t.Errorf("CreateSession() SystemPrompt = %q, want %q", session.SystemPrompt, "You are a helpful assistant") - } if session.CreatedAt.IsZero() { t.Error("CreateSession() CreatedAt should be set") } @@ -73,33 +58,26 @@ func TestStore_CreateAndGet(t *testing.T) { // Retrieve the session retrieved, err := store.Session(ctx, session.ID) if err != nil { - t.Fatalf("GetSession(%v) unexpected error: %v", session.ID, err) + t.Fatalf("Session(%v) unexpected error: %v", session.ID, err) } if retrieved == nil { - t.Fatal("GetSession() returned nil session") + t.Fatal("Session() returned nil session") } if retrieved.ID != session.ID { - t.Errorf("GetSession() ID = %v, want %v", retrieved.ID, session.ID) + t.Errorf("Session() ID = %v, want %v", retrieved.ID, session.ID) } if retrieved.Title != session.Title { - t.Errorf("GetSession() Title = %q, want %q", retrieved.Title, session.Title) - } - if retrieved.ModelName != session.ModelName { - t.Errorf("GetSession() ModelName = %q, want %q", retrieved.ModelName, session.ModelName) - } - if retrieved.SystemPrompt != session.SystemPrompt { - t.Errorf("GetSession() SystemPrompt = %q, want %q", retrieved.SystemPrompt, session.SystemPrompt) + t.Errorf("Session() Title = %q, want %q", retrieved.Title, session.Title) } } // TestStore_CreateWithEmptyFields tests creating session with empty optional fields func TestStore_CreateWithEmptyFields(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() - // Create session with empty title and system prompt - session, err := store.CreateSession(ctx, "", "gemini-2.5-flash", "") + // Create session with empty title + session, err := store.CreateSession(ctx, "") if err != nil { t.Fatalf("CreateSession() with empty fields unexpected error: %v", err) } @@ -109,79 +87,65 @@ func TestStore_CreateWithEmptyFields(t *testing.T) { if session.Title != "" { t.Errorf("CreateSession() Title = %q, want empty string", session.Title) } - if session.ModelName != "gemini-2.5-flash" { - t.Errorf("CreateSession() ModelName = %q, want %q", session.ModelName, "gemini-2.5-flash") - } - if session.SystemPrompt != "" { - t.Errorf("CreateSession() SystemPrompt = %q, want empty string", session.SystemPrompt) - } // Retrieve should work retrieved, err := store.Session(ctx, session.ID) if err != nil { - t.Fatalf("GetSession(%v) unexpected error: %v", session.ID, err) + t.Fatalf("Session(%v) unexpected error: %v", session.ID, err) } if retrieved.Title != "" { - t.Errorf("GetSession() Title = %q, want empty string", retrieved.Title) - } - if retrieved.SystemPrompt != "" { - t.Errorf("GetSession() SystemPrompt = %q, want empty string", retrieved.SystemPrompt) + t.Errorf("Session() Title = %q, want empty string", retrieved.Title) } } // TestStore_ListSessions tests listing sessions with pagination func TestStore_ListSessions_Integration(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create multiple sessions for i := 0; i < 5; i++ { - _, err := store.CreateSession(ctx, - fmt.Sprintf("Session %d", i+1), - "gemini-2.5-flash", - "") + _, err := store.CreateSession(ctx, fmt.Sprintf("Session %d", i+1)) if err != nil { t.Fatalf("CreateSession(%d) unexpected error: %v", i+1, err) } } // List all sessions - sessions, err := store.ListSessions(ctx, 10, 0) + sessions, err := store.Sessions(ctx, 10, 0) if err != nil { - t.Fatalf("ListSessions(10, 0) unexpected error: %v", err) + t.Fatalf("Sessions(10, 0) unexpected error: %v", err) } if len(sessions) < 5 { - t.Errorf("ListSessions(10, 0) returned %d sessions, want at least 5", len(sessions)) + t.Errorf("Sessions(10, 0) returned %d sessions, want at least 5", len(sessions)) } // Test pagination - first 3 - sessions, err = store.ListSessions(ctx, 3, 0) + sessions, err = store.Sessions(ctx, 3, 0) if err != nil { - t.Fatalf("ListSessions(3, 0) unexpected error: %v", err) + t.Fatalf("Sessions(3, 0) unexpected error: %v", err) } if len(sessions) != 3 { - t.Errorf("ListSessions(3, 0) returned %d sessions, want exactly 3", len(sessions)) + t.Errorf("Sessions(3, 0) returned %d sessions, want exactly 3", len(sessions)) } // Test pagination - next 2 - sessions, err = store.ListSessions(ctx, 3, 3) + sessions, err = store.Sessions(ctx, 3, 3) if err != nil { - t.Fatalf("ListSessions(3, 3) unexpected error: %v", err) + t.Fatalf("Sessions(3, 3) unexpected error: %v", err) } if len(sessions) < 2 { - t.Errorf("ListSessions(3, 3) returned %d sessions, want at least 2", len(sessions)) + t.Errorf("Sessions(3, 3) returned %d sessions, want at least 2", len(sessions)) } } // TestStore_DeleteSession tests deleting a session func TestStore_DeleteSession_Integration(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "To Be Deleted", "gemini-2.5-flash", "") + session, err := store.CreateSession(ctx, "To Be Deleted") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -189,7 +153,7 @@ func TestStore_DeleteSession_Integration(t *testing.T) { // Verify it exists _, err = store.Session(ctx, session.ID) if err != nil { - t.Fatalf("GetSession(%v) before delete unexpected error: %v", session.ID, err) + t.Fatalf("Session(%v) before delete unexpected error: %v", session.ID, err) } // Delete the session @@ -201,22 +165,17 @@ func TestStore_DeleteSession_Integration(t *testing.T) { // Verify it no longer exists _, err = store.Session(ctx, session.ID) if err == nil { - t.Errorf("GetSession(%v) after delete should return error", session.ID) + t.Errorf("Session(%v) after delete should return error", session.ID) } } -// ============================================================================= -// Message Tests -// ============================================================================= - // TestStore_AddMessage tests adding messages to a session func TestStore_AddMessage(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "Message Test", "gemini-2.5-flash", "") + session, err := store.CreateSession(ctx, "Message Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -250,35 +209,34 @@ func TestStore_AddMessage(t *testing.T) { // Retrieve messages messages, err := store.Messages(ctx, session.ID, 10, 0) if err != nil { - t.Fatalf("GetMessages(%v, 10, 0) unexpected error: %v", session.ID, err) + t.Fatalf("Messages(%v, 10, 0) unexpected error: %v", session.ID, err) } if len(messages) != 2 { - t.Fatalf("GetMessages() returned %d messages, want 2", len(messages)) + t.Fatalf("Messages() returned %d messages, want 2", len(messages)) } // Verify order (should be chronological) if messages[0].Role != string(ai.RoleUser) { - t.Errorf("GetMessages()[0].Role = %q, want %q", messages[0].Role, string(ai.RoleUser)) + t.Errorf("Messages()[0].Role = %q, want %q", messages[0].Role, string(ai.RoleUser)) } if messages[0].Content[0].Text != "Hello, how are you?" { - t.Errorf("GetMessages()[0].Content[0].Text = %q, want %q", messages[0].Content[0].Text, "Hello, how are you?") + t.Errorf("Messages()[0].Content[0].Text = %q, want %q", messages[0].Content[0].Text, "Hello, how are you?") } if messages[1].Role != string(ai.RoleModel) { - t.Errorf("GetMessages()[1].Role = %q, want %q", messages[1].Role, string(ai.RoleModel)) + t.Errorf("Messages()[1].Role = %q, want %q", messages[1].Role, string(ai.RoleModel)) } if messages[1].Content[0].Text != "I'm doing well, thank you!" { - t.Errorf("GetMessages()[1].Content[0].Text = %q, want %q", messages[1].Content[0].Text, "I'm doing well, thank you!") + t.Errorf("Messages()[1].Content[0].Text = %q, want %q", messages[1].Content[0].Text, "I'm doing well, thank you!") } } // TestStore_GetMessages tests retrieving messages with pagination func TestStore_GetMessages_Integration(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "Pagination Test", "gemini-2.5-flash", "") + session, err := store.CreateSession(ctx, "Pagination Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -302,36 +260,35 @@ func TestStore_GetMessages_Integration(t *testing.T) { // Get first 5 messages retrieved, err := store.Messages(ctx, session.ID, 5, 0) if err != nil { - t.Fatalf("GetMessages(%v, 5, 0) unexpected error: %v", session.ID, err) + t.Fatalf("Messages(%v, 5, 0) unexpected error: %v", session.ID, err) } if len(retrieved) != 5 { - t.Errorf("GetMessages(%v, 5, 0) returned %d messages, want 5", session.ID, len(retrieved)) + t.Errorf("Messages(%v, 5, 0) returned %d messages, want 5", session.ID, len(retrieved)) } if retrieved[0].Content[0].Text != "Message 1" { - t.Errorf("GetMessages(%v, 5, 0)[0].Content[0].Text = %q, want %q", session.ID, retrieved[0].Content[0].Text, "Message 1") + t.Errorf("Messages(%v, 5, 0)[0].Content[0].Text = %q, want %q", session.ID, retrieved[0].Content[0].Text, "Message 1") } // Get next 5 messages retrieved, err = store.Messages(ctx, session.ID, 5, 5) if err != nil { - t.Fatalf("GetMessages(%v, 5, 5) unexpected error: %v", session.ID, err) + t.Fatalf("Messages(%v, 5, 5) unexpected error: %v", session.ID, err) } if len(retrieved) != 5 { - t.Errorf("GetMessages(%v, 5, 5) returned %d messages, want 5", session.ID, len(retrieved)) + t.Errorf("Messages(%v, 5, 5) returned %d messages, want 5", session.ID, len(retrieved)) } if retrieved[0].Content[0].Text != "Message 6" { - t.Errorf("GetMessages(%v, 5, 5)[0].Content[0].Text = %q, want %q", session.ID, retrieved[0].Content[0].Text, "Message 6") + t.Errorf("Messages(%v, 5, 5)[0].Content[0].Text = %q, want %q", session.ID, retrieved[0].Content[0].Text, "Message 6") } } // TestStore_MessageOrdering tests that messages maintain chronological order func TestStore_MessageOrdering(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "Ordering Test", "gemini-2.5-flash", "") + session, err := store.CreateSession(ctx, "Ordering Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -361,10 +318,10 @@ func TestStore_MessageOrdering(t *testing.T) { // Retrieve all messages retrieved, err := store.Messages(ctx, session.ID, 100, 0) if err != nil { - t.Fatalf("GetMessages(%v, 100, 0) unexpected error: %v", session.ID, err) + t.Fatalf("Messages(%v, 100, 0) unexpected error: %v", session.ID, err) } if len(retrieved) != 6 { - t.Fatalf("GetMessages() returned %d messages, want 6", len(retrieved)) + t.Fatalf("Messages() returned %d messages, want 6", len(retrieved)) } // Verify order @@ -392,12 +349,11 @@ func TestStore_MessageOrdering(t *testing.T) { // TestStore_LargeMessageContent tests handling large message content func TestStore_LargeMessageContent(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "Large Content Test", "gemini-2.5-flash", "") + session, err := store.CreateSession(ctx, "Large Content Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -420,24 +376,23 @@ func TestStore_LargeMessageContent(t *testing.T) { // Retrieve and verify messages, err := store.Messages(ctx, session.ID, 10, 0) if err != nil { - t.Fatalf("GetMessages(%v, 10, 0) unexpected error: %v", session.ID, err) + t.Fatalf("Messages(%v, 10, 0) unexpected error: %v", session.ID, err) } if len(messages) != 1 { - t.Fatalf("GetMessages() returned %d messages, want 1", len(messages)) + t.Fatalf("Messages() returned %d messages, want 1", len(messages)) } if messages[0].Content[0].Text != largeText { - t.Errorf("GetMessages()[0].Content[0].Text length = %d, want %d (large content not preserved)", len(messages[0].Content[0].Text), len(largeText)) + t.Errorf("Messages()[0].Content[0].Text length = %d, want %d (large content not preserved)", len(messages[0].Content[0].Text), len(largeText)) } } // TestStore_DeleteSessionWithMessages tests that deleting a session also deletes messages func TestStore_DeleteSessionWithMessages(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a session with messages - session, err := store.CreateSession(ctx, "Cascade Delete Test", "gemini-2.5-flash", "") + session, err := store.CreateSession(ctx, "Cascade Delete Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -456,10 +411,10 @@ func TestStore_DeleteSessionWithMessages(t *testing.T) { // Verify message exists retrieved, err := store.Messages(ctx, session.ID, 10, 0) if err != nil { - t.Fatalf("GetMessages(%v, 10, 0) before delete unexpected error: %v", session.ID, err) + t.Fatalf("Messages(%v, 10, 0) before delete unexpected error: %v", session.ID, err) } if len(retrieved) != 1 { - t.Errorf("GetMessages() before delete returned %d messages, want 1", len(retrieved)) + t.Errorf("Messages() before delete returned %d messages, want 1", len(retrieved)) } // Delete session @@ -471,24 +426,19 @@ func TestStore_DeleteSessionWithMessages(t *testing.T) { // Verify session is deleted _, err = store.Session(ctx, session.ID) if err == nil { - t.Error("GetSession() after delete should return error") + t.Error("Session() after delete should return error") } } -// ============================================================================= -// Race Condition Tests -// ============================================================================= - // TestStore_ConcurrentSessionCreation tests that multiple goroutines can create // sessions simultaneously without data corruption or race conditions. func TestStore_ConcurrentSessionCreation(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() const numGoroutines = 10 var wg sync.WaitGroup - errors := make(chan error, numGoroutines) + errs := make(chan error, numGoroutines) sessionIDs := make(chan string, numGoroutines) // Create sessions concurrently @@ -497,9 +447,9 @@ func TestStore_ConcurrentSessionCreation(t *testing.T) { go func(id int) { defer wg.Done() title := fmt.Sprintf("Race-Session-%d", id) - session, err := store.CreateSession(ctx, title, "test-model", "test-prompt") + session, err := store.CreateSession(ctx, title) if err != nil { - errors <- fmt.Errorf("goroutine %d: %w", id, err) + errs <- fmt.Errorf("goroutine %d: %w", id, err) return } sessionIDs <- session.ID.String() @@ -507,12 +457,12 @@ func TestStore_ConcurrentSessionCreation(t *testing.T) { } wg.Wait() - close(errors) + close(errs) close(sessionIDs) // Check for errors var errCount int - for err := range errors { + for err := range errs { t.Errorf("concurrent creation error: %v", err) errCount++ } @@ -530,18 +480,16 @@ func TestStore_ConcurrentSessionCreation(t *testing.T) { if len(ids) != expectedCount { t.Errorf("created %d unique sessions, want %d", len(ids), expectedCount) } - t.Logf("Successfully created %d sessions concurrently", len(ids)) } // TestStore_ConcurrentHistoryUpdate tests that multiple goroutines can add // messages to the same session without data corruption. func TestStore_ConcurrentHistoryUpdate(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a test session - session, err := store.CreateSession(ctx, "Race-History-Test", "", "") + session, err := store.CreateSession(ctx, "Race-History-Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -549,7 +497,7 @@ func TestStore_ConcurrentHistoryUpdate(t *testing.T) { const numGoroutines = 10 const messagesPerGoroutine = 5 var wg sync.WaitGroup - errors := make(chan error, numGoroutines) + errs := make(chan error, numGoroutines) var successCount atomic.Int32 // Add messages concurrently @@ -567,7 +515,7 @@ func TestStore_ConcurrentHistoryUpdate(t *testing.T) { } if err := store.AddMessages(ctx, session.ID, messages); err != nil { - errors <- fmt.Errorf("goroutine %d: %w", goroutineID, err) + errs <- fmt.Errorf("goroutine %d: %w", goroutineID, err) return } successCount.Add(1) @@ -575,22 +523,22 @@ func TestStore_ConcurrentHistoryUpdate(t *testing.T) { } wg.Wait() - close(errors) + close(errs) // Check for errors - for err := range errors { + for err := range errs { t.Errorf("concurrent history update error: %v", err) } // Verify message integrity allMessages, err := store.Messages(ctx, session.ID, 1000, 0) if err != nil { - t.Fatalf("GetMessages(%v, 1000, 0) unexpected error: %v", session.ID, err) + t.Fatalf("Messages(%v, 1000, 0) unexpected error: %v", session.ID, err) } expectedCount := int(successCount.Load()) * messagesPerGoroutine if len(allMessages) != expectedCount { - t.Errorf("GetMessages() returned %d messages, want %d", len(allMessages), expectedCount) + t.Errorf("Messages() returned %d messages, want %d", len(allMessages), expectedCount) } // Verify sequence numbers are unique and sequential @@ -608,19 +556,16 @@ func TestStore_ConcurrentHistoryUpdate(t *testing.T) { t.Errorf("missing sequence number: %d", i) } } - - t.Logf("Successfully added %d messages from %d goroutines", len(allMessages), successCount.Load()) } // TestStore_ConcurrentLoadAndSaveHistory tests simultaneous load and save // operations on the same session. func TestStore_ConcurrentLoadAndSaveHistory(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a test session with initial messages - session, err := store.CreateSession(ctx, "Race-LoadSave-Test", "", "") + session, err := store.CreateSession(ctx, "Race-LoadSave-Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -653,9 +598,9 @@ func TestStore_ConcurrentLoadAndSaveHistory(t *testing.T) { return } // Verify we got at least the initial messages - if len(history.Messages()) < 2 { + if len(history) < 2 { loadErrors <- fmt.Errorf("load goroutine %d: expected at least 2 messages, got %d", - id, len(history.Messages())) + id, len(history)) } }(i) @@ -687,22 +632,19 @@ func TestStore_ConcurrentLoadAndSaveHistory(t *testing.T) { // Verify final state finalHistory, err := store.History(ctx, sessionID) if err != nil { - t.Fatalf("GetHistory(%v) final state unexpected error: %v", sessionID, err) + t.Fatalf("History(%v) final state unexpected error: %v", sessionID, err) } // Should have at least initial messages + some concurrent messages - if len(finalHistory.Messages()) < 2 { - t.Errorf("GetHistory() final state returned %d messages, want at least 2", len(finalHistory.Messages())) + if len(finalHistory) < 2 { + t.Errorf("History() final state returned %d messages, want at least 2", len(finalHistory)) } - - t.Logf("Final history has %d messages after concurrent load/save", len(finalHistory.Messages())) } // TestStore_ConcurrentSessionDeletion tests that deleting sessions while // other operations are in progress doesn't cause crashes or data corruption. func TestStore_ConcurrentSessionDeletion(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() const numSessions = 5 @@ -710,7 +652,7 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) { // Create test sessions for i := 0; i < numSessions; i++ { - session, err := store.CreateSession(ctx, fmt.Sprintf("Race-Delete-Test-%d", i), "", "") + session, err := store.CreateSession(ctx, fmt.Sprintf("Race-Delete-Test-%d", i)) if err != nil { t.Fatalf("CreateSession() for session %d unexpected error: %v", i, err) } @@ -744,7 +686,7 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) { // List goroutine go func() { defer wg.Done() - _, _ = store.ListSessions(ctx, 100, 0) + _, _ = store.Sessions(ctx, 100, 0) }() // Get goroutine @@ -757,9 +699,9 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) { wg.Wait() // Verify all sessions are deleted - remaining, err := store.ListSessions(ctx, 100, 0) + remaining, err := store.Sessions(ctx, 100, 0) if err != nil { - t.Fatalf("ListSessions(100, 0) after concurrent deletion unexpected error: %v", err) + t.Fatalf("Sessions(100, 0) after concurrent deletion unexpected error: %v", err) } for _, session := range sessions { @@ -775,8 +717,6 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) { _ = store.DeleteSession(ctx, session.ID) } } - - t.Log("Concurrent deletion test completed without crashes") } // TestStore_RaceDetector is a comprehensive test designed to trigger @@ -784,12 +724,11 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) { // // Run with: go test -race -tags=integration ./internal/session/... func TestStore_RaceDetector(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a shared session - session, err := store.CreateSession(ctx, "Race-Detector-Test", "", "") + session, err := store.CreateSession(ctx, "Race-Detector-Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -833,20 +772,18 @@ func TestStore_RaceDetector(t *testing.T) { } wg.Wait() - t.Log("Race detector test completed - if no race detected, Store is thread-safe") } // TestStore_ConcurrentWrites tests concurrent writes to different sessions func TestStore_ConcurrentWrites(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create multiple sessions numSessions := 5 sessions := make([]*Session, numSessions) for i := 0; i < numSessions; i++ { - session, err := store.CreateSession(ctx, fmt.Sprintf("Concurrent Session %d", i+1), "gemini-2.5-flash", "") + session, err := store.CreateSession(ctx, fmt.Sprintf("Concurrent Session %d", i+1)) if err != nil { t.Fatalf("CreateSession() for session %d unexpected error: %v", i+1, err) } @@ -855,7 +792,7 @@ func TestStore_ConcurrentWrites(t *testing.T) { // Concurrently write messages to different sessions var wg sync.WaitGroup - errors := make(chan error, numSessions*10) + errs := make(chan error, numSessions*10) for i := 0; i < numSessions; i++ { sessionID := sessions[i].ID @@ -872,17 +809,17 @@ func TestStore_ConcurrentWrites(t *testing.T) { } if err := store.AddMessages(ctx, sid, []*Message{message}); err != nil { - errors <- err + errs <- err } } }(sessionID, i) } wg.Wait() - close(errors) + close(errs) // Check for errors - for err := range errors { + for err := range errs { t.Errorf("Concurrent write error: %v", err) } @@ -890,7 +827,7 @@ func TestStore_ConcurrentWrites(t *testing.T) { for i, session := range sessions { messages, err := store.Messages(ctx, session.ID, 100, 0) if err != nil { - t.Fatalf("GetMessages(%v, 100, 0) for session %d unexpected error: %v", session.ID, i+1, err) + t.Fatalf("Messages(%v, 100, 0) for session %d unexpected error: %v", session.ID, i+1, err) } if len(messages) != 10 { t.Errorf("Session %d has %d messages, want 10", i+1, len(messages)) @@ -898,28 +835,23 @@ func TestStore_ConcurrentWrites(t *testing.T) { } } -// ============================================================================= -// SQL Injection Prevention Tests -// ============================================================================= - // TestStore_SQLInjectionPrevention verifies that SQL injection attacks are blocked. // Session store uses sqlc parameterized queries which should prevent all injection. func TestStore_SQLInjectionPrevention(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // First, create a legitimate session - legitSession, err := store.CreateSession(ctx, "Legitimate Session", "", "") + legitSession, err := store.CreateSession(ctx, "Legitimate Session") if err != nil { t.Fatalf("CreateSession() for legitimate session unexpected error: %v", err) } t.Logf("Created legitimate session: %s", legitSession.ID) // Count sessions before attacks - sessions, err := store.ListSessions(ctx, 100, 0) + sessions, err := store.Sessions(ctx, 100, 0) if err != nil { - t.Fatalf("ListSessions(100, 0) before attacks unexpected error: %v", err) + t.Fatalf("Sessions(100, 0) before attacks unexpected error: %v", err) } t.Logf("Initial session count: %d", len(sessions)) @@ -957,7 +889,7 @@ func TestStore_SQLInjectionPrevention(t *testing.T) { for _, tc := range maliciousTitles { t.Run("title_"+tc.name, func(t *testing.T) { // Attempt SQL injection via session title - session, err := store.CreateSession(ctx, tc.title, "", "") + session, err := store.CreateSession(ctx, tc.title) // Should either succeed (with escaped title) or fail safely if err != nil { @@ -971,42 +903,12 @@ func TestStore_SQLInjectionPrevention(t *testing.T) { }) } - // SQL injection via model name - maliciousModels := []string{ - "'; DROP TABLE sessions; --", - "model' UNION SELECT password FROM users --", - } - - for i, model := range maliciousModels { - t.Run("model_"+string(rune('a'+i)), func(t *testing.T) { - session, err := store.CreateSession(ctx, "Test", model, "") - if err == nil { - _ = store.DeleteSession(ctx, session.ID) - } - }) - } - - // SQL injection via system prompt - maliciousPrompts := []string{ - "'; DELETE FROM session_messages; --", - "You are helpful'); DROP TABLE sessions; --", - } - - for i, prompt := range maliciousPrompts { - t.Run("prompt_"+string(rune('a'+i)), func(t *testing.T) { - session, err := store.CreateSession(ctx, "Test", "", prompt) - if err == nil { - _ = store.DeleteSession(ctx, session.ID) - } - }) - } - // Verify database integrity t.Run("verify database integrity", func(t *testing.T) { // Sessions table should still exist - sessions, err := store.ListSessions(ctx, 100, 0) + sessions, err := store.Sessions(ctx, 100, 0) if err != nil { - t.Fatalf("ListSessions(100, 0) after attacks unexpected error: %v (sessions table should still exist)", err) + t.Fatalf("Sessions(100, 0) after attacks unexpected error: %v (sessions table should still exist)", err) } // Legitimate session should still exist @@ -1024,22 +926,21 @@ func TestStore_SQLInjectionPrevention(t *testing.T) { // Should be able to load the session loaded, err := store.Session(ctx, legitSession.ID) if err != nil { - t.Fatalf("GetSession(%v) after attacks unexpected error: %v", legitSession.ID, err) + t.Fatalf("Session(%v) after attacks unexpected error: %v", legitSession.ID, err) } if loaded.Title != "Legitimate Session" { - t.Errorf("GetSession(%v) Title = %q, want %q", legitSession.ID, loaded.Title, "Legitimate Session") + t.Errorf("Session(%v) Title = %q, want %q", legitSession.ID, loaded.Title, "Legitimate Session") } }) } // TestStore_SQLInjectionViaSessionID tests injection through session IDs. func TestStore_SQLInjectionViaSessionID(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a test session - session, err := store.CreateSession(ctx, "Test Session", "", "") + session, err := store.CreateSession(ctx, "Test Session") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -1072,82 +973,70 @@ func TestStore_SQLInjectionViaSessionID(t *testing.T) { // Verify the test session still exists loaded, err := store.Session(ctx, session.ID) if err != nil { - t.Fatalf("GetSession(%v) after malicious ID attempts unexpected error: %v", session.ID, err) + t.Fatalf("Session(%v) after malicious ID attempts unexpected error: %v", session.ID, err) } if loaded.Title != "Test Session" { - t.Errorf("GetSession(%v) Title = %q, want %q", session.ID, loaded.Title, "Test Session") + t.Errorf("Session(%v) Title = %q, want %q", session.ID, loaded.Title, "Test Session") } } -// ============================================================================= -// Error Handling Tests -// ============================================================================= - -// TestStore_GetHistory_SessionNotFound verifies that GetHistory returns ErrSessionNotFound -// sentinel error when the session doesn't exist. This test validates the A3 fix from -// Proposal 056 - proper sentinel error propagation without double-wrapping. +// TestStore_GetHistory_SessionNotFound verifies that History returns ErrNotFound +// sentinel error when the session doesn't exist. func TestStore_GetHistory_SessionNotFound(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Use a non-existent session ID nonExistentID := uuid.New() - // GetHistory should return ErrSessionNotFound + // GetHistory should return ErrNotFound _, err := store.History(ctx, nonExistentID) if err == nil { - t.Fatal("GetHistory() with non-existent session should return error") + t.Fatal("History() with non-existent session should return error") } - // Verify the error is the sentinel ErrSessionNotFound (errors.Is check) - if !errors.Is(err, ErrSessionNotFound) { - t.Errorf("GetHistory(%v) error = %v (type: %T), want ErrSessionNotFound sentinel", nonExistentID, err, err) + // Verify the error is the sentinel ErrNotFound (errors.Is check) + if !errors.Is(err, ErrNotFound) { + t.Errorf("History(%v) error = %v (type: %T), want ErrNotFound sentinel", nonExistentID, err, err) } // Verify error message is not double-wrapped errStr := err.Error() if strings.Contains(errStr, "session not found: session not found") { - t.Errorf("GetHistory(%v) error message is double-wrapped: %v", nonExistentID, err) + t.Errorf("History(%v) error message is double-wrapped: %v", nonExistentID, err) } } -// TestStore_GetSession_NotFound verifies GetSession returns ErrSessionNotFound sentinel. +// TestStore_GetSession_NotFound verifies GetSession returns ErrNotFound sentinel. func TestStore_GetSession_NotFound(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() nonExistentID := uuid.New() _, err := store.Session(ctx, nonExistentID) if err == nil { - t.Fatal("GetSession() with non-existent session should return error") + t.Fatal("Session() with non-existent session should return error") } - if !errors.Is(err, ErrSessionNotFound) { - t.Errorf("GetSession(%v) error = %v, want ErrSessionNotFound sentinel", nonExistentID, err) + if !errors.Is(err, ErrNotFound) { + t.Errorf("Session(%v) error = %v, want ErrNotFound sentinel", nonExistentID, err) } } -// ============================================================================= -// SQL Injection Prevention Tests -// ============================================================================= - // TestStore_SQLInjectionViaMessageContent tests injection through message content. func TestStore_SQLInjectionViaMessageContent(t *testing.T) { - store, cleanup := setupIntegrationTest(t) - defer cleanup() + store := setupIntegrationTest(t) ctx := context.Background() // Create a test session - session, err := store.CreateSession(ctx, "Message Test", "", "") + session, err := store.CreateSession(ctx, "Message Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } // Malicious message content maliciousMessages := []string{ - "'; DROP TABLE session_messages; --", + "'; DROP TABLE messages; --", "Hello'); DELETE FROM sessions WHERE '1'='1", "Test' UNION SELECT password FROM users --", "Message\x00'; DROP TABLE sessions; --", @@ -1174,15 +1063,15 @@ func TestStore_SQLInjectionViaMessageContent(t *testing.T) { // Session should still exist _, err := store.Session(ctx, session.ID) if err != nil { - t.Fatalf("GetSession(%v) after malicious message attempts unexpected error: %v (session should still exist)", session.ID, err) + t.Fatalf("Session(%v) after malicious message attempts unexpected error: %v (session should still exist)", session.ID, err) } // Should be able to load messages sessionID := session.ID history, err := store.History(ctx, sessionID) if err != nil { - t.Fatalf("GetHistory(%v) after malicious message attempts unexpected error: %v (should be able to load history)", sessionID, err) + t.Fatalf("History(%v) after malicious message attempts unexpected error: %v (should be able to load history)", sessionID, err) } - t.Logf("loaded history with %d messages", len(history.Messages())) + t.Logf("loaded history with %d messages", len(history)) }) } diff --git a/internal/session/session.go b/internal/session/session.go new file mode 100644 index 0000000..4975c54 --- /dev/null +++ b/internal/session/session.go @@ -0,0 +1,31 @@ +package session + +import ( + "errors" + "time" + + "github.com/firebase/genkit/go/ai" + "github.com/google/uuid" +) + +// ErrNotFound indicates the requested session does not exist in the database. +var ErrNotFound = errors.New("session not found") + +// Session represents a conversation session (application-level type). +type Session struct { + ID uuid.UUID + Title string + CreatedAt time.Time + UpdatedAt time.Time +} + +// Message represents a single conversation message (application-level type). +// Content field stores Genkit's ai.Part slice, serialized as JSONB in database. +type Message struct { + ID uuid.UUID + SessionID uuid.UUID + Role string // "user" | "assistant" | "system" | "tool" + Content []*ai.Part // Genkit Part slice (stored as JSONB) + SequenceNumber int + CreatedAt time.Time +} diff --git a/internal/session/state.go b/internal/session/state.go index 91a1876..db30f7f 100644 --- a/internal/session/state.go +++ b/internal/session/state.go @@ -64,7 +64,7 @@ func getStateFilePath() (string, error) { // LoadCurrentSessionID loads the currently active session ID from local state file. // -// Acquires shared file lock to allow concurrent reads but prevent writes during read. +// Acquires exclusive file lock to prevent concurrent access during read. // // Returns: // - *uuid.UUID: Current session ID (nil if no current session) @@ -146,9 +146,6 @@ func SaveCurrentSessionID(sessionID uuid.UUID) error { return fmt.Errorf("saving session: %w", err) } - // Clean up any orphaned temp files from previous crashed sessions - _ = cleanupOrphanedTempFiles() - // Acquire file lock to prevent concurrent access lock, err := acquireStateLock() if err != nil { @@ -156,6 +153,9 @@ func SaveCurrentSessionID(sessionID uuid.UUID) error { } defer func() { _ = lock.Unlock() }() + // Clean up any orphaned temp files from previous crashed sessions (under lock) + _ = cleanupOrphanedTempFiles() + // Write to temporary file first (atomic write pattern) tmpFile := filePath + ".tmp" if err := os.WriteFile(tmpFile, []byte(sessionID.String()), 0o600); err != nil { diff --git a/internal/session/state_test.go b/internal/session/state_test.go index 6a46db1..8007259 100644 --- a/internal/session/state_test.go +++ b/internal/session/state_test.go @@ -26,19 +26,19 @@ func TestGetStateFilePath(t *testing.T) { // Verify path is absolute if !filepath.IsAbs(path) { - t.Errorf("getStateFilePath() returned relative path: %s", path) + t.Errorf("getStateFilePath() returned relative path: %q", path) } // Verify path uses temp directory rel, err := filepath.Rel(tempDir, path) if err != nil || strings.HasPrefix(rel, "..") { - t.Errorf("getStateFilePath() = %s, expected to be within %s", path, tempDir) + t.Errorf("getStateFilePath() = %q, want within %q", path, tempDir) } // Verify directory was created dir := filepath.Dir(path) if _, err := os.Stat(dir); os.IsNotExist(err) { - t.Errorf("getStateFilePath() did not create directory: %s", dir) + t.Errorf("getStateFilePath() did not create directory: %q", dir) } } @@ -64,7 +64,6 @@ func TestSaveAndLoadCurrentSessionID(t *testing.T) { if loadedID == nil { t.Fatal("LoadCurrentSessionID() returned nil") - return } if *loadedID != testID { @@ -110,7 +109,6 @@ func TestSaveAndLoadCurrentSessionID(t *testing.T) { if loadedID == nil { t.Fatal("LoadCurrentSessionID() returned nil") - return } if *loadedID != secondID { diff --git a/internal/session/store.go b/internal/session/store.go index a48b28b..edb13e2 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" + "github.com/koopa0/koopa/internal/config" "github.com/koopa0/koopa/internal/sqlc" ) @@ -52,32 +53,19 @@ func New(queries *sqlc.Queries, pool *pgxpool.Pool, logger *slog.Logger) *Store // Parameters: // - ctx: Context for the operation // - title: Session title (empty string = no title) -// - modelName: Model name used for this session (empty string = default) -// - systemPrompt: System prompt for this session (empty string = default) // // Returns: // - *Session: Created session with generated UUID // - error: If creation fails -func (s *Store) CreateSession(ctx context.Context, title, modelName, systemPrompt string) (*Session, error) { - // Convert empty strings to nil for nullable fields - var titlePtr, modelNamePtr, systemPromptPtr *string +func (s *Store) CreateSession(ctx context.Context, title string) (*Session, error) { + var titlePtr *string if title != "" { titlePtr = &title } - if modelName != "" { - modelNamePtr = &modelName - } - if systemPrompt != "" { - systemPromptPtr = &systemPrompt - } - sqlcSession, err := s.queries.CreateSession(ctx, sqlc.CreateSessionParams{ - Title: titlePtr, - ModelName: modelNamePtr, - SystemPrompt: systemPromptPtr, - }) + sqlcSession, err := s.queries.CreateSession(ctx, titlePtr) if err != nil { - return nil, fmt.Errorf("failed to create session: %w", err) + return nil, fmt.Errorf("creating session: %w", err) } session := s.sqlcSessionToSession(sqlcSession) @@ -86,21 +74,21 @@ func (s *Store) CreateSession(ctx context.Context, title, modelName, systemPromp } // Session retrieves a session by ID. -// Returns ErrSessionNotFound if the session does not exist. +// Returns ErrNotFound if the session does not exist. func (s *Store) Session(ctx context.Context, sessionID uuid.UUID) (*Session, error) { sqlcSession, err := s.queries.Session(ctx, sessionID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { // Return sentinel error directly (no wrapping per reviewer guidance) - return nil, ErrSessionNotFound + return nil, ErrNotFound } - return nil, fmt.Errorf("failed to get session %s: %w", sessionID, err) + return nil, fmt.Errorf("getting session %s: %w", sessionID, err) } return s.sqlcSessionToSession(sqlcSession), nil } -// ListSessions lists sessions with pagination, ordered by updated_at descending. +// Sessions lists sessions with pagination, ordered by updated_at descending. // // Parameters: // - ctx: Context for the operation @@ -110,13 +98,13 @@ func (s *Store) Session(ctx context.Context, sessionID uuid.UUID) (*Session, err // Returns: // - []*Session: List of sessions // - error: If listing fails -func (s *Store) ListSessions(ctx context.Context, limit, offset int32) ([]*Session, error) { - sqlcSessions, err := s.queries.ListSessions(ctx, sqlc.ListSessionsParams{ +func (s *Store) Sessions(ctx context.Context, limit, offset int32) ([]*Session, error) { + sqlcSessions, err := s.queries.Sessions(ctx, sqlc.SessionsParams{ ResultLimit: limit, ResultOffset: offset, }) if err != nil { - return nil, fmt.Errorf("failed to list sessions: %w", err) + return nil, fmt.Errorf("listing sessions: %w", err) } sessions := make([]*Session, 0, len(sqlcSessions)) @@ -128,35 +116,6 @@ func (s *Store) ListSessions(ctx context.Context, limit, offset int32) ([]*Sessi return sessions, nil } -// ListSessionsWithMessages lists sessions that have messages or titles. -// This is used for sidebar display where empty "New Chat" placeholder sessions should be hidden. -// -// Parameters: -// - ctx: Context for the operation -// - limit: Maximum number of sessions to return -// - offset: Number of sessions to skip (for pagination) -// -// Returns: -// - []*Session: List of sessions with messages or titles -// - error: If listing fails -func (s *Store) ListSessionsWithMessages(ctx context.Context, limit, offset int32) ([]*Session, error) { - sqlcSessions, err := s.queries.ListSessionsWithMessages(ctx, sqlc.ListSessionsWithMessagesParams{ - ResultLimit: limit, - ResultOffset: offset, - }) - if err != nil { - return nil, fmt.Errorf("failed to list sessions with messages: %w", err) - } - - sessions := make([]*Session, 0, len(sqlcSessions)) - for i := range sqlcSessions { - sessions = append(sessions, s.sqlcSessionToSession(sqlcSessions[i])) - } - - s.logger.Debug("listed sessions with messages", "count", len(sessions), "limit", limit, "offset", offset) - return sessions, nil -} - // DeleteSession deletes a session and all its messages (CASCADE). // // Parameters: @@ -167,7 +126,7 @@ func (s *Store) ListSessionsWithMessages(ctx context.Context, limit, offset int3 // - error: If deletion fails func (s *Store) DeleteSession(ctx context.Context, sessionID uuid.UUID) error { if err := s.queries.DeleteSession(ctx, sessionID); err != nil { - return fmt.Errorf("failed to delete session %s: %w", sessionID, err) + return fmt.Errorf("deleting session %s: %w", sessionID, err) } s.logger.Debug("deleted session", "id", sessionID) @@ -194,7 +153,7 @@ func (s *Store) UpdateSessionTitle(ctx context.Context, sessionID uuid.UUID, tit SessionID: sessionID, Title: titlePtr, }); err != nil { - return fmt.Errorf("failed to update session title %s: %w", sessionID, err) + return fmt.Errorf("updating session title %s: %w", sessionID, err) } s.logger.Debug("updated session title", "id", sessionID, "title", title) @@ -224,13 +183,13 @@ func (s *Store) AddMessages(ctx context.Context, sessionID uuid.UUID, messages [ // Database pool is required for transactional operations // Tests should use pgxmock or Testcontainers for proper transaction testing if s.pool == nil { - return fmt.Errorf("database pool required for AddMessages: use pgxmock or real database for testing") + return fmt.Errorf("database pool is required for transactional operations") } // Begin transaction for atomicity tx, err := s.pool.Begin(ctx) if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) + return fmt.Errorf("beginning transaction: %w", err) } // Rollback if not committed - log any rollback errors for debugging defer func() { @@ -246,16 +205,13 @@ func (s *Store) AddMessages(ctx context.Context, sessionID uuid.UUID, messages [ // This SELECT ... FOR UPDATE ensures that only one transaction can modify // this session at a time, preventing race conditions on sequence numbers if _, err = txQuerier.LockSession(ctx, sessionID); err != nil { - return fmt.Errorf("failed to lock session: %w", err) + return fmt.Errorf("locking session: %w", err) } // 1. Get current max sequence number within transaction - maxSeq, err := txQuerier.GetMaxSequenceNumber(ctx, sessionID) + maxSeq, err := txQuerier.MaxSequenceNumber(ctx, sessionID) if err != nil { - // If session doesn't exist yet or no messages, start from 0 - s.logger.Debug("no existing messages, starting from sequence 0", - "session_id", sessionID) - maxSeq = 0 + return fmt.Errorf("getting max sequence number: %w", err) } // 2. Insert messages in batch within transaction @@ -271,7 +227,7 @@ func (s *Store) AddMessages(ctx context.Context, sessionID uuid.UUID, messages [ contentJSON, marshalErr := json.Marshal(msg.Content) if marshalErr != nil { // Transaction will be rolled back by defer - return fmt.Errorf("failed to marshal message content at index %d: %w", i, marshalErr) + return fmt.Errorf("marshaling message content at index %d: %w", i, marshalErr) } // Calculate sequence number (maxSeq is now int32 from sqlc) @@ -285,24 +241,19 @@ func (s *Store) AddMessages(ctx context.Context, sessionID uuid.UUID, messages [ SequenceNumber: seqNum, }); err != nil { // Transaction will be rolled back by defer - return fmt.Errorf("failed to insert message %d: %w", i, err) + return fmt.Errorf("inserting message %d: %w", i, err) } } - // 3. Update session's updated_at and message_count within transaction - // Safe conversion: len(messages) is bounded by practical limits (< millions) - newCount := maxSeq + int32(len(messages)) // #nosec G115 -- len bounded by practical message limits - if err = txQuerier.UpdateSessionUpdatedAt(ctx, sqlc.UpdateSessionUpdatedAtParams{ - MessageCount: &newCount, - SessionID: sessionID, - }); err != nil { + // 3. Update session's updated_at within transaction + if err = txQuerier.UpdateSessionUpdatedAt(ctx, sessionID); err != nil { // Transaction will be rolled back by defer - return fmt.Errorf("failed to update session metadata: %w", err) + return fmt.Errorf("updating session metadata: %w", err) } // 4. Commit transaction if err := tx.Commit(ctx); err != nil { - return fmt.Errorf("failed to commit transaction: %w", err) + return fmt.Errorf("committing transaction: %w", err) } s.logger.Debug("added messages", "session_id", sessionID, "count", len(messages)) @@ -327,14 +278,14 @@ func (s *Store) Messages(ctx context.Context, sessionID uuid.UUID, limit, offset ResultOffset: offset, }) if err != nil { - return nil, fmt.Errorf("failed to get messages for session %s: %w", sessionID, err) + return nil, fmt.Errorf("getting messages for session %s: %w", sessionID, err) } messages := make([]*Message, 0, len(sqlcMessages)) for i := range sqlcMessages { msg, err := s.sqlcMessageToMessage(sqlcMessages[i]) if err != nil { - s.logger.Warn("failed to unmarshal message content", + s.logger.Warn("skipping malformed message", "message_id", sqlcMessages[i].ID, "error", err) continue // Skip malformed messages @@ -382,7 +333,7 @@ func (s *Store) AppendMessages(ctx context.Context, sessionID uuid.UUID, message // Use AddMessages if err := s.AddMessages(ctx, sessionID, sessionMessages); err != nil { - return fmt.Errorf("failed to append messages: %w", err) + return err // AddMessages already wraps with context } s.logger.Debug("appended messages", @@ -391,32 +342,21 @@ func (s *Store) AppendMessages(ctx context.Context, sessionID uuid.UUID, message return nil } -// History retrieves the conversation history for a session. -// Used by chat.Chat agent for session management. -// -// Parameters: -// - ctx: Context for the operation -// - sessionID: Session UUID -// -// Returns: -// - *History: Conversation history -// - error: If retrieval fails -func (s *Store) History(ctx context.Context, sessionID uuid.UUID) (*History, error) { +// History retrieves the conversation history for a session as a slice of ai.Message. +// Used by chat.Agent for session management. +func (s *Store) History(ctx context.Context, sessionID uuid.UUID) ([]*ai.Message, error) { // Verify session exists before loading history if _, err := s.Session(ctx, sessionID); err != nil { - if errors.Is(err, ErrSessionNotFound) { + if errors.Is(err, ErrNotFound) { return nil, err // Sentinel propagates unchanged } - return nil, fmt.Errorf("get history for session %s: %w", sessionID, err) + return nil, fmt.Errorf("getting history for session %s: %w", sessionID, err) } - // Use default limit for history retrieval - limit := DefaultHistoryLimit - // Retrieve messages - messages, err := s.Messages(ctx, sessionID, limit, 0) + messages, err := s.Messages(ctx, sessionID, config.DefaultMaxHistoryMessages, 0) if err != nil { - return nil, fmt.Errorf("failed to load history: %w", err) + return nil, fmt.Errorf("loading history: %w", err) } // Convert to ai.Message @@ -430,11 +370,42 @@ func (s *Store) History(ctx context.Context, sessionID uuid.UUID) (*History, err s.logger.Debug("loaded history", "session_id", sessionID, - "message_count", len(messages)) + "count", len(messages)) - history := NewHistory() - history.SetMessages(aiMessages) - return history, nil + return aiMessages, nil +} + +// ResolveCurrentSession loads the active session from the state file, +// validates it exists in the database, and creates a new session if needed. +// Returns the session ID. +func (s *Store) ResolveCurrentSession(ctx context.Context) (uuid.UUID, error) { + //nolint:contextcheck // LoadCurrentSessionID manages its own lock timeout context + savedID, err := LoadCurrentSessionID() + if err != nil { + return uuid.Nil, fmt.Errorf("loading current session: %w", err) + } + + if savedID != nil { + if _, err = s.Session(ctx, *savedID); err == nil { + return *savedID, nil + } + if !errors.Is(err, ErrNotFound) { + return uuid.Nil, fmt.Errorf("validating session: %w", err) + } + } + + newSess, err := s.CreateSession(ctx, "") + if err != nil { + return uuid.Nil, fmt.Errorf("creating session: %w", err) + } + + // best-effort: state file is non-critical, session already created in DB + //nolint:contextcheck // SaveCurrentSessionID manages its own lock timeout context + if saveErr := SaveCurrentSessionID(newSess.ID); saveErr != nil { + s.logger.Warn("saving session state", "error", saveErr) + } + + return newSess.ID, nil } // sqlcSessionToSession converts sqlc.Session to Session (application type). @@ -448,15 +419,6 @@ func (*Store) sqlcSessionToSession(ss sqlc.Session) *Session { if ss.Title != nil { session.Title = *ss.Title } - if ss.ModelName != nil { - session.ModelName = *ss.ModelName - } - if ss.SystemPrompt != nil { - session.SystemPrompt = *ss.SystemPrompt - } - if ss.MessageCount != nil { - session.MessageCount = int(*ss.MessageCount) - } return session } @@ -466,7 +428,7 @@ func (*Store) sqlcMessageToMessage(sm sqlc.Message) (*Message, error) { // Unmarshal JSONB content to ai.Part slice var content []*ai.Part if err := json.Unmarshal(sm.Content, &content); err != nil { - return nil, fmt.Errorf("failed to unmarshal content: %w", err) + return nil, fmt.Errorf("unmarshaling content: %w", err) } return &Message{ @@ -474,169 +436,7 @@ func (*Store) sqlcMessageToMessage(sm sqlc.Message) (*Message, error) { SessionID: sm.SessionID, Role: sm.Role, Content: content, - Status: sm.Status, SequenceNumber: int(sm.SequenceNumber), CreatedAt: sm.CreatedAt.Time, }, nil } - -// ============================================================================= -// Streaming Message Operations (for SSE chat flow) -// ============================================================================= - -// MessagePair represents a user-assistant message pair created for streaming. -// Used to track both messages atomically for SSE-based chat. -type MessagePair struct { - UserMsgID uuid.UUID - AssistantMsgID uuid.UUID - UserSeq int32 - AssistantSeq int32 -} - -// CreateMessagePair atomically creates a user message and empty assistant placeholder. -// The user message is marked as "completed", the assistant message as "streaming". -// This is used at the start of a chat turn before SSE streaming begins. -// -// Parameters: -// - ctx: Context for the operation -// - sessionID: UUID of the session -// - userContent: User message content as ai.Part slice -// - assistantID: Pre-generated UUID for the assistant message (used in SSE URL) -// -// Returns: -// - *MessagePair: Contains both message IDs and sequence numbers -// - error: If creation fails or session doesn't exist -func (s *Store) CreateMessagePair( - ctx context.Context, - sessionID uuid.UUID, - userContent []*ai.Part, - assistantID uuid.UUID, -) (*MessagePair, error) { - // Database pool is required for transactional operations - if s.pool == nil { - return nil, fmt.Errorf("database pool required for CreateMessagePair") - } - - // Begin transaction for atomicity - tx, err := s.pool.Begin(ctx) - if err != nil { - return nil, fmt.Errorf("begin transaction: %w", err) - } - defer func() { - if rollbackErr := tx.Rollback(ctx); rollbackErr != nil { - s.logger.Debug("transaction rollback (may be already committed)", "error", rollbackErr) - } - }() - - txQuerier := sqlc.New(tx) - - // Lock session row to prevent concurrent modifications - _, err = txQuerier.LockSession(ctx, sessionID) - if err != nil { - return nil, fmt.Errorf("lock session: %w", err) - } - - // Get current max sequence number - maxSeq, err := txQuerier.GetMaxSequenceNumber(ctx, sessionID) - if err != nil { - s.logger.Debug("no existing messages, starting from sequence 0", - "session_id", sessionID) - maxSeq = 0 - } - - // Marshal user content to JSON - userContentJSON, err := json.Marshal(userContent) - if err != nil { - return nil, fmt.Errorf("marshal user content: %w", err) - } - - // Generate user message ID - userMsgID := uuid.New() - userSeq := maxSeq + 1 - assistantSeq := maxSeq + 2 - - // Insert user message (status = completed) - _, err = txQuerier.AddMessageWithID(ctx, sqlc.AddMessageWithIDParams{ - ID: userMsgID, - SessionID: sessionID, - Role: "user", - Content: userContentJSON, - Status: StatusCompleted, - SequenceNumber: userSeq, - }) - if err != nil { - return nil, fmt.Errorf("insert user message: %w", err) - } - - // Insert empty assistant message placeholder (status = streaming) - emptyContent := []byte("[]") // Empty ai.Part slice - _, err = txQuerier.AddMessageWithID(ctx, sqlc.AddMessageWithIDParams{ - ID: assistantID, - SessionID: sessionID, - Role: RoleAssistant, - Content: emptyContent, - Status: StatusStreaming, - SequenceNumber: assistantSeq, - }) - if err != nil { - return nil, fmt.Errorf("insert assistant message: %w", err) - } - - // Update session metadata - if err = txQuerier.UpdateSessionUpdatedAt(ctx, sqlc.UpdateSessionUpdatedAtParams{ - MessageCount: &assistantSeq, - SessionID: sessionID, - }); err != nil { - return nil, fmt.Errorf("update session metadata: %w", err) - } - - // Commit transaction - if err := tx.Commit(ctx); err != nil { - return nil, fmt.Errorf("commit transaction: %w", err) - } - - s.logger.Debug("created message pair", - "session_id", sessionID, - "user_msg_id", userMsgID, - "assistant_msg_id", assistantID) - - return &MessagePair{ - UserMsgID: userMsgID, - AssistantMsgID: assistantID, - UserSeq: userSeq, - AssistantSeq: assistantSeq, - }, nil -} - -// UpdateMessageContent updates the content of a message and marks it as completed. -// Used after streaming is finished to save the final AI response. -func (s *Store) UpdateMessageContent(ctx context.Context, msgID uuid.UUID, content []*ai.Part) error { - contentJSON, err := json.Marshal(content) - if err != nil { - return fmt.Errorf("marshal content: %w", err) - } - - if err := s.queries.UpdateMessageContent(ctx, sqlc.UpdateMessageContentParams{ - ID: msgID, - Content: contentJSON, - }); err != nil { - return fmt.Errorf("update message content: %w", err) - } - - s.logger.Debug("updated message content", "msg_id", msgID) - return nil -} - -// UpdateMessageStatus updates the status of a message. -// Used to mark streaming messages as failed if an error occurs. -func (s *Store) UpdateMessageStatus(ctx context.Context, msgID uuid.UUID, status string) error { - if err := s.queries.UpdateMessageStatus(ctx, sqlc.UpdateMessageStatusParams{ - ID: msgID, - Status: status, - }); err != nil { - return fmt.Errorf("update message status: %w", err) - } - - s.logger.Debug("updated message status", "msg_id", msgID, "status", status) - return nil -} diff --git a/internal/session/store_test.go b/internal/session/store_test.go new file mode 100644 index 0000000..a1cafda --- /dev/null +++ b/internal/session/store_test.go @@ -0,0 +1,34 @@ +package session + +import "testing" + +// TestNormalizeRole tests the Genkit role normalization function. +// Genkit uses "model" for AI responses, but we store "assistant" in the database +// for consistency with the CHECK constraint. +func TestNormalizeRole(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + }{ + {name: "model to assistant", input: "model", want: "assistant"}, + {name: "user unchanged", input: "user", want: "user"}, + {name: "assistant unchanged", input: "assistant", want: "assistant"}, + {name: "system unchanged", input: "system", want: "system"}, + {name: "tool unchanged", input: "tool", want: "tool"}, + {name: "empty passthrough", input: "", want: ""}, + {name: "unknown passthrough", input: "unknown", want: "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := normalizeRole(tt.input) + if got != tt.want { + t.Errorf("normalizeRole(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/session/types.go b/internal/session/types.go deleted file mode 100644 index cbd9c6f..0000000 --- a/internal/session/types.go +++ /dev/null @@ -1,112 +0,0 @@ -// Package session provides session persistence functionality for conversation history. -// -// Responsibilities: Save/load conversation sessions to PostgreSQL database. -// Thread Safety: Not thread-safe - caller must synchronize access. -package session - -import ( - "sync" - "time" - - "github.com/firebase/genkit/go/ai" - "github.com/google/uuid" -) - -// History encapsulates conversation history with thread-safe access. -// -// Note: The zero value is NOT useful - use NewHistory() to create instances. -type History struct { - mu sync.RWMutex - messages []*ai.Message -} - -// NewHistory creates a new History instance. -func NewHistory() *History { - return &History{ - messages: make([]*ai.Message, 0), - } -} - -// SetMessages replaces all messages in the history. -// Used by SessionStore when loading history from the database. -// Makes a defensive copy to prevent external modification. -func (h *History) SetMessages(messages []*ai.Message) { - h.mu.Lock() - defer h.mu.Unlock() - h.messages = make([]*ai.Message, len(messages)) - copy(h.messages, messages) -} - -// Messages returns a copy of all messages for thread-safe access -func (h *History) Messages() []*ai.Message { - h.mu.RLock() - defer h.mu.RUnlock() - result := make([]*ai.Message, len(h.messages)) - copy(result, h.messages) - return result -} - -// Add appends user message and assistant response -func (h *History) Add(userInput, assistantResponse string) { - h.mu.Lock() - defer h.mu.Unlock() - h.messages = append(h.messages, - ai.NewUserMessage(ai.NewTextPart(userInput)), - ai.NewModelMessage(ai.NewTextPart(assistantResponse)), - ) -} - -// AddMessage appends a single message -// Returns without effect if msg is nil -func (h *History) AddMessage(msg *ai.Message) { - if msg == nil { - return - } - h.mu.Lock() - defer h.mu.Unlock() - h.messages = append(h.messages, msg) -} - -// Count returns the number of messages -func (h *History) Count() int { - h.mu.RLock() - defer h.mu.RUnlock() - return len(h.messages) -} - -// Clear removes all messages -func (h *History) Clear() { - h.mu.Lock() - defer h.mu.Unlock() - h.messages = make([]*ai.Message, 0) -} - -// Role constants define valid message roles for type safety. -const ( - RoleUser = "user" - RoleAssistant = "assistant" - RoleTool = "tool" -) - -// Session represents a conversation session (application-level type). -type Session struct { - ID uuid.UUID - Title string - CreatedAt time.Time - UpdatedAt time.Time - ModelName string - SystemPrompt string - MessageCount int -} - -// Message represents a single conversation message (application-level type). -// Content field stores Genkit's ai.Part slice, serialized as JSONB in database. -type Message struct { - ID uuid.UUID - SessionID uuid.UUID - Role string // "user" | "assistant" | "tool" - Content []*ai.Part // Genkit Part slice (stored as JSONB) - Status string // Message status: streaming/completed/failed - SequenceNumber int - CreatedAt time.Time -} diff --git a/internal/sqlc/documents.sql.go b/internal/sqlc/documents.sql.go deleted file mode 100644 index 48ac5f5..0000000 --- a/internal/sqlc/documents.sql.go +++ /dev/null @@ -1,280 +0,0 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.30.0 -// source: documents.sql - -package sqlc - -import ( - "context" - - "github.com/pgvector/pgvector-go" -) - -const countDocuments = `-- name: CountDocuments :one -SELECT COUNT(*) -FROM documents -WHERE metadata @> $1::jsonb -` - -func (q *Queries) CountDocuments(ctx context.Context, dollar_1 []byte) (int64, error) { - row := q.db.QueryRow(ctx, countDocuments, dollar_1) - var count int64 - err := row.Scan(&count) - return count, err -} - -const countDocumentsAll = `-- name: CountDocumentsAll :one -SELECT COUNT(*) -FROM documents -` - -func (q *Queries) CountDocumentsAll(ctx context.Context) (int64, error) { - row := q.db.QueryRow(ctx, countDocumentsAll) - var count int64 - err := row.Scan(&count) - return count, err -} - -const deleteDocument = `-- name: DeleteDocument :exec -DELETE FROM documents -WHERE id = $1 -` - -func (q *Queries) DeleteDocument(ctx context.Context, id string) error { - _, err := q.db.Exec(ctx, deleteDocument, id) - return err -} - -const getDocument = `-- name: GetDocument :one -SELECT id, content, metadata -FROM documents -WHERE id = $1 -` - -type GetDocumentRow struct { - ID string `json:"id"` - Content string `json:"content"` - Metadata []byte `json:"metadata"` -} - -func (q *Queries) GetDocument(ctx context.Context, id string) (GetDocumentRow, error) { - row := q.db.QueryRow(ctx, getDocument, id) - var i GetDocumentRow - err := row.Scan(&i.ID, &i.Content, &i.Metadata) - return i, err -} - -const listDocumentsBySourceType = `-- name: ListDocumentsBySourceType :many -SELECT id, content, metadata -FROM documents -WHERE source_type = $1::text -LIMIT $2 -` - -type ListDocumentsBySourceTypeParams struct { - SourceType string `json:"source_type"` - ResultLimit int32 `json:"result_limit"` -} - -type ListDocumentsBySourceTypeRow struct { - ID string `json:"id"` - Content string `json:"content"` - Metadata []byte `json:"metadata"` -} - -// List all documents by source_type using dedicated indexed column -// Used for listing indexed files without needing embeddings -func (q *Queries) ListDocumentsBySourceType(ctx context.Context, arg ListDocumentsBySourceTypeParams) ([]ListDocumentsBySourceTypeRow, error) { - rows, err := q.db.Query(ctx, listDocumentsBySourceType, arg.SourceType, arg.ResultLimit) - if err != nil { - return nil, err - } - defer rows.Close() - items := []ListDocumentsBySourceTypeRow{} - for rows.Next() { - var i ListDocumentsBySourceTypeRow - if err := rows.Scan(&i.ID, &i.Content, &i.Metadata); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const searchBySourceType = `-- name: SearchBySourceType :many - -SELECT id, content, metadata, - (1 - (embedding <=> $1::vector))::float8 AS similarity -FROM documents -WHERE source_type = $2::text -ORDER BY similarity DESC -LIMIT $3 -` - -type SearchBySourceTypeParams struct { - QueryEmbedding *pgvector.Vector `json:"query_embedding"` - SourceType string `json:"source_type"` - ResultLimit int32 `json:"result_limit"` -} - -type SearchBySourceTypeRow struct { - ID string `json:"id"` - Content string `json:"content"` - Metadata []byte `json:"metadata"` - Similarity float64 `json:"similarity"` -} - -// ===== Optimized RAG Queries (SQL-level filtering) ===== -// Generic search by source_type using dedicated indexed column -func (q *Queries) SearchBySourceType(ctx context.Context, arg SearchBySourceTypeParams) ([]SearchBySourceTypeRow, error) { - rows, err := q.db.Query(ctx, searchBySourceType, arg.QueryEmbedding, arg.SourceType, arg.ResultLimit) - if err != nil { - return nil, err - } - defer rows.Close() - items := []SearchBySourceTypeRow{} - for rows.Next() { - var i SearchBySourceTypeRow - if err := rows.Scan( - &i.ID, - &i.Content, - &i.Metadata, - &i.Similarity, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const searchDocuments = `-- name: SearchDocuments :many -SELECT id, content, metadata, - (1 - (embedding <=> $1::vector))::float8 AS similarity -FROM documents -WHERE metadata @> $2::jsonb -ORDER BY similarity DESC -LIMIT $3 -` - -type SearchDocumentsParams struct { - QueryEmbedding *pgvector.Vector `json:"query_embedding"` - FilterMetadata []byte `json:"filter_metadata"` - ResultLimit int32 `json:"result_limit"` -} - -type SearchDocumentsRow struct { - ID string `json:"id"` - Content string `json:"content"` - Metadata []byte `json:"metadata"` - Similarity float64 `json:"similarity"` -} - -func (q *Queries) SearchDocuments(ctx context.Context, arg SearchDocumentsParams) ([]SearchDocumentsRow, error) { - rows, err := q.db.Query(ctx, searchDocuments, arg.QueryEmbedding, arg.FilterMetadata, arg.ResultLimit) - if err != nil { - return nil, err - } - defer rows.Close() - items := []SearchDocumentsRow{} - for rows.Next() { - var i SearchDocumentsRow - if err := rows.Scan( - &i.ID, - &i.Content, - &i.Metadata, - &i.Similarity, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const searchDocumentsAll = `-- name: SearchDocumentsAll :many -SELECT id, content, metadata, - (1 - (embedding <=> $1::vector))::float8 AS similarity -FROM documents -ORDER BY similarity DESC -LIMIT $2 -` - -type SearchDocumentsAllParams struct { - QueryEmbedding *pgvector.Vector `json:"query_embedding"` - ResultLimit int32 `json:"result_limit"` -} - -type SearchDocumentsAllRow struct { - ID string `json:"id"` - Content string `json:"content"` - Metadata []byte `json:"metadata"` - Similarity float64 `json:"similarity"` -} - -func (q *Queries) SearchDocumentsAll(ctx context.Context, arg SearchDocumentsAllParams) ([]SearchDocumentsAllRow, error) { - rows, err := q.db.Query(ctx, searchDocumentsAll, arg.QueryEmbedding, arg.ResultLimit) - if err != nil { - return nil, err - } - defer rows.Close() - items := []SearchDocumentsAllRow{} - for rows.Next() { - var i SearchDocumentsAllRow - if err := rows.Scan( - &i.ID, - &i.Content, - &i.Metadata, - &i.Similarity, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const upsertDocument = `-- name: UpsertDocument :exec - -INSERT INTO documents (id, content, embedding, source_type, metadata) -VALUES ($1, $2, $3, $4, $5) -ON CONFLICT (id) DO UPDATE SET - content = EXCLUDED.content, - embedding = EXCLUDED.embedding, - source_type = EXCLUDED.source_type, - metadata = EXCLUDED.metadata -` - -type UpsertDocumentParams struct { - ID string `json:"id"` - Content string `json:"content"` - Embedding *pgvector.Vector `json:"embedding"` - SourceType *string `json:"source_type"` - Metadata []byte `json:"metadata"` -} - -// Documents queries for sqlc -// Generated code will be in internal/sqlc/documents.sql.go -func (q *Queries) UpsertDocument(ctx context.Context, arg UpsertDocumentParams) error { - _, err := q.db.Exec(ctx, upsertDocument, - arg.ID, - arg.Content, - arg.Embedding, - arg.SourceType, - arg.Metadata, - ) - return err -} diff --git a/internal/sqlc/models.go b/internal/sqlc/models.go index 90afee5..4ec4daf 100644 --- a/internal/sqlc/models.go +++ b/internal/sqlc/models.go @@ -25,16 +25,11 @@ type Message struct { Content []byte `json:"content"` SequenceNumber int32 `json:"sequence_number"` CreatedAt pgtype.Timestamptz `json:"created_at"` - Status string `json:"status"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` } type Session struct { - ID uuid.UUID `json:"id"` - Title *string `json:"title"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` - ModelName *string `json:"model_name"` - SystemPrompt *string `json:"system_prompt"` - MessageCount *int32 `json:"message_count"` + ID uuid.UUID `json:"id"` + Title *string `json:"title"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` } diff --git a/internal/sqlc/sessions.sql.go b/internal/sqlc/sessions.sql.go index ade5590..df5e238 100644 --- a/internal/sqlc/sessions.sql.go +++ b/internal/sqlc/sessions.sql.go @@ -12,7 +12,7 @@ import ( ) const addMessage = `-- name: AddMessage :exec -INSERT INTO message (session_id, role, content, sequence_number) +INSERT INTO messages (session_id, role, content, sequence_number) VALUES ($1, $2, $3, $4) ` @@ -34,100 +34,27 @@ func (q *Queries) AddMessage(ctx context.Context, arg AddMessageParams) error { return err } -const addMessageWithID = `-- name: AddMessageWithID :one -INSERT INTO message (id, session_id, role, content, status, sequence_number) -VALUES ($1, $2, $3, $4, $5, $6) -RETURNING id, session_id, role, content, sequence_number, created_at, status, updated_at -` - -type AddMessageWithIDParams struct { - ID uuid.UUID `json:"id"` - SessionID uuid.UUID `json:"session_id"` - Role string `json:"role"` - Content []byte `json:"content"` - Status string `json:"status"` - SequenceNumber int32 `json:"sequence_number"` -} - -// Add message with pre-assigned ID and status (for streaming) -func (q *Queries) AddMessageWithID(ctx context.Context, arg AddMessageWithIDParams) (Message, error) { - row := q.db.QueryRow(ctx, addMessageWithID, - arg.ID, - arg.SessionID, - arg.Role, - arg.Content, - arg.Status, - arg.SequenceNumber, - ) - var i Message - err := row.Scan( - &i.ID, - &i.SessionID, - &i.Role, - &i.Content, - &i.SequenceNumber, - &i.CreatedAt, - &i.Status, - &i.UpdatedAt, - ) - return i, err -} - -const countMessages = `-- name: CountMessages :one -SELECT COUNT(*)::integer AS count -FROM message -WHERE session_id = $1 -` - -// Count messages in a session -func (q *Queries) CountMessages(ctx context.Context, sessionID uuid.UUID) (int32, error) { - row := q.db.QueryRow(ctx, countMessages, sessionID) - var count int32 - err := row.Scan(&count) - return count, err -} - const createSession = `-- name: CreateSession :one -INSERT INTO sessions (title, model_name, system_prompt) -VALUES ($1, $2, $3) -RETURNING id, title, created_at, updated_at, model_name, system_prompt, message_count +INSERT INTO sessions (title) +VALUES ($1) +RETURNING id, title, created_at, updated_at ` -type CreateSessionParams struct { - Title *string `json:"title"` - ModelName *string `json:"model_name"` - SystemPrompt *string `json:"system_prompt"` -} - -// Sessions queries for sqlc +// Sessions and messages queries for sqlc // Generated code will be in internal/sqlc/sessions.sql.go -func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { - row := q.db.QueryRow(ctx, createSession, arg.Title, arg.ModelName, arg.SystemPrompt) +func (q *Queries) CreateSession(ctx context.Context, title *string) (Session, error) { + row := q.db.QueryRow(ctx, createSession, title) var i Session err := row.Scan( &i.ID, &i.Title, &i.CreatedAt, &i.UpdatedAt, - &i.ModelName, - &i.SystemPrompt, - &i.MessageCount, ) return i, err } -const deleteMessages = `-- name: DeleteMessages :exec -DELETE FROM message -WHERE session_id = $1 -` - -// Delete all messages in a session -func (q *Queries) DeleteMessages(ctx context.Context, sessionID uuid.UUID) error { - _, err := q.db.Exec(ctx, deleteMessages, sessionID) - return err -} - const deleteSession = `-- name: DeleteSession :exec DELETE FROM sessions WHERE id = $1 @@ -138,105 +65,20 @@ func (q *Queries) DeleteSession(ctx context.Context, id uuid.UUID) error { return err } -const getMaxSequenceNumber = `-- name: GetMaxSequenceNumber :one +const maxSequenceNumber = `-- name: MaxSequenceNumber :one SELECT COALESCE(MAX(sequence_number), 0)::integer AS max_seq -FROM message +FROM messages WHERE session_id = $1 ` -// Get max sequence number for a session -func (q *Queries) GetMaxSequenceNumber(ctx context.Context, sessionID uuid.UUID) (int32, error) { - row := q.db.QueryRow(ctx, getMaxSequenceNumber, sessionID) +// MaxSequenceNumber returns the max sequence number for a session (returns 0 if no messages). +func (q *Queries) MaxSequenceNumber(ctx context.Context, sessionID uuid.UUID) (int32, error) { + row := q.db.QueryRow(ctx, maxSequenceNumber, sessionID) var max_seq int32 err := row.Scan(&max_seq) return max_seq, err } -const listSessions = `-- name: ListSessions :many -SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count -FROM sessions -ORDER BY updated_at DESC -LIMIT $2 -OFFSET $1 -` - -type ListSessionsParams struct { - ResultOffset int32 `json:"result_offset"` - ResultLimit int32 `json:"result_limit"` -} - -func (q *Queries) ListSessions(ctx context.Context, arg ListSessionsParams) ([]Session, error) { - rows, err := q.db.Query(ctx, listSessions, arg.ResultOffset, arg.ResultLimit) - if err != nil { - return nil, err - } - defer rows.Close() - items := []Session{} - for rows.Next() { - var i Session - if err := rows.Scan( - &i.ID, - &i.Title, - &i.CreatedAt, - &i.UpdatedAt, - &i.ModelName, - &i.SystemPrompt, - &i.MessageCount, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const listSessionsWithMessages = `-- name: ListSessionsWithMessages :many -SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count -FROM sessions -WHERE message_count > 0 OR title IS NOT NULL -ORDER BY updated_at DESC -LIMIT $2 -OFFSET $1 -` - -type ListSessionsWithMessagesParams struct { - ResultOffset int32 `json:"result_offset"` - ResultLimit int32 `json:"result_limit"` -} - -// Only list sessions that have messages or titles (not empty sessions) -// This is used for sidebar to hide "New Chat" placeholder sessions -func (q *Queries) ListSessionsWithMessages(ctx context.Context, arg ListSessionsWithMessagesParams) ([]Session, error) { - rows, err := q.db.Query(ctx, listSessionsWithMessages, arg.ResultOffset, arg.ResultLimit) - if err != nil { - return nil, err - } - defer rows.Close() - items := []Session{} - for rows.Next() { - var i Session - if err := rows.Scan( - &i.ID, - &i.Title, - &i.CreatedAt, - &i.UpdatedAt, - &i.ModelName, - &i.SystemPrompt, - &i.MessageCount, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const lockSession = `-- name: LockSession :one SELECT id FROM sessions WHERE id = $1 FOR UPDATE ` @@ -249,8 +91,8 @@ func (q *Queries) LockSession(ctx context.Context, id uuid.UUID) (uuid.UUID, err } const messages = `-- name: Messages :many -SELECT id, session_id, role, content, sequence_number, created_at, status, updated_at -FROM message +SELECT id, session_id, role, content, sequence_number, created_at +FROM messages WHERE session_id = $1 ORDER BY sequence_number ASC LIMIT $3 @@ -280,8 +122,6 @@ func (q *Queries) Messages(ctx context.Context, arg MessagesParams) ([]Message, &i.Content, &i.SequenceNumber, &i.CreatedAt, - &i.Status, - &i.UpdatedAt, ); err != nil { return nil, err } @@ -294,7 +134,7 @@ func (q *Queries) Messages(ctx context.Context, arg MessagesParams) ([]Message, } const session = `-- name: Session :one -SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count +SELECT id, title, created_at, updated_at FROM sessions WHERE id = $1 ` @@ -307,48 +147,46 @@ func (q *Queries) Session(ctx context.Context, id uuid.UUID) (Session, error) { &i.Title, &i.CreatedAt, &i.UpdatedAt, - &i.ModelName, - &i.SystemPrompt, - &i.MessageCount, ) return i, err } -const updateMessageContent = `-- name: UpdateMessageContent :exec -UPDATE message -SET content = $2, - status = 'completed', - updated_at = NOW() -WHERE id = $1 -` - -type UpdateMessageContentParams struct { - ID uuid.UUID `json:"id"` - Content []byte `json:"content"` -} - -// Update message content and mark as completed -func (q *Queries) UpdateMessageContent(ctx context.Context, arg UpdateMessageContentParams) error { - _, err := q.db.Exec(ctx, updateMessageContent, arg.ID, arg.Content) - return err -} - -const updateMessageStatus = `-- name: UpdateMessageStatus :exec -UPDATE message -SET status = $2, - updated_at = NOW() -WHERE id = $1 +const sessions = `-- name: Sessions :many +SELECT id, title, created_at, updated_at +FROM sessions +ORDER BY updated_at DESC +LIMIT $2 +OFFSET $1 ` -type UpdateMessageStatusParams struct { - ID uuid.UUID `json:"id"` - Status string `json:"status"` +type SessionsParams struct { + ResultOffset int32 `json:"result_offset"` + ResultLimit int32 `json:"result_limit"` } -// Update message status (streaming/completed/failed) -func (q *Queries) UpdateMessageStatus(ctx context.Context, arg UpdateMessageStatusParams) error { - _, err := q.db.Exec(ctx, updateMessageStatus, arg.ID, arg.Status) - return err +func (q *Queries) Sessions(ctx context.Context, arg SessionsParams) ([]Session, error) { + rows, err := q.db.Query(ctx, sessions, arg.ResultOffset, arg.ResultLimit) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Session{} + for rows.Next() { + var i Session + if err := rows.Scan( + &i.ID, + &i.Title, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil } const updateSessionTitle = `-- name: UpdateSessionTitle :exec @@ -371,17 +209,11 @@ func (q *Queries) UpdateSessionTitle(ctx context.Context, arg UpdateSessionTitle const updateSessionUpdatedAt = `-- name: UpdateSessionUpdatedAt :exec UPDATE sessions -SET updated_at = NOW(), - message_count = $1 -WHERE id = $2 +SET updated_at = NOW() +WHERE id = $1 ` -type UpdateSessionUpdatedAtParams struct { - MessageCount *int32 `json:"message_count"` - SessionID uuid.UUID `json:"session_id"` -} - -func (q *Queries) UpdateSessionUpdatedAt(ctx context.Context, arg UpdateSessionUpdatedAtParams) error { - _, err := q.db.Exec(ctx, updateSessionUpdatedAt, arg.MessageCount, arg.SessionID) +func (q *Queries) UpdateSessionUpdatedAt(ctx context.Context, sessionID uuid.UUID) error { + _, err := q.db.Exec(ctx, updateSessionUpdatedAt, sessionID) return err } diff --git a/internal/testutil/googleai.go b/internal/testutil/googleai.go index 62669bb..bf1265e 100644 --- a/internal/testutil/googleai.go +++ b/internal/testutil/googleai.go @@ -51,9 +51,9 @@ func SetupGoogleAI(tb testing.TB) *GoogleAISetup { ctx := context.Background() // Find project root to get absolute path to prompts directory - projectRoot, err := findProjectRoot() + projectRoot, err := FindProjectRoot() if err != nil { - tb.Fatalf("Failed to find project root: %v", err) + tb.Fatalf("finding project root: %v", err) } promptsDir := filepath.Join(projectRoot, "prompts") @@ -64,19 +64,18 @@ func SetupGoogleAI(tb testing.TB) *GoogleAISetup { // Nil check: genkit.Init returns nil on internal initialization failure if g == nil { - tb.Fatal("Failed to initialize Genkit: genkit.Init returned nil") + tb.Fatal("genkit.Init returned nil") } // Create embedder using config constant for maintainability - embedder := googlegenai.GoogleAIEmbedder(g, config.DefaultEmbedderModel) + embedder := googlegenai.GoogleAIEmbedder(g, config.DefaultGeminiEmbedderModel) // Nil check: GoogleAIEmbedder returns nil if model lookup fails if embedder == nil { - tb.Fatalf("Failed to create embedder: GoogleAIEmbedder returned nil for model %q", config.DefaultEmbedderModel) + tb.Fatalf("GoogleAIEmbedder returned nil for model %q", config.DefaultGeminiEmbedderModel) } - // Create quiet logger for tests (discard all logs) - logger := slog.New(slog.DiscardHandler) + logger := DiscardLogger() return &GoogleAISetup{ Embedder: embedder, diff --git a/internal/testutil/logger.go b/internal/testutil/logger.go index b1acf34..2bf5fba 100644 --- a/internal/testutil/logger.go +++ b/internal/testutil/logger.go @@ -4,15 +4,8 @@ import ( "log/slog" ) -// DiscardLogger returns a slog.Logger that discards all output. -// This is the standard library pattern for test loggers. -// -// Use this in tests to reduce noise. For components that use log.Logger -// (which is a type alias for *slog.Logger), use log.NewNop() directly. -// -// Note: log.Logger is a type alias for *slog.Logger, so this function -// and log.NewNop() return the same type. Prefer log.NewNop() when working -// with the internal/log package. +// DiscardLogger returns a *slog.Logger that discards all output. +// Use this in tests to reduce log noise. func DiscardLogger() *slog.Logger { return slog.New(slog.DiscardHandler) } diff --git a/internal/testutil/postgres.go b/internal/testutil/postgres.go index 30cf689..a056ec0 100644 --- a/internal/testutil/postgres.go +++ b/internal/testutil/postgres.go @@ -26,15 +26,14 @@ import ( // Provides: // - Isolated PostgreSQL instance with pgvector extension // - Connection pool for database operations -// - Automatic cleanup via cleanup function +// - Automatic cleanup via tb.Cleanup (no manual cleanup needed) // // Usage: // -// db, cleanup := testutil.SetupTestDB(t) -// defer cleanup() +// db := testutil.SetupTestDB(t) // // Use db.Pool for database operations type TestDBContainer struct { - Container *postgres.PostgresContainer + container *postgres.PostgresContainer Pool *pgxpool.Pool ConnStr string } @@ -46,25 +45,24 @@ type TestDBContainer struct { // - Test database schema (via migrations) // - Connection pool ready for use // -// Returns: -// - TestDBContainer: Container with connection pool -// - cleanup function: Must be called to terminate container +// Cleanup is registered via tb.Cleanup and runs automatically when the test ends. // // Example: // // func TestMyFeature(t *testing.T) { -// db, cleanup := testutil.SetupTestDB(t) -// defer cleanup() +// db := testutil.SetupTestDB(t) // // // Use db.Pool for queries // var count int // err := db.Pool.QueryRow(ctx, "SELECT COUNT(*) FROM documents").Scan(&count) -// require.NoError(t, err) +// if err != nil { +// t.Fatalf("QueryRow() unexpected error: %v", err) +// } // } // // Note: Accepts testing.TB interface to support both *testing.T (tests) and // *testing.B (benchmarks). This allows the same setup to be used in both contexts. -func SetupTestDB(tb testing.TB) (*TestDBContainer, func()) { +func SetupTestDB(tb testing.TB) *TestDBContainer { tb.Helper() ctx := context.Background() @@ -81,62 +79,58 @@ func SetupTestDB(tb testing.TB) (*TestDBContainer, func()) { WithStartupTimeout(60*time.Second)), ) if err != nil { - tb.Fatalf("Failed to start PostgreSQL container: %v", err) + tb.Fatalf("starting PostgreSQL container: %v", err) } // Get connection string connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") if err != nil { - _ = pgContainer.Terminate(ctx) - tb.Fatalf("Failed to get connection string: %v", err) + _ = pgContainer.Terminate(ctx) // best-effort cleanup + tb.Fatalf("getting connection string: %v", err) } // Create connection pool pool, err := pgxpool.New(ctx, connStr) if err != nil { - _ = pgContainer.Terminate(ctx) - tb.Fatalf("Failed to create connection pool: %v", err) + _ = pgContainer.Terminate(ctx) // best-effort cleanup + tb.Fatalf("creating connection pool: %v", err) } // Verify connection if err := pool.Ping(ctx); err != nil { pool.Close() - _ = pgContainer.Terminate(ctx) - tb.Fatalf("Failed to ping database: %v", err) + _ = pgContainer.Terminate(ctx) // best-effort cleanup + tb.Fatalf("pinging database: %v", err) } // Run migrations if err := runMigrations(ctx, pool); err != nil { pool.Close() _ = pgContainer.Terminate(ctx) - tb.Fatalf("Failed to run migrations: %v", err) + tb.Fatalf("running migrations: %v", err) } container := &TestDBContainer{ - Container: pgContainer, + container: pgContainer, Pool: pool, ConnStr: connStr, } - cleanup := func() { - if pool != nil { - pool.Close() - } - if pgContainer != nil { - _ = pgContainer.Terminate(context.Background()) - } - } + tb.Cleanup(func() { + pool.Close() + _ = pgContainer.Terminate(context.Background()) + }) - return container, cleanup + return container } -// findProjectRoot finds the project root directory by looking for go.mod. +// FindProjectRoot finds the project root directory by looking for go.mod. // This allows tests to run from any subdirectory and still find migration files. -func findProjectRoot() (string, error) { +func FindProjectRoot() (string, error) { // Start from the current file's directory _, filename, _, ok := runtime.Caller(0) if !ok { - return "", fmt.Errorf("failed to get current file path") + return "", fmt.Errorf("getting current file path") } dir := filepath.Dir(filename) @@ -168,9 +162,9 @@ func findProjectRoot() (string, error) { //nolint:gocognit // Complex error handling necessary for transaction safety in test utility func runMigrations(ctx context.Context, pool *pgxpool.Pool) error { // Find project root to build absolute paths to migrations - projectRoot, err := findProjectRoot() + projectRoot, err := FindProjectRoot() if err != nil { - return fmt.Errorf("failed to find project root: %w", err) + return fmt.Errorf("finding project root: %w", err) } // Read and execute migration files in order @@ -183,7 +177,7 @@ func runMigrations(ctx context.Context, pool *pgxpool.Pool) error { // #nosec G304 -- migration paths are hardcoded constants, not from user input migrationSQL, readErr := os.ReadFile(migrationPath) if readErr != nil { - return fmt.Errorf("failed to read migration %s: %w", migrationPath, readErr) + return fmt.Errorf("reading migration %s: %w", migrationPath, readErr) } // Skip empty migration files to avoid unnecessary execution @@ -198,7 +192,7 @@ func runMigrations(ctx context.Context, pool *pgxpool.Pool) error { // This ensures that if a migration fails, changes are rolled back tx, beginErr := pool.Begin(ctx) if beginErr != nil { - return fmt.Errorf("failed to begin transaction for migration %s: %w", migrationPath, beginErr) + return fmt.Errorf("beginning transaction for migration %s: %w", migrationPath, beginErr) } // Ensure transaction is always closed (rollback unless committed) @@ -215,11 +209,11 @@ func runMigrations(ctx context.Context, pool *pgxpool.Pool) error { _, execErr := tx.Exec(ctx, string(migrationSQL)) if execErr != nil { - return fmt.Errorf("failed to execute migration %s: %w", migrationPath, execErr) + return fmt.Errorf("executing migration %s: %w", migrationPath, execErr) } if commitErr := tx.Commit(ctx); commitErr != nil { - return fmt.Errorf("failed to commit migration %s: %w", migrationPath, commitErr) + return fmt.Errorf("committing migration %s: %w", migrationPath, commitErr) } committed = true return nil diff --git a/internal/testutil/postgres_test.go b/internal/testutil/postgres_test.go index 3a8f2e2..e5468ee 100644 --- a/internal/testutil/postgres_test.go +++ b/internal/testutil/postgres_test.go @@ -21,14 +21,13 @@ import ( // Run with: go test -tags=integration ./internal/testutil -v func TestSetupTestDB_Integration(t *testing.T) { // Setup test database - dbContainer, cleanup := SetupTestDB(t) - defer cleanup() + dbContainer := SetupTestDB(t) // Verify database is accessible ctx := context.Background() err := dbContainer.Pool.Ping(ctx) if err != nil { - t.Fatalf("Failed to ping database: %v", err) + t.Fatalf("Pool.Ping() unexpected error: %v", err) } // Verify pgvector extension is installed @@ -36,26 +35,24 @@ func TestSetupTestDB_Integration(t *testing.T) { err = dbContainer.Pool.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'vector')").Scan(&hasExtension) if err != nil { - t.Fatalf("Failed to check for vector extension: %v", err) + t.Fatalf("QueryRow(vector extension check) unexpected error: %v", err) } if !hasExtension { - t.Error("pgvector extension not installed") + t.Error("pgvector extension installed = false, want true") } // Verify all required tables exist - tables := []string{"documents", "sessions", "session_messages"} + tables := []string{"documents", "sessions", "messages"} for _, table := range tables { var exists bool err = dbContainer.Pool.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name = $1)", table).Scan(&exists) if err != nil { - t.Fatalf("Failed to check for table %s: %v", table, err) + t.Fatalf("QueryRow(table %q check) unexpected error: %v", table, err) } if !exists { - t.Errorf("Table %s does not exist", table) + t.Errorf("table %q exists = false, want true", table) } } - - t.Log("Database setup successful with all required tables") } diff --git a/internal/testutil/rag.go b/internal/testutil/rag.go index 2b2c918..712d5bc 100644 --- a/internal/testutil/rag.go +++ b/internal/testutil/rag.go @@ -2,6 +2,8 @@ package testutil import ( "context" + "os" + "path/filepath" "testing" "github.com/firebase/genkit/go/ai" @@ -46,8 +48,7 @@ type RAGSetup struct { // Example: // // func TestRAGFeature(t *testing.T) { -// db, dbCleanup := testutil.SetupTestDB(t) -// defer dbCleanup() +// db := testutil.SetupTestDB(t) // // rag := testutil.SetupRAG(t, db.Pool) // @@ -64,10 +65,12 @@ type RAGSetup struct { func SetupRAG(tb testing.TB, pool *pgxpool.Pool) *RAGSetup { tb.Helper() - // Check for API key (SetupGoogleAI skips test if not set) - // We don't use the returned setup because we need to create a - // Genkit instance with both GoogleAI and PostgreSQL plugins - _ = SetupGoogleAI(tb) + // Check for API key — skip test if not set. + // We don't use SetupGoogleAI here because we need a Genkit instance + // with both GoogleAI and PostgreSQL plugins. + if os.Getenv("GEMINI_API_KEY") == "" { + tb.Skip("GEMINI_API_KEY not set - skipping test requiring embedder") + } ctx := context.Background() @@ -77,7 +80,7 @@ func SetupRAG(tb testing.TB, pool *pgxpool.Pool) *RAGSetup { postgresql.WithDatabase("koopa_test"), ) if err != nil { - tb.Fatalf("Failed to create PostgresEngine: %v", err) + tb.Fatalf("creating PostgresEngine: %v", err) } // Create PostgreSQL plugin @@ -85,30 +88,30 @@ func SetupRAG(tb testing.TB, pool *pgxpool.Pool) *RAGSetup { // Re-initialize Genkit with both plugins // We need to create a new Genkit instance that has both plugins - projectRoot, err := findProjectRoot() + projectRoot, err := FindProjectRoot() if err != nil { - tb.Fatalf("Failed to find project root: %v", err) + tb.Fatalf("finding project root: %v", err) } g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{}, postgres), - genkit.WithPromptDir(projectRoot+"/prompts"), + genkit.WithPromptDir(filepath.Join(projectRoot, "prompts")), ) if g == nil { - tb.Fatal("Failed to initialize Genkit with PostgreSQL plugin") + tb.Fatal("genkit.Init with PostgreSQL plugin returned nil") } // Create embedder - embedder := googlegenai.GoogleAIEmbedder(g, config.DefaultEmbedderModel) + embedder := googlegenai.GoogleAIEmbedder(g, config.DefaultGeminiEmbedderModel) if embedder == nil { - tb.Fatalf("Failed to create embedder for model %q", config.DefaultEmbedderModel) + tb.Fatalf("GoogleAIEmbedder returned nil for model %q", config.DefaultGeminiEmbedderModel) } // Create DocStore and Retriever using shared config factory cfg := rag.NewDocStoreConfig(embedder) docStore, retriever, err := postgresql.DefineRetriever(ctx, g, postgres, cfg) if err != nil { - tb.Fatalf("Failed to define retriever: %v", err) + tb.Fatalf("defining retriever: %v", err) } return &RAGSetup{ diff --git a/internal/tools/README.md b/internal/tools/README.md deleted file mode 100644 index b8a6772..0000000 --- a/internal/tools/README.md +++ /dev/null @@ -1,132 +0,0 @@ -# Tools Package - -Secure AI agent toolkit with structured error handling and zero-global-state design. - -[繁體中文](./README_ZH_TW.md) - ---- - -## Design Philosophy - -The Tools building maintainable, testable, and LLM-friendly agent tools. - -### Core Principles - -**1. Config Struct Pattern** -All dependencies injected via `KitConfig` structure—no global state, explicit dependencies, testable by design. - -**2. Functional Options Pattern** -Optional features (logging, metrics) configured through functional options, preserving backward compatibility. - -**3. Structured Results** -Standardized `Result{Status, Data, Error}` format enables consistent LLM interaction and rich error context. - -**4. Error Semantics: Agent vs System** -- **Agent Errors**: Recoverable failures (file not found, permission denied) returned in `Result` → LLM can retry -- **System Errors**: Infrastructure failures (database down, OOM) returned as Go `error` → requires human intervention - -**5. Zero Maintenance Tool Registry** -Leverages `genkit.ListTools()` API—single source of truth, no manual list synchronization. - ---- - -## Architecture - -```mermaid -graph TD - Kit[Kit - Tool Collection
Config Struct Pattern
Functional Options
No Global State] - - Kit --> FileTools[File Tools] - Kit --> SystemTools[System Tools] - Kit --> NetworkTools[Network Tools] - Kit --> KnowledgeTools[Knowledge Tools] - - FileTools --> ReadFile[readFile] - FileTools --> WriteFile[writeFile] - FileTools --> ListFiles[listFiles] - FileTools --> DeleteFile[deleteFile] - FileTools --> GetFileInfo[getFileInfo] - - SystemTools --> ExecuteCommand[executeCommand] - SystemTools --> GetEnv[getEnv] - SystemTools --> CurrentTime[currentTime] - - NetworkTools --> HTTPGet[httpGet
SSRF Protection] - - KnowledgeTools --> SearchHistory[searchHistory] - KnowledgeTools --> SearchDocuments[searchDocuments] - KnowledgeTools --> SearchSystemKnowledge[searchSystemKnowledge] -``` - -**12 tools across 4 categories**: File (5) • System (3) • Network (1) • Knowledge (3) - ---- - -## Design Decisions - -### Why Structured Result over Raw Returns? - -**Problem**: Inconsistent error handling across tools makes LLM interaction unpredictable. - -**Solution**: Standardized `Result` type with status, data, and structured errors. - -**Benefits**: -- LLM can parse status semantically -- Rich debugging context through error codes -- Consistent interface across all tools -- Enables programmatic error handling - -### Why Distinguish Agent Error from System Error? - -**Problem**: Treating all errors equally prevents LLM from recovering gracefully. - -**Solution**: Agent errors in `Result` (LLM-visible), system errors as Go `error` (human-visible). - -**Benefits**: -- LLM can retry on agent errors (wrong path, missing permission) -- System errors halt execution (database failure requires ops intervention) -- Genkit framework handles each type appropriately - -### Why Zero Tool List Maintenance? - -**Problem**: Manual tool lists drift out of sync with registrations, causing runtime failures. - -**Solution**: Use Genkit's `ListTools()` API as single source of truth. - -**Benefits**: -- DRY principle—register once, enumerate anywhere -- No sync bugs when adding/removing tools -- Runtime tool discovery support - ---- - -## Security Model - -All tools enforce **defense-in-depth** validation: - -- **Path Validation**: Prevents traversal attacks, enforces allow-list directories, resolves symlinks -- **Command Validation**: Blocks dangerous commands (`rm -rf`, `dd`, `format`, `sudo`, ...) -- **Environment Filtering**: Blocks sensitive variables (API keys, passwords, tokens) -- **HTTP Protection**: SSRF defense (blocks internal IPs, localhost, metadata services), size limits, timeouts - -**Security Philosophy**: Fail closed—deny by default, allow explicitly validated operations only. - ---- - -## Design Influences - -### Genkit Framework - -**Source**: [Firebase Genkit](https://github.com/firebase/genkit) - Google's AI framework for building production-ready AI agents - -**Design Philosophy**: -- **Tool-centric**: AI agents invoke tools to interact with the world -- **Registry pattern**: Central tool registry (`ListTools()`, `LookupTool()`) eliminates manual list maintenance -- **Structured I/O**: Typed inputs/outputs with JSON schema validation -- **Framework-managed lifecycle**: Genkit handles tool registration, discovery, and invocation - -**Why we adopted it**: -- Zero boilerplate—register once, use anywhere -- Type-safe tool definitions with schema validation -- Built-in observability and tracing -- Production-ready error handling diff --git a/internal/tools/README_ZH_TW.md b/internal/tools/README_ZH_TW.md deleted file mode 100644 index 63cd50c..0000000 --- a/internal/tools/README_ZH_TW.md +++ /dev/null @@ -1,136 +0,0 @@ -# Tools 套件 - -具備結構化錯誤處理和零全局狀態設計的安全 AI 代理工具包。 - -[English](./README.md) - ---- - -## 設計理念 - -Tools 用於構建可維護、可測試、LLM 友好的代理工具。 - -### 核心原則 - -**1. Config Struct Pattern** -所有依賴通過 `KitConfig` 結構注入—無全局狀態、明確依賴、設計即可測試。 - -**2. Functional Options Pattern** -可選功能(日誌、指標)通過函數選項配置,保持向後兼容性。 - -**3. 結構化結果** -標準化的 `Result{Status, Data, Error}` 格式實現一致的 LLM 互動和豐富的錯誤上下文。 - -**4. 錯誤語意:Agent vs System** - -- **Agent Errors**:可恢復故障(檔案不存在、權限拒絕)在 `Result` 中返回 → LLM 可重試 -- **System Errors**:基礎設施故障(資料庫停機、記憶體不足)作為 Go `error` 返回 → 需人工介入 - -**5. 零維護工具註冊表** -利用 `genkit.ListTools()` API—單一事實來源,無需手動列表同步。 - ---- - -## 架構 - -```mermaid -graph TD - Kit[Kit - 工具集合
Config Struct Pattern
Functional Options
無全局狀態] - - Kit --> FileTools[檔案工具] - Kit --> SystemTools[系統工具] - Kit --> NetworkTools[網路工具] - Kit --> KnowledgeTools[知識工具] - - FileTools --> ReadFile[readFile] - FileTools --> WriteFile[writeFile] - FileTools --> ListFiles[listFiles] - FileTools --> DeleteFile[deleteFile] - FileTools --> GetFileInfo[getFileInfo] - - SystemTools --> ExecuteCommand[executeCommand] - SystemTools --> GetEnv[getEnv] - SystemTools --> CurrentTime[currentTime] - - NetworkTools --> HTTPGet[httpGet
SSRF 防護] - - KnowledgeTools --> SearchHistory[searchHistory] - KnowledgeTools --> SearchDocuments[searchDocuments] - KnowledgeTools --> SearchSystemKnowledge[searchSystemKnowledge] -``` - -**4 大類別 12 個工具**:檔案(5)• 系統(3)• 網路(1)• 知識(3) - ---- - -### 為什麼用結構化 Result 而非原始返回值? - -**問題**:工具間不一致的錯誤處理使 LLM 互動不可預測。 - -**解決方案**:標準化的 `Result` 類型包含狀態、數據和結構化錯誤。 - -**優點**: - -- LLM 可語意化解析狀態 -- 通過錯誤代碼提供豐富除錯上下文 -- 所有工具間一致的介面 -- 支援程式化錯誤處理 - -### 為什麼區分 Agent Error 和 System Error? - -**問題**:平等對待所有錯誤阻止 LLM 優雅恢復。 - -**解決方案**:Agent 錯誤放在 `Result` 中(LLM 可見),系統錯誤作為 Go `error`(人類可見)。 - -**優點**: - -- LLM 可在 agent 錯誤時重試(錯誤路徑、缺少權限) -- 系統錯誤停止執行(資料庫故障需運維介入) -- Genkit 框架適當處理每種類型 - -### 為什麼零工具列表維護? - -**問題**:手動工具列表與註冊不同步,導致執行時故障。 - -**解決方案**:使用 Genkit 的 `ListTools()` API 作為單一事實來源。 - -**優點**: - -- DRY 原則—註冊一次,隨處枚舉 -- 添加/刪除工具時無同步錯誤 -- 支援執行時工具發現 - ---- - -## 安全模型 - -所有工具強制執行**縱深防禦**驗證: - -- **路徑驗證**:防止穿越攻擊、強制允許列表目錄、解析符號連結 -- **命令驗證**:阻擋危險命令(`rm -rf`、`dd`、`format`、`sudo`、...) -- **環境變數過濾**:阻擋敏感變數(API 金鑰、密碼、令牌) -- **HTTP 防護**:SSRF 防禦(阻擋內部 IP、localhost、元數據服務)、大小限制、逾時 - -**安全理念**:預設關閉—預設拒絕,僅允許明確驗證的操作。 - ---- - -## 設計影響 - -### Genkit 框架 - -**來源**:[Firebase Genkit](https://github.com/firebase/genkit) - Google 的 AI 框架,用於構建生產級 AI 代理 - -**設計理念**: - -- **以工具為中心**:AI 代理通過調用工具與世界互動 -- **註冊表模式**:中央工具註冊表(`ListTools()`、`LookupTool()`)消除手動列表維護 -- **結構化 I/O**:帶 JSON schema 驗證的類型化輸入/輸出 -- **框架管理生命週期**:Genkit 處理工具註冊、發現和調用 - -**為什麼採用它**: - -- 零樣板代碼—註冊一次,隨處使用 -- 帶 schema 驗證的類型安全工具定義 -- 內建可觀察性和追蹤 -- 生產就緒的錯誤處理 diff --git a/internal/tools/benchmark_test.go b/internal/tools/benchmark_test.go index 1b9e5c9..d6a4a0e 100644 --- a/internal/tools/benchmark_test.go +++ b/internal/tools/benchmark_test.go @@ -8,7 +8,7 @@ import ( // BenchmarkClampTopK benchmarks the clampTopK function. func BenchmarkClampTopK(b *testing.B) { b.ReportAllocs() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = clampTopK(5, 3) } } @@ -17,7 +17,7 @@ func BenchmarkClampTopK(b *testing.B) { func BenchmarkResultConstruction(b *testing.B) { b.Run("success", func(b *testing.B) { b.ReportAllocs() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = Result{ Status: StatusSuccess, Data: map[string]any{"key": "value"}, @@ -27,7 +27,7 @@ func BenchmarkResultConstruction(b *testing.B) { b.Run("error", func(b *testing.B) { b.ReportAllocs() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = Result{ Status: StatusError, Error: &Error{Code: ErrCodeSecurity, Message: "test error"}, @@ -36,9 +36,9 @@ func BenchmarkResultConstruction(b *testing.B) { }) } -// BenchmarkNetworkToolsCreation benchmarks NetworkTools constructor. -func BenchmarkNetworkToolsCreation(b *testing.B) { - cfg := NetworkConfig{ +// BenchmarkNetworkCreation benchmarks Network constructor. +func BenchmarkNetworkCreation(b *testing.B) { + cfg := NetConfig{ SearchBaseURL: "http://localhost:8080", FetchParallelism: 2, FetchDelay: time.Second, @@ -48,20 +48,20 @@ func BenchmarkNetworkToolsCreation(b *testing.B) { b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = NewNetworkTools(cfg, logger) + for b.Loop() { + _, _ = NewNetwork(cfg, logger) } } // BenchmarkFilterURLs benchmarks URL filtering and validation. func BenchmarkFilterURLs(b *testing.B) { - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: "http://localhost:8080", FetchParallelism: 2, FetchDelay: time.Second, FetchTimeout: 30 * time.Second, } - nt, _ := NewNetworkTools(cfg, testLogger()) + nt, _ := NewNetwork(cfg, testLogger()) urls := []string{ "https://example.com/", @@ -74,21 +74,19 @@ func BenchmarkFilterURLs(b *testing.B) { b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { _, _ = nt.filterURLs(urls) } } // BenchmarkExtractNonHTMLContent benchmarks non-HTML content extraction. func BenchmarkExtractNonHTMLContent(b *testing.B) { - nt := &NetworkTools{} - b.Run("json", func(b *testing.B) { body := []byte(`{"key": "value", "nested": {"a": 1, "b": 2}}`) b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = nt.extractNonHTMLContent(body, "application/json") + for b.Loop() { + _, _ = extractNonHTMLContent(body, "application/json") } }) @@ -96,8 +94,8 @@ func BenchmarkExtractNonHTMLContent(b *testing.B) { body := []byte("Plain text content here") b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = nt.extractNonHTMLContent(body, "text/plain") + for b.Loop() { + _, _ = extractNonHTMLContent(body, "text/plain") } }) @@ -112,8 +110,8 @@ func BenchmarkExtractNonHTMLContent(b *testing.B) { b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = nt.extractNonHTMLContent(body, "application/json") + for b.Loop() { + _, _ = extractNonHTMLContent(body, "application/json") } }) } @@ -133,7 +131,7 @@ func BenchmarkFetchState(b *testing.B) { b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { state.addResult(result) } }) @@ -145,23 +143,18 @@ func BenchmarkFetchState(b *testing.B) { b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + i := 0 + for b.Loop() { state.markProcessed("https://example.com/page" + string(rune(i%100))) + i++ } }) } -// BenchmarkFileToolsCreation benchmarks FileTools constructor. -func BenchmarkFileToolsCreation(b *testing.B) { - // Note: This benchmark requires a valid path validator - // Skip if we can't create one - b.Skip("requires valid path validator setup") -} - // BenchmarkKnowledgeSearchInput benchmarks the unified search input struct. func BenchmarkKnowledgeSearchInput(b *testing.B) { b.ReportAllocs() - for i := 0; i < b.N; i++ { + for b.Loop() { _ = KnowledgeSearchInput{ Query: "test query", TopK: 5, diff --git a/internal/tools/doc.go b/internal/tools/doc.go index 533ab36..b29b97b 100644 --- a/internal/tools/doc.go +++ b/internal/tools/doc.go @@ -10,35 +10,36 @@ // // Tools are organized into four categories: // -// - FileTools: File operations (read, write, list, delete, info) -// - SystemTools: System operations (time, command execution, environment) -// - NetworkTools: Network operations (web search, web fetch) -// - knowledgeTools: Knowledge base operations (semantic search) +// - File: File operations (read, write, list, delete, info) +// - System: System operations (time, command execution, environment) +// - Network: Network operations (web search, web fetch) +// - Knowledge: Knowledge base operations (semantic search) // // Each tool struct is created with a constructor, then registered with Genkit. // // # Available Tools // -// File tools (FileTools): +// File tools (File): // - read_file: Read file contents (max 10MB) // - write_file: Write or create files // - list_files: List directory contents // - delete_file: Delete a file // - get_file_info: Get file metadata // -// System tools (SystemTools): +// System tools (System): // - current_time: Get current system time // - execute_command: Execute shell commands (whitelist enforced) // - get_env: Read environment variables (secrets protected) // -// Network tools (NetworkTools): +// Network tools (Network): // - web_search: Search via SearXNG // - web_fetch: Fetch web content with SSRF protection // -// Knowledge tools (knowledgeTools): +// Knowledge tools (Knowledge): // - search_history: Search conversation history // - search_documents: Search indexed documents // - search_system_knowledge: Search system knowledge base +// - knowledge_store: Store knowledge documents (when DocStore is available) // // # Security // @@ -50,7 +51,9 @@ // // # Result Type // -// All tools return the unified Result type: +// File, System, and Knowledge tools return the unified Result type. +// Network tools (Search, Fetch) use typed output structs (SearchOutput, FetchOutput) +// with an Error string field for LLM-facing business errors. // // type Result struct { // Status Status // StatusSuccess or StatusError @@ -97,11 +100,11 @@ // // // Create tools with security validators // pathVal, _ := security.NewPath([]string{"/allowed/path"}) -// fileTools, err := tools.NewFileTools(pathVal, logger) +// fileTools, err := tools.NewFile(pathVal, logger) // if err != nil { // return err // } // // // Register with Genkit -// fileToolList, _ := tools.RegisterFileTools(g, fileTools) +// fileToolList, _ := tools.RegisterFile(g, fileTools) package tools diff --git a/internal/tools/emitter.go b/internal/tools/emitter.go index b130b66..4dc6ca3 100644 --- a/internal/tools/emitter.go +++ b/internal/tools/emitter.go @@ -1,25 +1,22 @@ -// Package tools provides tool abstractions for AI agent interactions. package tools import ( "context" ) -// emitterKey uses empty struct for zero-allocation context key. -// Per Rob Pike: empty struct is idiomatic for context keys. +// emitterKey is an unexported context key for zero-allocation type safety. type emitterKey struct{} -// ToolEventEmitter receives tool lifecycle events. +// Emitter receives tool lifecycle events. // Interface is minimal - only tool name, no UI concerns. -// Per architecture-master: Interface for loose coupling between tools and SSE layer. -// UI presentation logic moved to web/handlers layer. +// UI presentation logic is handled by the SSE/API layer. // // Usage: // 1. Handler creates emitter bound to SSE writer // 2. Handler stores emitter in context via ContextWithEmitter() // 3. Wrapped tool retrieves emitter via EmitterFromContext() // 4. Tool calls OnToolStart/Complete/Error during execution -type ToolEventEmitter interface { +type Emitter interface { // OnToolStart signals that a tool has started execution. // name: tool name (e.g., "web_search") // UI presentation (messages, icons) handled by web layer. @@ -35,16 +32,14 @@ type ToolEventEmitter interface { OnToolError(name string) } -// EmitterFromContext retrieves ToolEventEmitter from context. +// EmitterFromContext retrieves Emitter from context. // Returns nil if not set, allowing graceful degradation (no events emitted). -// Per architecture-master: Non-streaming code paths won't have emitter set. -func EmitterFromContext(ctx context.Context) ToolEventEmitter { - emitter, _ := ctx.Value(emitterKey{}).(ToolEventEmitter) +func EmitterFromContext(ctx context.Context) Emitter { + emitter, _ := ctx.Value(emitterKey{}).(Emitter) return emitter } -// ContextWithEmitter stores ToolEventEmitter in context. -// Per architecture-master: Per-request binding via context.Context. -func ContextWithEmitter(ctx context.Context, emitter ToolEventEmitter) context.Context { +// ContextWithEmitter stores Emitter in context for per-request binding. +func ContextWithEmitter(ctx context.Context, emitter Emitter) context.Context { return context.WithValue(ctx, emitterKey{}, emitter) } diff --git a/internal/tools/emitter_test.go b/internal/tools/emitter_test.go index f4761e1..1a318f7 100644 --- a/internal/tools/emitter_test.go +++ b/internal/tools/emitter_test.go @@ -7,7 +7,7 @@ import ( "github.com/koopa0/koopa/internal/tools" ) -// mockEmitter is a test implementation of ToolEventEmitter. +// mockEmitter is a test implementation of Emitter. // Interface simplified to only tool name parameter. type mockEmitter struct { startCalls []string @@ -27,8 +27,8 @@ func (m *mockEmitter) OnToolError(name string) { m.errorCalls = append(m.errorCalls, name) } -// Verify mockEmitter implements ToolEventEmitter. -var _ tools.ToolEventEmitter = (*mockEmitter)(nil) +// Verify mockEmitter implements Emitter. +var _ tools.Emitter = (*mockEmitter)(nil) func TestContextWithEmitter(t *testing.T) { t.Parallel() @@ -43,7 +43,7 @@ func TestContextWithEmitter(t *testing.T) { retrieved := tools.EmitterFromContext(ctxWithEmitter) if retrieved == nil { - t.Fatal("expected emitter to be retrieved from context") + t.Fatal("EmitterFromContext() = nil, want non-nil") } // Compare via interface method behavior instead of pointer equality retrieved.OnToolStart("test") @@ -65,7 +65,7 @@ func TestContextWithEmitter(t *testing.T) { retrieved.OnToolStart("test") // emitter2 should receive the call, not emitter1 if len(emitter2.startCalls) != 1 { - t.Error("expected second emitter to overwrite first") + t.Error("ContextWithEmitter() did not overwrite previous emitter") } if len(emitter1.startCalls) != 0 { t.Error("first emitter should not receive calls") @@ -83,7 +83,7 @@ func TestEmitterFromContext(t *testing.T) { emitter := tools.EmitterFromContext(ctx) if emitter != nil { - t.Error("expected nil emitter from empty context") + t.Error("EmitterFromContext(empty) = non-nil, want nil") } }) @@ -96,7 +96,7 @@ func TestEmitterFromContext(t *testing.T) { emitter := tools.EmitterFromContext(ctx) if emitter != nil { - t.Error("expected nil for missing emitter") + t.Error("EmitterFromContext(no-emitter) = non-nil, want nil") } }) @@ -124,7 +124,7 @@ func TestEmitterInterface(t *testing.T) { emitter.OnToolStart("web_search") if len(emitter.startCalls) != 1 { - t.Fatalf("expected 1 start call, got %d", len(emitter.startCalls)) + t.Fatalf("OnToolStart() call count = %d, want 1", len(emitter.startCalls)) } if emitter.startCalls[0] != "web_search" { @@ -139,7 +139,7 @@ func TestEmitterInterface(t *testing.T) { emitter.OnToolComplete("web_search") if len(emitter.completeCalls) != 1 { - t.Fatalf("expected 1 complete call, got %d", len(emitter.completeCalls)) + t.Fatalf("OnToolComplete() call count = %d, want 1", len(emitter.completeCalls)) } if emitter.completeCalls[0] != "web_search" { @@ -154,7 +154,7 @@ func TestEmitterInterface(t *testing.T) { emitter.OnToolError("web_search") if len(emitter.errorCalls) != 1 { - t.Fatalf("expected 1 error call, got %d", len(emitter.errorCalls)) + t.Fatalf("OnToolError() call count = %d, want 1", len(emitter.errorCalls)) } if emitter.errorCalls[0] != "web_search" { diff --git a/internal/tools/events_test.go b/internal/tools/events_test.go index 130e9e4..f66bfa3 100644 --- a/internal/tools/events_test.go +++ b/internal/tools/events_test.go @@ -8,7 +8,7 @@ import ( "github.com/firebase/genkit/go/ai" ) -// mockEmitterForEvents is a test implementation of ToolEventEmitter. +// mockEmitterForEvents is a test implementation of Emitter. type mockEmitterForEvents struct { startCalls []string completeCalls []string @@ -27,8 +27,8 @@ func (m *mockEmitterForEvents) OnToolError(name string) { m.errorCalls = append(m.errorCalls, name) } -// Verify mockEmitterForEvents implements ToolEventEmitter. -var _ ToolEventEmitter = (*mockEmitterForEvents)(nil) +// Verify mockEmitterForEvents implements Emitter. +var _ Emitter = (*mockEmitterForEvents)(nil) func TestWithEvents_Success(t *testing.T) { emitter := &mockEmitterForEvents{} diff --git a/internal/tools/file.go b/internal/tools/file.go index 9f7c81c..2d92d10 100644 --- a/internal/tools/file.go +++ b/internal/tools/file.go @@ -3,13 +3,13 @@ package tools import ( "fmt" "io" + "log/slog" "os" "path/filepath" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/genkit" - "github.com/koopa0/koopa/internal/log" "github.com/koopa0/koopa/internal/security" ) @@ -25,12 +25,18 @@ type FileEntry struct { Type string `json:"type"` // "file" or "directory" } +// Tool name constants for file operations registered with Genkit. const ( - ToolReadFile = "read_file" - ToolWriteFile = "write_file" - ToolListFiles = "list_files" - ToolDeleteFile = "delete_file" - ToolGetFileInfo = "get_file_info" + // ReadFileName is the Genkit tool name for reading file contents. + ReadFileName = "read_file" + // WriteFileName is the Genkit tool name for writing file contents. + WriteFileName = "write_file" + // ListFilesName is the Genkit tool name for listing directory contents. + ListFilesName = "list_files" + // DeleteFileName is the Genkit tool name for deleting files. + DeleteFileName = "delete_file" + // FileInfoName is the Genkit tool name for retrieving file metadata. + FileInfoName = "get_file_info" ) // MaxReadFileSize is the maximum file size allowed for ReadFile (10 MB). @@ -63,78 +69,78 @@ type GetFileInfoInput struct { Path string `json:"path" jsonschema_description:"The file path to get info for"` } -// FileTools provides file operation handlers. -// Use NewFileTools to create an instance, then either: +// File provides file operation handlers. +// Use NewFile to create an instance, then either: // - Call methods directly (for MCP) -// - Use RegisterFileTools to register with Genkit -type FileTools struct { +// - Use RegisterFile to register with Genkit +type File struct { pathVal *security.Path - logger log.Logger + logger *slog.Logger } -// NewFileTools creates a FileTools instance. -func NewFileTools(pathVal *security.Path, logger log.Logger) (*FileTools, error) { +// NewFile creates a File instance. +func NewFile(pathVal *security.Path, logger *slog.Logger) (*File, error) { if pathVal == nil { return nil, fmt.Errorf("path validator is required") } if logger == nil { return nil, fmt.Errorf("logger is required") } - return &FileTools{pathVal: pathVal, logger: logger}, nil + return &File{pathVal: pathVal, logger: logger}, nil } -// RegisterFileTools registers all file operation tools with Genkit. -func RegisterFileTools(g *genkit.Genkit, ft *FileTools) ([]ai.Tool, error) { +// RegisterFile registers all file operation tools with Genkit. +func RegisterFile(g *genkit.Genkit, ft *File) ([]ai.Tool, error) { if g == nil { return nil, fmt.Errorf("genkit instance is required") } if ft == nil { - return nil, fmt.Errorf("FileTools is required") + return nil, fmt.Errorf("File is required") } return []ai.Tool{ - genkit.DefineTool(g, ToolReadFile, + genkit.DefineTool(g, ReadFileName, "Read the complete content of a text-based file. "+ "Use this to examine source code, configuration files, logs, or documentation. "+ "Supports files up to 10MB. Binary files are not supported and will return an error. "+ "Returns: file path, content (UTF-8), size in bytes, and line count. "+ "Common errors: file not found (verify path with list_files), "+ "permission denied, file too large, binary file detected.", - WithEvents(ToolReadFile, ft.ReadFile)), - genkit.DefineTool(g, ToolWriteFile, + WithEvents(ReadFileName, ft.ReadFile)), + genkit.DefineTool(g, WriteFileName, "Write or create a text-based file with the specified content. "+ "Creates parent directories automatically if they don't exist. "+ "Overwrites existing files without confirmation. "+ "Use this for: creating new files, updating configuration, saving generated content. "+ "Returns: file path, bytes written, whether file was created or updated. "+ "Common errors: permission denied, disk full, invalid path.", - WithEvents(ToolWriteFile, ft.WriteFile)), - genkit.DefineTool(g, ToolListFiles, + WithEvents(WriteFileName, ft.WriteFile)), + genkit.DefineTool(g, ListFilesName, "List files and subdirectories in a directory. "+ "Returns file names, sizes, types (file/directory), and modification times. "+ "Does not recurse into subdirectories (use recursively for deep exploration). "+ "Use this to: explore project structure, find files by name, verify paths. "+ "Tip: Start from the project root and navigate down to find specific files.", - WithEvents(ToolListFiles, ft.ListFiles)), - genkit.DefineTool(g, ToolDeleteFile, + WithEvents(ListFilesName, ft.ListFiles)), + genkit.DefineTool(g, DeleteFileName, "Permanently delete a file or empty directory. "+ "WARNING: This action cannot be undone. "+ "Only deletes empty directories (use with caution). "+ "Returns: confirmation of deletion with file path. "+ "Common errors: file not found, directory not empty, permission denied.", - WithEvents(ToolDeleteFile, ft.DeleteFile)), - genkit.DefineTool(g, ToolGetFileInfo, + WithEvents(DeleteFileName, ft.DeleteFile)), + genkit.DefineTool(g, FileInfoName, "Get detailed metadata about a file without reading its contents. "+ "Returns: file size, modification time, permissions, and type (file/directory). "+ "Use this to: check if a file exists, verify file size before reading, "+ "determine file type without opening it. "+ "More efficient than read_file when you only need metadata.", - WithEvents(ToolGetFileInfo, ft.GetFileInfo)), + WithEvents(FileInfoName, ft.GetFileInfo)), }, nil } // ReadFile reads and returns the complete content of a file with security validation. -func (f *FileTools) ReadFile(_ *ai.ToolContext, input ReadFileInput) (Result, error) { +func (f *File) ReadFile(_ *ai.ToolContext, input ReadFileInput) (Result, error) { f.logger.Info("ReadFile called", "path", input.Path) safePath, err := f.pathVal.Validate(input.Path) @@ -143,7 +149,7 @@ func (f *FileTools) ReadFile(_ *ai.ToolContext, input ReadFileInput) (Result, er Status: StatusError, Error: &Error{ Code: ErrCodeSecurity, - Message: fmt.Sprintf("path validation failed: %v", err), + Message: fmt.Sprintf("validating path: %v", err), }, }, nil } @@ -212,7 +218,7 @@ func (f *FileTools) ReadFile(_ *ai.ToolContext, input ReadFileInput) (Result, er } // WriteFile writes content to a file with security validation. -func (f *FileTools) WriteFile(_ *ai.ToolContext, input WriteFileInput) (Result, error) { +func (f *File) WriteFile(_ *ai.ToolContext, input WriteFileInput) (Result, error) { f.logger.Info("WriteFile called", "path", input.Path) safePath, err := f.pathVal.Validate(input.Path) @@ -221,7 +227,7 @@ func (f *FileTools) WriteFile(_ *ai.ToolContext, input WriteFileInput) (Result, Status: StatusError, Error: &Error{ Code: ErrCodeSecurity, - Message: fmt.Sprintf("path validation failed: %v", err), + Message: fmt.Sprintf("validating path: %v", err), }, }, nil } @@ -270,7 +276,7 @@ func (f *FileTools) WriteFile(_ *ai.ToolContext, input WriteFileInput) (Result, } // ListFiles lists files in a directory. -func (f *FileTools) ListFiles(_ *ai.ToolContext, input ListFilesInput) (Result, error) { +func (f *File) ListFiles(_ *ai.ToolContext, input ListFilesInput) (Result, error) { f.logger.Info("ListFiles called", "path", input.Path) safePath, err := f.pathVal.Validate(input.Path) @@ -279,7 +285,7 @@ func (f *FileTools) ListFiles(_ *ai.ToolContext, input ListFilesInput) (Result, Status: StatusError, Error: &Error{ Code: ErrCodeSecurity, - Message: fmt.Sprintf("path validation failed: %v", err), + Message: fmt.Sprintf("validating path: %v", err), }, }, nil } @@ -318,7 +324,7 @@ func (f *FileTools) ListFiles(_ *ai.ToolContext, input ListFilesInput) (Result, } // DeleteFile permanently deletes a file with security validation. -func (f *FileTools) DeleteFile(_ *ai.ToolContext, input DeleteFileInput) (Result, error) { +func (f *File) DeleteFile(_ *ai.ToolContext, input DeleteFileInput) (Result, error) { f.logger.Info("DeleteFile called", "path", input.Path) safePath, err := f.pathVal.Validate(input.Path) @@ -327,7 +333,7 @@ func (f *FileTools) DeleteFile(_ *ai.ToolContext, input DeleteFileInput) (Result Status: StatusError, Error: &Error{ Code: ErrCodeSecurity, - Message: fmt.Sprintf("path validation failed: %v", err), + Message: fmt.Sprintf("validating path: %v", err), }, }, nil } @@ -351,7 +357,7 @@ func (f *FileTools) DeleteFile(_ *ai.ToolContext, input DeleteFileInput) (Result } // GetFileInfo gets file metadata. -func (f *FileTools) GetFileInfo(_ *ai.ToolContext, input GetFileInfoInput) (Result, error) { +func (f *File) GetFileInfo(_ *ai.ToolContext, input GetFileInfoInput) (Result, error) { f.logger.Info("GetFileInfo called", "path", input.Path) safePath, err := f.pathVal.Validate(input.Path) @@ -360,7 +366,7 @@ func (f *FileTools) GetFileInfo(_ *ai.ToolContext, input GetFileInfoInput) (Resu Status: StatusError, Error: &Error{ Code: ErrCodeSecurity, - Message: fmt.Sprintf("path validation failed: %v", err), + Message: fmt.Sprintf("validating path: %v", err), }, }, nil } diff --git a/internal/tools/file_integration_test.go b/internal/tools/file_integration_test.go index 34f67a1..ee0b204 100644 --- a/internal/tools/file_integration_test.go +++ b/internal/tools/file_integration_test.go @@ -9,7 +9,7 @@ import ( "github.com/koopa0/koopa/internal/security" ) -// fileTools provides test utilities for FileTools. +// fileTools provides test utilities for File. type fileTools struct { t *testing.T tempDir string @@ -21,20 +21,20 @@ func newfileTools(t *testing.T) *fileTools { // Resolve symlinks (macOS /var -> /private/var) realTempDir, err := filepath.EvalSymlinks(tempDir) if err != nil { - t.Fatalf("failed to resolve temp dir symlinks: %v", err) + t.Fatalf("resolving temp dir symlinks: %v", err) } return &fileTools{t: t, tempDir: realTempDir} } -func (h *fileTools) createFileTools() *FileTools { +func (h *fileTools) createFile() *File { h.t.Helper() pathVal, err := security.NewPath([]string{h.tempDir}) if err != nil { - h.t.Fatalf("failed to create path validator: %v", err) + h.t.Fatalf("creating path validator: %v", err) } - ft, err := NewFileTools(pathVal, testLogger()) + ft, err := NewFile(pathVal, testLogger()) if err != nil { - h.t.Fatalf("failed to create file tools: %v", err) + h.t.Fatalf("creating file tools: %v", err) } return ft } @@ -44,16 +44,12 @@ func (h *fileTools) createTestFile(name, content string) string { path := filepath.Join(h.tempDir, name) err := os.WriteFile(path, []byte(content), 0o600) if err != nil { - h.t.Fatalf("failed to create test file: %v", err) + h.t.Fatalf("creating test file: %v", err) } return path } -// ============================================================================ -// ReadFile Integration Tests -// ============================================================================ - -func TestFileTools_ReadFile_PathSecurity(t *testing.T) { +func TestFile_ReadFile_PathSecurity(t *testing.T) { t.Parallel() tests := []struct { @@ -87,11 +83,11 @@ func TestFileTools_ReadFile_PathSecurity(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() result, err := ft.ReadFile(nil, ReadFileInput{Path: tt.path}) - // FileTools returns business errors in Result, not Go errors + // File returns business errors in Result, not Go errors if err != nil { t.Fatalf("ReadFile(%q) unexpected Go error: %v (should not return Go error)", tt.path, err) } @@ -104,18 +100,18 @@ func TestFileTools_ReadFile_PathSecurity(t *testing.T) { if got, want := result.Error.Code, tt.wantErrCode; got != want { t.Errorf("ReadFile(%q).Error.Code = %v, want %v", tt.path, got, want) } - if !strings.Contains(result.Error.Message, "path validation failed") { - t.Errorf("ReadFile(%q).Error.Message = %q, want contains %q", tt.path, result.Error.Message, "path validation failed") + if !strings.Contains(result.Error.Message, "validating path") { + t.Errorf("ReadFile(%q).Error.Message = %q, want contains %q", tt.path, result.Error.Message, "validating path") } }) } } -func TestFileTools_ReadFile_Success(t *testing.T) { +func TestFile_ReadFile_Success(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() // Create a test file testContent := "Hello, World!" @@ -143,11 +139,11 @@ func TestFileTools_ReadFile_Success(t *testing.T) { } } -func TestFileTools_ReadFile_NotFound(t *testing.T) { +func TestFile_ReadFile_NotFound(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() // Try to read non-existent file within allowed directory nonExistentPath := filepath.Join(h.tempDir, "does-not-exist.txt") @@ -168,11 +164,11 @@ func TestFileTools_ReadFile_NotFound(t *testing.T) { } } -func TestFileTools_ReadFile_FileTooLarge(t *testing.T) { +func TestFile_ReadFile_FileTooLarge(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() // Create a file larger than MaxReadFileSize (10MB) largePath := filepath.Join(h.tempDir, "large.txt") @@ -209,11 +205,7 @@ func TestFileTools_ReadFile_FileTooLarge(t *testing.T) { } } -// ============================================================================ -// WriteFile Integration Tests -// ============================================================================ - -func TestFileTools_WriteFile_PathSecurity(t *testing.T) { +func TestFile_WriteFile_PathSecurity(t *testing.T) { t.Parallel() tests := []struct { @@ -247,7 +239,7 @@ func TestFileTools_WriteFile_PathSecurity(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() result, err := ft.WriteFile(nil, WriteFileInput{ Path: tt.path, @@ -270,11 +262,11 @@ func TestFileTools_WriteFile_PathSecurity(t *testing.T) { } } -func TestFileTools_WriteFile_Success(t *testing.T) { +func TestFile_WriteFile_Success(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() testPath := filepath.Join(h.tempDir, "new-file.txt") testContent := "New content" @@ -304,11 +296,11 @@ func TestFileTools_WriteFile_Success(t *testing.T) { } } -func TestFileTools_WriteFile_CreatesDirectories(t *testing.T) { +func TestFile_WriteFile_CreatesDirectories(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() // Write to a nested path that doesn't exist nestedPath := filepath.Join(h.tempDir, "subdir", "nested", "file.txt") @@ -336,11 +328,7 @@ func TestFileTools_WriteFile_CreatesDirectories(t *testing.T) { } } -// ============================================================================ -// DeleteFile Integration Tests -// ============================================================================ - -func TestFileTools_DeleteFile_PathSecurity(t *testing.T) { +func TestFile_DeleteFile_PathSecurity(t *testing.T) { t.Parallel() tests := []struct { @@ -368,7 +356,7 @@ func TestFileTools_DeleteFile_PathSecurity(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() result, err := ft.DeleteFile(nil, DeleteFileInput{Path: tt.path}) @@ -388,11 +376,11 @@ func TestFileTools_DeleteFile_PathSecurity(t *testing.T) { } } -func TestFileTools_DeleteFile_Success(t *testing.T) { +func TestFile_DeleteFile_Success(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() // Create a file to delete testPath := h.createTestFile("to-delete.txt", "content") @@ -417,11 +405,7 @@ func TestFileTools_DeleteFile_Success(t *testing.T) { } } -// ============================================================================ -// ListFiles Integration Tests -// ============================================================================ - -func TestFileTools_ListFiles_PathSecurity(t *testing.T) { +func TestFile_ListFiles_PathSecurity(t *testing.T) { t.Parallel() tests := []struct { @@ -449,7 +433,7 @@ func TestFileTools_ListFiles_PathSecurity(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() result, err := ft.ListFiles(nil, ListFilesInput{Path: tt.path}) @@ -469,11 +453,11 @@ func TestFileTools_ListFiles_PathSecurity(t *testing.T) { } } -func TestFileTools_ListFiles_Success(t *testing.T) { +func TestFile_ListFiles_Success(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() // Create some test files h.createTestFile("file1.txt", "content1") @@ -516,15 +500,11 @@ func TestFileTools_ListFiles_Success(t *testing.T) { } } -// ============================================================================ -// GetFileInfo Integration Tests -// ============================================================================ - -func TestFileTools_GetFileInfo_PathSecurity(t *testing.T) { +func TestFile_GetFileInfo_PathSecurity(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() result, err := ft.GetFileInfo(nil, GetFileInfoInput{Path: "/etc/passwd"}) @@ -542,11 +522,11 @@ func TestFileTools_GetFileInfo_PathSecurity(t *testing.T) { } } -func TestFileTools_GetFileInfo_Success(t *testing.T) { +func TestFile_GetFileInfo_Success(t *testing.T) { t.Parallel() h := newfileTools(t) - ft := h.createFileTools() + ft := h.createFile() // Create a test file testPath := h.createTestFile("info.txt", "test content") diff --git a/internal/tools/file_test.go b/internal/tools/file_test.go index 904ed3c..33ed506 100644 --- a/internal/tools/file_test.go +++ b/internal/tools/file_test.go @@ -6,44 +6,44 @@ import ( "github.com/koopa0/koopa/internal/security" ) -func TestFileTools_Constructor(t *testing.T) { +func TestFile_Constructor(t *testing.T) { t.Run("valid inputs", func(t *testing.T) { pathVal, err := security.NewPath([]string{"/tmp"}) if err != nil { - t.Fatalf("failed to create path validator: %v", err) + t.Fatalf("creating path validator: %v", err) } - ft, err := NewFileTools(pathVal, testLogger()) + ft, err := NewFile(pathVal, testLogger()) if err != nil { - t.Errorf("NewFileTools() error = %v, want nil", err) + t.Errorf("NewFile() error = %v, want nil", err) } if ft == nil { - t.Error("NewFileTools() returned nil, want non-nil") + t.Error("NewFile() returned nil, want non-nil") } }) t.Run("nil path validator", func(t *testing.T) { - ft, err := NewFileTools(nil, testLogger()) + ft, err := NewFile(nil, testLogger()) if err == nil { - t.Error("NewFileTools() error = nil, want error") + t.Error("NewFile() error = nil, want error") } if ft != nil { - t.Error("NewFileTools() returned non-nil, want nil") + t.Error("NewFile() returned non-nil, want nil") } }) t.Run("nil logger", func(t *testing.T) { pathVal, err := security.NewPath([]string{"/tmp"}) if err != nil { - t.Fatalf("failed to create path validator: %v", err) + t.Fatalf("creating path validator: %v", err) } - ft, err := NewFileTools(pathVal, nil) + ft, err := NewFile(pathVal, nil) if err == nil { - t.Error("NewFileTools() error = nil, want error") + t.Error("NewFile() error = nil, want error") } if ft != nil { - t.Error("NewFileTools() returned non-nil, want nil") + t.Error("NewFile() returned non-nil, want nil") } }) } @@ -51,27 +51,27 @@ func TestFileTools_Constructor(t *testing.T) { func TestFileToolConstants(t *testing.T) { // Verify tool name constants are correct expectedNames := map[string]string{ - "ToolReadFile": "read_file", - "ToolWriteFile": "write_file", - "ToolListFiles": "list_files", - "ToolDeleteFile": "delete_file", - "ToolGetFileInfo": "get_file_info", + "ReadFileName": "read_file", + "WriteFileName": "write_file", + "ListFilesName": "list_files", + "DeleteFileName": "delete_file", + "FileInfoName": "get_file_info", } - if ToolReadFile != expectedNames["ToolReadFile"] { - t.Errorf("ToolReadFile = %q, want %q", ToolReadFile, expectedNames["ToolReadFile"]) + if ReadFileName != expectedNames["ReadFileName"] { + t.Errorf("ReadFileName = %q, want %q", ReadFileName, expectedNames["ReadFileName"]) } - if ToolWriteFile != expectedNames["ToolWriteFile"] { - t.Errorf("ToolWriteFile = %q, want %q", ToolWriteFile, expectedNames["ToolWriteFile"]) + if WriteFileName != expectedNames["WriteFileName"] { + t.Errorf("WriteFileName = %q, want %q", WriteFileName, expectedNames["WriteFileName"]) } - if ToolListFiles != expectedNames["ToolListFiles"] { - t.Errorf("ToolListFiles = %q, want %q", ToolListFiles, expectedNames["ToolListFiles"]) + if ListFilesName != expectedNames["ListFilesName"] { + t.Errorf("ListFilesName = %q, want %q", ListFilesName, expectedNames["ListFilesName"]) } - if ToolDeleteFile != expectedNames["ToolDeleteFile"] { - t.Errorf("ToolDeleteFile = %q, want %q", ToolDeleteFile, expectedNames["ToolDeleteFile"]) + if DeleteFileName != expectedNames["DeleteFileName"] { + t.Errorf("DeleteFileName = %q, want %q", DeleteFileName, expectedNames["DeleteFileName"]) } - if ToolGetFileInfo != expectedNames["ToolGetFileInfo"] { - t.Errorf("ToolGetFileInfo = %q, want %q", ToolGetFileInfo, expectedNames["ToolGetFileInfo"]) + if FileInfoName != expectedNames["FileInfoName"] { + t.Errorf("FileInfoName = %q, want %q", FileInfoName, expectedNames["FileInfoName"]) } } @@ -82,45 +82,3 @@ func TestMaxReadFileSize(t *testing.T) { t.Errorf("MaxReadFileSize = %d, want %d (10MB)", MaxReadFileSize, expected) } } - -func TestReadFileInput(t *testing.T) { - // Test that ReadFileInput struct can be created - input := ReadFileInput{Path: "/tmp/test.txt"} - if input.Path != "/tmp/test.txt" { - t.Errorf("ReadFileInput.Path = %q, want %q", input.Path, "/tmp/test.txt") - } -} - -func TestWriteFileInput(t *testing.T) { - input := WriteFileInput{ - Path: "/tmp/test.txt", - Content: "hello world", - } - if input.Path != "/tmp/test.txt" { - t.Errorf("WriteFileInput.Path = %q, want %q", input.Path, "/tmp/test.txt") - } - if input.Content != "hello world" { - t.Errorf("WriteFileInput.Content = %q, want %q", input.Content, "hello world") - } -} - -func TestListFilesInput(t *testing.T) { - input := ListFilesInput{Path: "/tmp"} - if input.Path != "/tmp" { - t.Errorf("ListFilesInput.Path = %q, want %q", input.Path, "/tmp") - } -} - -func TestDeleteFileInput(t *testing.T) { - input := DeleteFileInput{Path: "/tmp/test.txt"} - if input.Path != "/tmp/test.txt" { - t.Errorf("DeleteFileInput.Path = %q, want %q", input.Path, "/tmp/test.txt") - } -} - -func TestGetFileInfoInput(t *testing.T) { - input := GetFileInfoInput{Path: "/tmp/test.txt"} - if input.Path != "/tmp/test.txt" { - t.Errorf("GetFileInfoInput.Path = %q, want %q", input.Path, "/tmp/test.txt") - } -} diff --git a/internal/tools/fuzz_test.go b/internal/tools/fuzz_test.go index 639167d..3fb4026 100644 --- a/internal/tools/fuzz_test.go +++ b/internal/tools/fuzz_test.go @@ -22,17 +22,17 @@ func FuzzClampTopK(f *testing.F) { if topK <= 0 { if result != defaultVal { - t.Errorf("clampTopK(%d, %d) = %d, expected defaultVal %d for zero/negative topK", + t.Errorf("clampTopK(%d, %d) = %d, want %d for zero/negative topK", topK, defaultVal, result, defaultVal) } } else if topK > 10 { if result != 10 { - t.Errorf("clampTopK(%d, %d) = %d, expected 10 for topK > 10", + t.Errorf("clampTopK(%d, %d) = %d, want 10 for topK > 10", topK, defaultVal, result) } } else { if result != topK { - t.Errorf("clampTopK(%d, %d) = %d, expected %d for valid topK", + t.Errorf("clampTopK(%d, %d) = %d, want %d for valid topK", topK, defaultVal, result, topK) } } @@ -63,10 +63,6 @@ func FuzzResultConstruction(f *testing.F) { }) } -// ============================================================================= -// Security Fuzz Tests -// ============================================================================= - // FuzzPathTraversal tests path validation never panics and handles edge cases. // The validator allows relative paths (resolved from working directory) and // absolute paths within allowed directories. @@ -227,7 +223,7 @@ func FuzzCommandInjection(f *testing.F) { f.Fuzz(func(t *testing.T, cmd, args string) { validator := security.NewCommand() argList := strings.Fields(args) - err := validator.ValidateCommand(cmd, argList) + err := validator.Validate(cmd, argList) // If validation passes, verify it's not a dangerous command if err == nil { @@ -278,7 +274,7 @@ func FuzzEnvVarBypass(f *testing.F) { f.Fuzz(func(t *testing.T, envName string) { validator := security.NewEnv() - err := validator.ValidateEnvAccess(envName) + err := validator.Validate(envName) // If validation passes, verify it's not a sensitive variable if err == nil { diff --git a/internal/tools/knowledge.go b/internal/tools/knowledge.go index 18af242..f1bce84 100644 --- a/internal/tools/knowledge.go +++ b/internal/tools/knowledge.go @@ -10,20 +10,25 @@ import ( "context" "crypto/sha256" "fmt" + "log/slog" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/postgresql" - "github.com/koopa0/koopa/internal/log" "github.com/koopa0/koopa/internal/rag" ) +// Tool name constants for knowledge operations registered with Genkit. const ( - ToolSearchHistory = "search_history" - ToolSearchDocuments = "search_documents" - ToolSearchSystemKnowledge = "search_system_knowledge" - ToolStoreKnowledge = "knowledge_store" + // SearchHistoryName is the Genkit tool name for searching conversation history. + SearchHistoryName = "search_history" + // SearchDocumentsName is the Genkit tool name for searching indexed documents. + SearchDocumentsName = "search_documents" + // SearchSystemKnowledgeName is the Genkit tool name for searching system knowledge. + SearchSystemKnowledgeName = "search_system_knowledge" + // StoreKnowledgeName is the Genkit tool name for storing new knowledge documents. + StoreKnowledgeName = "knowledge_store" ) // Default TopK values for knowledge searches. @@ -47,67 +52,73 @@ type KnowledgeStoreInput struct { Content string `json:"content" jsonschema_description:"The knowledge content to store"` } -// KnowledgeTools holds dependencies for knowledge operation handlers. -type KnowledgeTools struct { +// Knowledge holds dependencies for knowledge operation handlers. +type Knowledge struct { retriever ai.Retriever docStore *postgresql.DocStore // nil disables knowledge_store tool - logger log.Logger + logger *slog.Logger } -// NewKnowledgeTools creates a KnowledgeTools instance. +// HasDocStore reports whether the document store is available. +// Used by MCP server to conditionally register the knowledge_store tool. +func (k *Knowledge) HasDocStore() bool { + return k.docStore != nil +} + +// NewKnowledge creates a Knowledge instance. // docStore is optional: when nil, the knowledge_store tool is not registered. -func NewKnowledgeTools(retriever ai.Retriever, docStore *postgresql.DocStore, logger log.Logger) (*KnowledgeTools, error) { +func NewKnowledge(retriever ai.Retriever, docStore *postgresql.DocStore, logger *slog.Logger) (*Knowledge, error) { if retriever == nil { return nil, fmt.Errorf("retriever is required") } if logger == nil { return nil, fmt.Errorf("logger is required") } - return &KnowledgeTools{retriever: retriever, docStore: docStore, logger: logger}, nil + return &Knowledge{retriever: retriever, docStore: docStore, logger: logger}, nil } -// RegisterKnowledgeTools registers all knowledge search tools with Genkit. +// RegisterKnowledge registers all knowledge search tools with Genkit. // Tools are registered with event emission wrappers for streaming support. -func RegisterKnowledgeTools(g *genkit.Genkit, kt *KnowledgeTools) ([]ai.Tool, error) { +func RegisterKnowledge(g *genkit.Genkit, kt *Knowledge) ([]ai.Tool, error) { if g == nil { return nil, fmt.Errorf("genkit instance is required") } if kt == nil { - return nil, fmt.Errorf("KnowledgeTools is required") + return nil, fmt.Errorf("Knowledge is required") } tools := []ai.Tool{ - genkit.DefineTool(g, ToolSearchHistory, + genkit.DefineTool(g, SearchHistoryName, "Search conversation history using semantic similarity. "+ "Finds past exchanges that are conceptually related to the query. "+ "Returns: matched conversation turns with timestamps and similarity scores. "+ "Use this to: recall past discussions, find context from earlier conversations. "+ "Default topK: 3. Maximum topK: 10.", - WithEvents(ToolSearchHistory, kt.SearchHistory)), - genkit.DefineTool(g, ToolSearchDocuments, + WithEvents(SearchHistoryName, kt.SearchHistory)), + genkit.DefineTool(g, SearchDocumentsName, "Search indexed documents (PDFs, code files, notes) using semantic similarity. "+ "Finds document sections that are conceptually related to the query. "+ "Returns: document titles, content excerpts, and similarity scores. "+ "Use this to: find relevant documentation, locate code examples, research topics. "+ "Default topK: 5. Maximum topK: 10.", - WithEvents(ToolSearchDocuments, kt.SearchDocuments)), - genkit.DefineTool(g, ToolSearchSystemKnowledge, + WithEvents(SearchDocumentsName, kt.SearchDocuments)), + genkit.DefineTool(g, SearchSystemKnowledgeName, "Search system knowledge base (tool usage, commands, patterns) using semantic similarity. "+ "Finds internal system documentation and usage patterns. "+ "Returns: knowledge entries with descriptions and examples. "+ "Use this to: understand tool capabilities, find command syntax, learn system patterns. "+ "Default topK: 3. Maximum topK: 10.", - WithEvents(ToolSearchSystemKnowledge, kt.SearchSystemKnowledge)), + WithEvents(SearchSystemKnowledgeName, kt.SearchSystemKnowledge)), } // Register knowledge_store only when DocStore is available. if kt.docStore != nil { - tools = append(tools, genkit.DefineTool(g, ToolStoreKnowledge, + tools = append(tools, genkit.DefineTool(g, StoreKnowledgeName, "Store a knowledge entry for later retrieval via search_documents. "+ "Use this to save important information, notes, or learnings "+ "that the user wants to remember across sessions. "+ "Each entry gets a unique ID and is indexed for semantic search.", - WithEvents(ToolStoreKnowledge, kt.StoreKnowledge))) + WithEvents(StoreKnowledgeName, kt.StoreKnowledge))) } return tools, nil @@ -135,13 +146,17 @@ var validSourceTypes = map[string]bool{ // search performs a knowledge search with the given source type filter. // Returns error if sourceType is not in the allowed whitelist. -func (k *KnowledgeTools) search(ctx context.Context, query string, topK int, sourceType string) ([]*ai.Document, error) { +func (k *Knowledge) search(ctx context.Context, query string, topK int, sourceType string) ([]*ai.Document, error) { // Validate source type against whitelist (SQL injection prevention) if !validSourceTypes[sourceType] { return nil, fmt.Errorf("invalid source type: %q", sourceType) } - // Build WHERE clause filter for source_type (safe: sourceType is validated) + // Build WHERE clause filter for source_type. + // SECURITY: sourceType is SQL injection-safe because it's validated against + // a hardcoded whitelist (validSourceTypes). This filter is passed to the + // Genkit PostgreSQL retriever which includes it in a SQL query. + // DO NOT bypass the whitelist validation above. filter := fmt.Sprintf("source_type = '%s'", sourceType) req := &ai.RetrieverRequest{ @@ -161,7 +176,7 @@ func (k *KnowledgeTools) search(ctx context.Context, query string, topK int, sou } // SearchHistory searches conversation history using semantic similarity. -func (k *KnowledgeTools) SearchHistory(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) { +func (k *Knowledge) SearchHistory(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) { k.logger.Info("SearchHistory called", "query", input.Query, "topK", input.TopK) topK := clampTopK(input.TopK, DefaultHistoryTopK) @@ -173,7 +188,7 @@ func (k *KnowledgeTools) SearchHistory(ctx *ai.ToolContext, input KnowledgeSearc Status: StatusError, Error: &Error{ Code: ErrCodeExecution, - Message: fmt.Sprintf("history search failed: %v", err), + Message: fmt.Sprintf("searching history: %v", err), }, }, nil } @@ -190,7 +205,7 @@ func (k *KnowledgeTools) SearchHistory(ctx *ai.ToolContext, input KnowledgeSearc } // SearchDocuments searches indexed documents using semantic similarity. -func (k *KnowledgeTools) SearchDocuments(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) { +func (k *Knowledge) SearchDocuments(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) { k.logger.Info("SearchDocuments called", "query", input.Query, "topK", input.TopK) topK := clampTopK(input.TopK, DefaultDocumentsTopK) @@ -202,7 +217,7 @@ func (k *KnowledgeTools) SearchDocuments(ctx *ai.ToolContext, input KnowledgeSea Status: StatusError, Error: &Error{ Code: ErrCodeExecution, - Message: fmt.Sprintf("document search failed: %v", err), + Message: fmt.Sprintf("searching documents: %v", err), }, }, nil } @@ -219,7 +234,7 @@ func (k *KnowledgeTools) SearchDocuments(ctx *ai.ToolContext, input KnowledgeSea } // StoreKnowledge stores a new knowledge document for later retrieval. -func (k *KnowledgeTools) StoreKnowledge(ctx *ai.ToolContext, input KnowledgeStoreInput) (Result, error) { +func (k *Knowledge) StoreKnowledge(ctx *ai.ToolContext, input KnowledgeStoreInput) (Result, error) { k.logger.Info("StoreKnowledge called", "title", input.Title) if k.docStore == nil { @@ -252,6 +267,7 @@ func (k *KnowledgeTools) StoreKnowledge(ctx *ai.ToolContext, input KnowledgeStor } // Generate a deterministic document ID from the title using SHA-256. + // Changing the title creates a new document; the old entry remains. // Prefix "user:" namespaces user-created knowledge (vs "system:" for built-in). docID := fmt.Sprintf("user:%x", sha256.Sum256([]byte(input.Title))) @@ -267,7 +283,7 @@ func (k *KnowledgeTools) StoreKnowledge(ctx *ai.ToolContext, input KnowledgeStor Status: StatusError, Error: &Error{ Code: ErrCodeExecution, - Message: fmt.Sprintf("failed to store knowledge: %v", err), + Message: fmt.Sprintf("storing knowledge: %v", err), }, }, nil } @@ -283,7 +299,7 @@ func (k *KnowledgeTools) StoreKnowledge(ctx *ai.ToolContext, input KnowledgeStor } // SearchSystemKnowledge searches system knowledge base using semantic similarity. -func (k *KnowledgeTools) SearchSystemKnowledge(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) { +func (k *Knowledge) SearchSystemKnowledge(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) { k.logger.Info("SearchSystemKnowledge called", "query", input.Query, "topK", input.TopK) topK := clampTopK(input.TopK, DefaultSystemKnowledgeTopK) @@ -295,7 +311,7 @@ func (k *KnowledgeTools) SearchSystemKnowledge(ctx *ai.ToolContext, input Knowle Status: StatusError, Error: &Error{ Code: ErrCodeExecution, - Message: fmt.Sprintf("system knowledge search failed: %v", err), + Message: fmt.Sprintf("searching system knowledge: %v", err), }, }, nil } diff --git a/internal/tools/knowledge_test.go b/internal/tools/knowledge_test.go index 2e0abbe..4500edf 100644 --- a/internal/tools/knowledge_test.go +++ b/internal/tools/knowledge_test.go @@ -2,11 +2,11 @@ package tools import ( "context" + "log/slog" "testing" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core/api" - "github.com/koopa0/koopa/internal/log" ) // mockRetriever is a minimal ai.Retriever implementation for testing. @@ -25,12 +25,12 @@ func TestClampTopK(t *testing.T) { defaultVal int want int }{ - {"zero uses default", 0, 3, 3}, - {"negative uses default", -5, 5, 5}, - {"value in range unchanged", 5, 3, 5}, - {"max boundary", 10, 3, 10}, - {"exceeds max clamped to 10", 50, 3, 10}, - {"min value", 1, 3, 1}, + {name: "zero uses default", topK: 0, defaultVal: 3, want: 3}, + {name: "negative uses default", topK: -5, defaultVal: 5, want: 5}, + {name: "value in range unchanged", topK: 5, defaultVal: 3, want: 5}, + {name: "max boundary", topK: 10, defaultVal: 3, want: 10}, + {name: "exceeds max clamped to 10", topK: 50, defaultVal: 3, want: 10}, + {name: "min value", topK: 1, defaultVal: 3, want: 1}, } for _, tt := range tests { @@ -44,30 +44,30 @@ func TestClampTopK(t *testing.T) { } func TestKnowledgeToolConstants(t *testing.T) { - if ToolSearchHistory != "search_history" { - t.Errorf("ToolSearchHistory = %q, want %q", ToolSearchHistory, "search_history") + if SearchHistoryName != "search_history" { + t.Errorf("SearchHistoryName = %q, want %q", SearchHistoryName, "search_history") } - if ToolSearchDocuments != "search_documents" { - t.Errorf("ToolSearchDocuments = %q, want %q", ToolSearchDocuments, "search_documents") + if SearchDocumentsName != "search_documents" { + t.Errorf("SearchDocumentsName = %q, want %q", SearchDocumentsName, "search_documents") } - if ToolSearchSystemKnowledge != "search_system_knowledge" { - t.Errorf("ToolSearchSystemKnowledge = %q, want %q", ToolSearchSystemKnowledge, "search_system_knowledge") + if SearchSystemKnowledgeName != "search_system_knowledge" { + t.Errorf("SearchSystemKnowledgeName = %q, want %q", SearchSystemKnowledgeName, "search_system_knowledge") } - if ToolStoreKnowledge != "knowledge_store" { - t.Errorf("ToolStoreKnowledge = %q, want %q", ToolStoreKnowledge, "knowledge_store") + if StoreKnowledgeName != "knowledge_store" { + t.Errorf("StoreKnowledgeName = %q, want %q", StoreKnowledgeName, "knowledge_store") } } -func TestNewKnowledgeTools(t *testing.T) { +func TestNewKnowledge(t *testing.T) { t.Run("nil retriever returns error", func(t *testing.T) { - if _, err := NewKnowledgeTools(nil, nil, log.NewNop()); err == nil { - t.Error("expected error for nil retriever") + if _, err := NewKnowledge(nil, nil, slog.New(slog.DiscardHandler)); err == nil { + t.Error("NewKnowledge(nil, nil, logger) error = nil, want non-nil") } }) t.Run("nil logger returns error", func(t *testing.T) { - if _, err := NewKnowledgeTools(&mockRetriever{}, nil, nil); err == nil { - t.Error("expected error for nil logger") + if _, err := NewKnowledge(&mockRetriever{}, nil, nil); err == nil { + t.Error("NewKnowledge(retriever, nil, nil) error = nil, want non-nil") } }) } @@ -92,7 +92,7 @@ func TestValidSourceTypes(t *testing.T) { validTypes := []string{"conversation", "file", "system"} for _, st := range validTypes { if !validSourceTypes[st] { - t.Errorf("expected %q to be valid source type", st) + t.Errorf("validSourceTypes[%q] = false, want true", st) } } @@ -106,7 +106,7 @@ func TestValidSourceTypes(t *testing.T) { } for _, st := range invalidTypes { if validSourceTypes[st] { - t.Errorf("expected %q to be invalid source type (SQL injection risk)", st) + t.Errorf("validSourceTypes[%q] = true, want false (SQL injection risk)", st) } } } diff --git a/internal/tools/network.go b/internal/tools/network.go index 6ef3986..32f3c64 100644 --- a/internal/tools/network.go +++ b/internal/tools/network.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "log/slog" "net/http" "net/url" "strings" @@ -16,13 +17,15 @@ import ( "github.com/go-shiori/go-readability" "github.com/gocolly/colly/v2" - "github.com/koopa0/koopa/internal/log" "github.com/koopa0/koopa/internal/security" ) +// Tool name constants for network operations registered with Genkit. const ( - ToolWebSearch = "web_search" - ToolWebFetch = "web_fetch" + // WebSearchName is the Genkit tool name for performing web searches. + WebSearchName = "web_search" + // WebFetchName is the Genkit tool name for fetching web page content. + WebFetchName = "web_fetch" ) // Content limits. @@ -39,11 +42,11 @@ const ( MaxRedirects = 5 ) -// NetworkTools holds dependencies for network operation handlers. -// Use NewNetworkTools to create an instance, then either: +// Network holds dependencies for network operation handlers. +// Use NewNetwork to create an instance, then either: // - Call methods directly (for MCP) -// - Use RegisterNetworkTools to register with Genkit -type NetworkTools struct { +// - Use RegisterNetwork to register with Genkit +type Network struct { // Search configuration (SearXNG) searchBaseURL string searchClient *http.Client @@ -60,19 +63,19 @@ type NetworkTools struct { // Only settable within the tools package (unexported field). skipSSRFCheck bool - logger log.Logger + logger *slog.Logger } -// NetworkConfig holds configuration for network tools. -type NetworkConfig struct { +// NetConfig holds configuration for network tools. +type NetConfig struct { SearchBaseURL string FetchParallelism int FetchDelay time.Duration FetchTimeout time.Duration } -// NewNetworkTools creates a NetworkTools instance. -func NewNetworkTools(cfg NetworkConfig, logger log.Logger) (*NetworkTools, error) { +// NewNetwork creates a Network instance. +func NewNetwork(cfg NetConfig, logger *slog.Logger) (*Network, error) { if cfg.SearchBaseURL == "" { return nil, fmt.Errorf("search base URL is required") } @@ -91,33 +94,38 @@ func NewNetworkTools(cfg NetworkConfig, logger log.Logger) (*NetworkTools, error cfg.FetchTimeout = 30 * time.Second } - return &NetworkTools{ - searchBaseURL: strings.TrimSuffix(cfg.SearchBaseURL, "/"), - searchClient: &http.Client{Timeout: 30 * time.Second}, + urlValidator := security.NewURL() + + return &Network{ + searchBaseURL: strings.TrimSuffix(cfg.SearchBaseURL, "/"), + searchClient: &http.Client{ + Timeout: 30 * time.Second, + Transport: urlValidator.SafeTransport(), + }, fetchParallelism: cfg.FetchParallelism, fetchDelay: cfg.FetchDelay, fetchTimeout: cfg.FetchTimeout, - urlValidator: security.NewURL(), + urlValidator: urlValidator, logger: logger, }, nil } -// RegisterNetworkTools registers all network operation tools with Genkit. +// RegisterNetwork registers all network operation tools with Genkit. // Tools are registered with event emission wrappers for streaming support. -func RegisterNetworkTools(g *genkit.Genkit, nt *NetworkTools) ([]ai.Tool, error) { +func RegisterNetwork(g *genkit.Genkit, nt *Network) ([]ai.Tool, error) { if g == nil { return nil, fmt.Errorf("genkit instance is required") } if nt == nil { - return nil, fmt.Errorf("NetworkTools is required") + return nil, fmt.Errorf("Network is required") } return []ai.Tool{ - genkit.DefineTool(g, ToolWebSearch, + genkit.DefineTool(g, WebSearchName, "Search the web for information. Returns relevant results with titles, URLs, and content snippets. "+ "Use this to find current information, news, or facts from the internet.", - WithEvents(ToolWebSearch, nt.Search)), - genkit.DefineTool(g, ToolWebFetch, + WithEvents(WebSearchName, nt.Search)), + genkit.DefineTool(g, WebFetchName, "Fetch and extract content from one or more URLs (max 10). "+ "Supports HTML pages, JSON APIs, and plain text. "+ "For HTML: uses Readability algorithm to extract main content. "+ @@ -126,14 +134,10 @@ func RegisterNetworkTools(g *genkit.Genkit, nt *NetworkTools) ([]ai.Tool, error) "Supports parallel fetching with rate limiting. "+ "Returns extracted content (max 50KB per URL). "+ "Note: Does not render JavaScript - for SPA pages, content may be incomplete.", - WithEvents(ToolWebFetch, nt.Fetch)), + WithEvents(WebFetchName, nt.Fetch)), }, nil } -// ============================================================================ -// web_search: Search the web via SearXNG -// ============================================================================ - // SearchInput defines the input for web_search tool. type SearchInput struct { // Query is the search query string. Required. @@ -166,7 +170,7 @@ type SearchResult struct { } // Search performs web search via SearXNG. -func (n *NetworkTools) Search(ctx *ai.ToolContext, input SearchInput) (SearchOutput, error) { +func (n *Network) Search(ctx *ai.ToolContext, input SearchInput) (SearchOutput, error) { // Validate required fields if strings.TrimSpace(input.Query) == "" { return SearchOutput{Error: "Query is required. Please provide a search query."}, nil @@ -202,8 +206,7 @@ func (n *NetworkTools) Search(ctx *ai.ToolContext, input SearchInput) (SearchOut // Execute request resp, err := n.searchClient.Do(req) if err != nil { - n.logger.Error("search request failed", "error", err) - return SearchOutput{}, fmt.Errorf("search request failed: %w", err) + return SearchOutput{}, fmt.Errorf("executing search request: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -275,10 +278,6 @@ type searxngResponse struct { Results []searxngResult `json:"results"` } -// ============================================================================ -// web_fetch: Fetch and extract content from URLs -// ============================================================================ - // FetchInput defines the input for web_fetch tool. type FetchInput struct { // URLs is one or more URLs to fetch. Required, max 10. @@ -352,7 +351,7 @@ func (s *fetchState) markProcessed(urlStr string) bool { // Fetch retrieves and extracts content from one or more URLs. // Includes SSRF protection to block private IPs and cloud metadata endpoints. -func (n *NetworkTools) Fetch(ctx *ai.ToolContext, input FetchInput) (FetchOutput, error) { +func (n *Network) Fetch(ctx *ai.ToolContext, input FetchInput) (FetchOutput, error) { // Validate input if len(input.URLs) == 0 { return FetchOutput{Error: "At least one URL is required. Please provide URLs to fetch."}, nil @@ -403,7 +402,7 @@ func (n *NetworkTools) Fetch(ctx *ai.ToolContext, input FetchInput) (FetchOutput } // filterURLs validates and deduplicates URLs, returning safe and failed lists. -func (n *NetworkTools) filterURLs(urls []string) (safe []string, failed []FailedURL) { +func (n *Network) filterURLs(urls []string) (safe []string, failed []FailedURL) { urlSet := make(map[string]struct{}) for _, u := range urls { if _, exists := urlSet[u]; exists { @@ -424,7 +423,7 @@ func (n *NetworkTools) filterURLs(urls []string) (safe []string, failed []Failed } // createCollector creates a configured Colly collector. -func (n *NetworkTools) createCollector() *colly.Collector { +func (n *Network) createCollector() *colly.Collector { c := colly.NewCollector( colly.Async(true), colly.MaxDepth(1), @@ -462,7 +461,7 @@ func (n *NetworkTools) createCollector() *colly.Collector { } // setupCallbacks configures all Colly callbacks. -func (n *NetworkTools) setupCallbacks(c *colly.Collector, ctx *ai.ToolContext, state *fetchState, selector string) { +func (n *Network) setupCallbacks(c *colly.Collector, ctx *ai.ToolContext, state *fetchState, selector string) { c.OnRequest(func(r *colly.Request) { select { case <-ctx.Done(): @@ -472,7 +471,7 @@ func (n *NetworkTools) setupCallbacks(c *colly.Collector, ctx *ai.ToolContext, s }) c.OnResponse(func(r *colly.Response) { - n.handleNonHTMLResponse(r, state) + handleNonHTMLResponse(r, state) }) c.OnHTML("html", func(e *colly.HTMLElement) { @@ -485,7 +484,7 @@ func (n *NetworkTools) setupCallbacks(c *colly.Collector, ctx *ai.ToolContext, s } // handleNonHTMLResponse processes JSON, XML, and text responses. -func (n *NetworkTools) handleNonHTMLResponse(r *colly.Response, state *fetchState) { +func handleNonHTMLResponse(r *colly.Response, state *fetchState) { contentType := r.Headers.Get("Content-Type") if strings.Contains(contentType, "text/html") { return @@ -496,7 +495,7 @@ func (n *NetworkTools) handleNonHTMLResponse(r *colly.Response, state *fetchStat return } - title, content := n.extractNonHTMLContent(r.Body, contentType) + title, content := extractNonHTMLContent(r.Body, contentType) if len(content) > MaxContentLength { content = content[:MaxContentLength] + "\n\n[Content truncated...]" } @@ -510,7 +509,7 @@ func (n *NetworkTools) handleNonHTMLResponse(r *colly.Response, state *fetchStat } // extractNonHTMLContent extracts content from non-HTML responses. -func (*NetworkTools) extractNonHTMLContent(body []byte, contentType string) (title, content string) { +func extractNonHTMLContent(body []byte, contentType string) (title, content string) { switch { case strings.Contains(contentType, "application/json"): var jsonData any @@ -532,7 +531,7 @@ func (*NetworkTools) extractNonHTMLContent(body []byte, contentType string) (tit } // handleHTMLResponse processes HTML responses using readability. -func (n *NetworkTools) handleHTMLResponse(e *colly.HTMLElement, state *fetchState, selector string) { +func (n *Network) handleHTMLResponse(e *colly.HTMLElement, state *fetchState, selector string) { urlStr := e.Request.URL.String() if state.markProcessed(urlStr) { return @@ -558,7 +557,7 @@ func (n *NetworkTools) handleHTMLResponse(e *colly.HTMLElement, state *fetchStat } // handleError processes fetch errors. -func (n *NetworkTools) handleError(r *colly.Response, err error, state *fetchState) { +func (n *Network) handleError(r *colly.Response, err error, state *fetchState) { reason := err.Error() statusCode := 0 if r.StatusCode > 0 { @@ -578,7 +577,7 @@ func (n *NetworkTools) handleError(r *colly.Response, err error, state *fetchSta // extractWithReadability extracts content using go-readability with CSS selector fallback. // Returns (title, content). // u should be the final URL after redirects (e.Request.URL from Colly). -func (n *NetworkTools) extractWithReadability(u *url.URL, html string, e *colly.HTMLElement, selector string) (title, content string) { +func (n *Network) extractWithReadability(u *url.URL, html string, e *colly.HTMLElement, selector string) (title, content string) { // Try go-readability first (Mozilla Readability algorithm) // u is already parsed and reflects the final URL after any redirects article, err := readability.FromReader(bytes.NewReader([]byte(html)), u) @@ -590,7 +589,7 @@ func (n *NetworkTools) extractWithReadability(u *url.URL, html string, e *colly. } // Convert HTML content to plain text (remove tags) - content = n.htmlToText(article.Content) + content = htmlToText(article.Content) if content != "" { return title, content } @@ -598,12 +597,12 @@ func (n *NetworkTools) extractWithReadability(u *url.URL, html string, e *colly. // Readability failed or returned empty - fallback to CSS selector n.logger.Debug("readability fallback to selector", "url", u.String()) - return n.extractWithSelector(e, selector) + return extractWithSelector(e, selector) } // extractWithSelector extracts content using CSS selectors (fallback method). // Returns (title, content). -func (*NetworkTools) extractWithSelector(e *colly.HTMLElement, selector string) (extractedTitle, extractedContent string) { +func extractWithSelector(e *colly.HTMLElement, selector string) (extractedTitle, extractedContent string) { // Extract title extractedTitle = e.ChildText("title") if extractedTitle == "" { @@ -631,7 +630,7 @@ func (*NetworkTools) extractWithSelector(e *colly.HTMLElement, selector string) } // htmlToText converts HTML content to plain text. -func (*NetworkTools) htmlToText(html string) string { +func htmlToText(html string) string { doc, err := goquery.NewDocumentFromReader(strings.NewReader(html)) if err != nil { return "" diff --git a/internal/tools/network_integration_test.go b/internal/tools/network_integration_test.go index 642c54f..3010d0e 100644 --- a/internal/tools/network_integration_test.go +++ b/internal/tools/network_integration_test.go @@ -12,7 +12,7 @@ import ( "github.com/firebase/genkit/go/ai" ) -// networkTools provides test utilities for NetworkTools. +// networkTools provides test utilities for Network. type networkTools struct { t *testing.T } @@ -22,17 +22,17 @@ func newnetworkTools(t *testing.T) *networkTools { return &networkTools{t: t} } -func (h *networkTools) createNetworkTools(serverURL string) *NetworkTools { +func (h *networkTools) createNetwork(serverURL string) *Network { h.t.Helper() - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: serverURL, FetchParallelism: 2, FetchDelay: 10 * time.Millisecond, FetchTimeout: 5 * time.Second, } - nt, err := NewNetworkTools(cfg, testLogger()) + nt, err := NewNetwork(cfg, testLogger()) if err != nil { - h.t.Fatalf("failed to create network tools: %v", err) + h.t.Fatalf("creating network tools: %v", err) } return nt } @@ -61,11 +61,7 @@ func (*networkTools) toolContext() *ai.ToolContext { return &ai.ToolContext{Context: context.Background()} } -// ============================================================================ -// SSRF Protection Tests - Blocked Hosts -// ============================================================================ - -func TestNetworkTools_Fetch_SSRFBlockedHosts(t *testing.T) { +func TestNetwork_Fetch_SSRFBlockedHosts(t *testing.T) { t.Parallel() tests := []struct { @@ -165,7 +161,7 @@ func TestNetworkTools_Fetch_SSRFBlockedHosts(t *testing.T) { h := newnetworkTools(t) server := h.createMockServer() - nt := h.createNetworkTools(server.URL) + nt := h.createNetwork(server.URL) ctx := h.toolContext() output, err := nt.Fetch(ctx, FetchInput{URLs: []string{tt.url}}) @@ -192,11 +188,7 @@ func TestNetworkTools_Fetch_SSRFBlockedHosts(t *testing.T) { } } -// ============================================================================ -// SSRF Protection Tests - Scheme Validation -// ============================================================================ - -func TestNetworkTools_Fetch_SchemeValidation(t *testing.T) { +func TestNetwork_Fetch_SchemeValidation(t *testing.T) { t.Parallel() tests := []struct { @@ -237,7 +229,7 @@ func TestNetworkTools_Fetch_SchemeValidation(t *testing.T) { h := newnetworkTools(t) server := h.createMockServer() - nt := h.createNetworkTools(server.URL) + nt := h.createNetwork(server.URL) ctx := h.toolContext() output, err := nt.Fetch(ctx, FetchInput{URLs: []string{tt.url}}) @@ -261,24 +253,20 @@ func TestNetworkTools_Fetch_SchemeValidation(t *testing.T) { } } -// ============================================================================ -// SSRF Protection Tests - Mixed URLs -// ============================================================================ - -func TestNetworkTools_Fetch_MixedURLsFiltered(t *testing.T) { +func TestNetwork_Fetch_MixedURLsFiltered(t *testing.T) { t.Parallel() h := newnetworkTools(t) server := h.createMockServer() - // Create NetworkTools with testing mode (SSRF protection enabled but using mock server) - cfg := NetworkConfig{ + // Create Network with testing mode (SSRF protection enabled but using mock server) + cfg := NetConfig{ SearchBaseURL: server.URL, FetchParallelism: 2, FetchDelay: 10 * time.Millisecond, FetchTimeout: 5 * time.Second, } - nt := newNetworkToolsForTesting(t, cfg, testLogger()) + nt := newNetworkForTesting(t, cfg, testLogger()) ctx := h.toolContext() @@ -290,24 +278,13 @@ func TestNetworkTools_Fetch_MixedURLsFiltered(t *testing.T) { "http://169.254.169.254/", // Cloud metadata - blocked } - output, err := nt.Fetch(ctx, FetchInput{URLs: urls}) - + _, err := nt.Fetch(ctx, FetchInput{URLs: urls}) if err != nil { t.Fatalf("Fetch(mixed URLs) unexpected error: %v", err) } - - // The mock server URL should succeed in testing mode - // Private IPs should fail - // Note: skipSSRFCheck affects URL validation, so in testing mode - // even private URLs might pass. Let's verify the test setup. - t.Logf("Results: %d, Failed: %d", len(output.Results), len(output.FailedURLs)) } -// ============================================================================ -// SSRF Protection Tests - Redirect Protection -// ============================================================================ - -func TestNetworkTools_Fetch_RedirectSSRFProtection(t *testing.T) { +func TestNetwork_Fetch_RedirectSSRFProtection(t *testing.T) { t.Parallel() h := newnetworkTools(t) @@ -331,7 +308,7 @@ func TestNetworkTools_Fetch_RedirectSSRFProtection(t *testing.T) { })) t.Cleanup(func() { redirectServer.Close() }) - nt := h.createNetworkTools(redirectServer.URL) + nt := h.createNetwork(redirectServer.URL) ctx := h.toolContext() tests := []struct { @@ -393,16 +370,12 @@ func TestNetworkTools_Fetch_RedirectSSRFProtection(t *testing.T) { } } -// ============================================================================ -// Input Validation Tests -// ============================================================================ - -func TestNetworkTools_Fetch_InputValidation(t *testing.T) { +func TestNetwork_Fetch_InputValidation(t *testing.T) { t.Parallel() h := newnetworkTools(t) server := h.createMockServer() - nt := h.createNetworkTools(server.URL) + nt := h.createNetwork(server.URL) ctx := h.toolContext() t.Run("empty URL list", func(t *testing.T) { @@ -467,16 +440,12 @@ func TestNetworkTools_Fetch_InputValidation(t *testing.T) { }) } -// ============================================================================ -// Search Input Validation Tests -// ============================================================================ - -func TestNetworkTools_Search_InputValidation(t *testing.T) { +func TestNetwork_Search_InputValidation(t *testing.T) { t.Parallel() h := newnetworkTools(t) server := h.createMockServer() - nt := h.createNetworkTools(server.URL) + nt := h.createNetwork(server.URL) ctx := h.toolContext() t.Run("empty query rejected", func(t *testing.T) { @@ -514,24 +483,20 @@ func TestNetworkTools_Search_InputValidation(t *testing.T) { }) } -// ============================================================================ -// Public URL Success Test (using httptest) -// ============================================================================ - -func TestNetworkTools_Fetch_PublicURLSuccess(t *testing.T) { +func TestNetwork_Fetch_PublicURLSuccess(t *testing.T) { t.Parallel() h := newnetworkTools(t) server := h.createMockServer() // Use ForTesting to allow httptest server (which uses localhost) - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: server.URL, FetchParallelism: 2, FetchDelay: 10 * time.Millisecond, FetchTimeout: 5 * time.Second, } - nt := newNetworkToolsForTesting(t, cfg, testLogger()) + nt := newNetworkForTesting(t, cfg, testLogger()) ctx := h.toolContext() @@ -558,11 +523,7 @@ func TestNetworkTools_Fetch_PublicURLSuccess(t *testing.T) { } } -// ============================================================================ -// Concurrent Fetch Test -// ============================================================================ - -func TestNetworkTools_Fetch_Concurrent(t *testing.T) { +func TestNetwork_Fetch_Concurrent(t *testing.T) { t.Parallel() h := newnetworkTools(t) @@ -576,13 +537,13 @@ func TestNetworkTools_Fetch_Concurrent(t *testing.T) { })) t.Cleanup(func() { server.Close() }) - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: server.URL, FetchParallelism: 5, FetchDelay: 5 * time.Millisecond, FetchTimeout: 5 * time.Second, } - nt := newNetworkToolsForTesting(t, cfg, testLogger()) + nt := newNetworkForTesting(t, cfg, testLogger()) ctx := h.toolContext() diff --git a/internal/tools/network_test.go b/internal/tools/network_test.go index d495cdc..fe29113 100644 --- a/internal/tools/network_test.go +++ b/internal/tools/network_test.go @@ -5,64 +5,64 @@ import ( "time" ) -func TestNetworkTools_Constructor(t *testing.T) { +func TestNetwork_Constructor(t *testing.T) { t.Run("valid inputs", func(t *testing.T) { - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: "http://localhost:8080", FetchParallelism: 2, FetchDelay: time.Second, FetchTimeout: 30 * time.Second, } - nt, err := NewNetworkTools(cfg, testLogger()) + nt, err := NewNetwork(cfg, testLogger()) if err != nil { - t.Errorf("NewNetworkTools() error = %v, want nil", err) + t.Errorf("NewNetwork() error = %v, want nil", err) } if nt == nil { - t.Error("NewNetworkTools() returned nil, want non-nil") + t.Error("NewNetwork() returned nil, want non-nil") } }) t.Run("empty search URL", func(t *testing.T) { - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: "", } - nt, err := NewNetworkTools(cfg, testLogger()) + nt, err := NewNetwork(cfg, testLogger()) if err == nil { - t.Error("NewNetworkTools() error = nil, want error") + t.Error("NewNetwork() error = nil, want error") } if nt != nil { - t.Error("NewNetworkTools() returned non-nil, want nil") + t.Error("NewNetwork() returned non-nil, want nil") } }) t.Run("nil logger", func(t *testing.T) { - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: "http://localhost:8080", } - nt, err := NewNetworkTools(cfg, nil) + nt, err := NewNetwork(cfg, nil) if err == nil { - t.Error("NewNetworkTools() error = nil, want error") + t.Error("NewNetwork() error = nil, want error") } if nt != nil { - t.Error("NewNetworkTools() returned non-nil, want nil") + t.Error("NewNetwork() returned non-nil, want nil") } }) t.Run("defaults applied", func(t *testing.T) { - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: "http://localhost:8080", // Leave other fields as zero values } - nt, err := NewNetworkTools(cfg, testLogger()) + nt, err := NewNetwork(cfg, testLogger()) if err != nil { - t.Errorf("NewNetworkTools() error = %v, want nil", err) + t.Errorf("NewNetwork() error = %v, want nil", err) } if nt == nil { - t.Fatal("NewNetworkTools() returned nil") + t.Fatal("NewNetwork() returned nil") } // Verify defaults were applied (internal fields not accessible, but no error means success) }) @@ -70,19 +70,19 @@ func TestNetworkTools_Constructor(t *testing.T) { func TestNetworkToolConstants(t *testing.T) { expectedNames := map[string]string{ - "ToolWebSearch": "web_search", - "ToolWebFetch": "web_fetch", + "WebSearchName": "web_search", + "WebFetchName": "web_fetch", } - if ToolWebSearch != expectedNames["ToolWebSearch"] { - t.Errorf("ToolWebSearch = %q, want %q", ToolWebSearch, expectedNames["ToolWebSearch"]) + if WebSearchName != expectedNames["WebSearchName"] { + t.Errorf("WebSearchName = %q, want %q", WebSearchName, expectedNames["WebSearchName"]) } - if ToolWebFetch != expectedNames["ToolWebFetch"] { - t.Errorf("ToolWebFetch = %q, want %q", ToolWebFetch, expectedNames["ToolWebFetch"]) + if WebFetchName != expectedNames["WebFetchName"] { + t.Errorf("WebFetchName = %q, want %q", WebFetchName, expectedNames["WebFetchName"]) } } -func TestNetworkConfigConstants(t *testing.T) { +func TestNetConfigConstants(t *testing.T) { // Verify content limits if MaxURLsPerRequest != 10 { t.Errorf("MaxURLsPerRequest = %d, want 10", MaxURLsPerRequest) @@ -97,137 +97,3 @@ func TestNetworkConfigConstants(t *testing.T) { t.Errorf("DefaultSearchResults = %d, want 10", DefaultSearchResults) } } - -func TestSearchInput(t *testing.T) { - input := SearchInput{ - Query: "test query", - Categories: []string{"general", "news"}, - Language: "en", - MaxResults: 20, - } - if input.Query != "test query" { - t.Errorf("SearchInput.Query = %q, want %q", input.Query, "test query") - } - if len(input.Categories) != 2 { - t.Errorf("SearchInput.Categories length = %d, want 2", len(input.Categories)) - } - if input.Language != "en" { - t.Errorf("SearchInput.Language = %q, want %q", input.Language, "en") - } - if input.MaxResults != 20 { - t.Errorf("SearchInput.MaxResults = %d, want 20", input.MaxResults) - } -} - -func TestFetchInput(t *testing.T) { - input := FetchInput{ - URLs: []string{"https://example.com", "https://test.com"}, - Selector: "article", - } - if len(input.URLs) != 2 { - t.Errorf("FetchInput.URLs length = %d, want 2", len(input.URLs)) - } - if input.Selector != "article" { - t.Errorf("FetchInput.Selector = %q, want %q", input.Selector, "article") - } -} - -func TestSearchOutput(t *testing.T) { - output := SearchOutput{ - Results: []SearchResult{ - { - Title: "Test", - URL: "https://example.com", - Content: "Test content", - Engine: "google", - }, - }, - Query: "test", - } - if len(output.Results) != 1 { - t.Errorf("SearchOutput.Results length = %d, want 1", len(output.Results)) - } - if output.Query != "test" { - t.Errorf("SearchOutput.Query = %q, want %q", output.Query, "test") - } -} - -func TestFetchOutput(t *testing.T) { - output := FetchOutput{ - Results: []FetchResult{ - { - URL: "https://example.com", - Title: "Example", - Content: "Content", - ContentType: "text/html", - }, - }, - FailedURLs: []FailedURL{ - { - URL: "https://failed.com", - Reason: "connection refused", - StatusCode: 503, - }, - }, - } - if len(output.Results) != 1 { - t.Errorf("FetchOutput.Results length = %d, want 1", len(output.Results)) - } - if len(output.FailedURLs) != 1 { - t.Errorf("FetchOutput.FailedURLs length = %d, want 1", len(output.FailedURLs)) - } -} - -func TestSearchResult(t *testing.T) { - result := SearchResult{ - Title: "Test Title", - URL: "https://example.com", - Content: "Test content", - Engine: "google", - PublishedAt: "2024-01-01", - } - if result.Title != "Test Title" { - t.Errorf("SearchResult.Title = %q, want %q", result.Title, "Test Title") - } - if result.URL != "https://example.com" { - t.Errorf("SearchResult.URL = %q, want %q", result.URL, "https://example.com") - } - if result.Engine != "google" { - t.Errorf("SearchResult.Engine = %q, want %q", result.Engine, "google") - } - if result.PublishedAt != "2024-01-01" { - t.Errorf("SearchResult.PublishedAt = %q, want %q", result.PublishedAt, "2024-01-01") - } -} - -func TestFetchResult(t *testing.T) { - result := FetchResult{ - URL: "https://example.com", - Title: "Example", - Content: "Content", - ContentType: "text/html", - } - if result.URL != "https://example.com" { - t.Errorf("FetchResult.URL = %q, want %q", result.URL, "https://example.com") - } - if result.ContentType != "text/html" { - t.Errorf("FetchResult.ContentType = %q, want %q", result.ContentType, "text/html") - } -} - -func TestFailedURL(t *testing.T) { - failed := FailedURL{ - URL: "https://failed.com", - Reason: "timeout", - StatusCode: 504, - } - if failed.URL != "https://failed.com" { - t.Errorf("FailedURL.URL = %q, want %q", failed.URL, "https://failed.com") - } - if failed.Reason != "timeout" { - t.Errorf("FailedURL.Reason = %q, want %q", failed.Reason, "timeout") - } - if failed.StatusCode != 504 { - t.Errorf("FailedURL.StatusCode = %d, want 504", failed.StatusCode) - } -} diff --git a/internal/tools/register_test.go b/internal/tools/register_test.go index af4982e..abb2bb7 100644 --- a/internal/tools/register_test.go +++ b/internal/tools/register_test.go @@ -17,11 +17,7 @@ func setupTestGenkit(t *testing.T) *genkit.Genkit { return genkit.Init(context.Background()) } -// ============================================================================ -// NewFileTools Tests -// ============================================================================ - -func TestNewFileTools(t *testing.T) { +func TestNewFile(t *testing.T) { t.Parallel() t.Run("successful creation", func(t *testing.T) { @@ -31,27 +27,27 @@ func TestNewFileTools(t *testing.T) { t.Fatalf("NewPath() unexpected error: %v", err) } - ft, err := NewFileTools(pathVal, testLogger()) + ft, err := NewFile(pathVal, testLogger()) if err != nil { - t.Fatalf("NewFileTools() unexpected error: %v", err) + t.Fatalf("NewFile() unexpected error: %v", err) } if ft == nil { - t.Fatal("NewFileTools() = nil, want non-nil") + t.Fatal("NewFile() = nil, want non-nil") } }) t.Run("nil path validator", func(t *testing.T) { t.Parallel() - ft, err := NewFileTools(nil, testLogger()) + ft, err := NewFile(nil, testLogger()) if err == nil { - t.Fatal("NewFileTools(nil, logger) expected error, got nil") + t.Fatal("NewFile(nil, logger) expected error, got nil") } if ft != nil { - t.Errorf("NewFileTools(nil, logger) = %v, want nil", ft) + t.Errorf("NewFile(nil, logger) = %v, want nil", ft) } if !strings.Contains(err.Error(), "path validator is required") { - t.Errorf("NewFileTools(nil, logger) error = %q, want contains %q", err.Error(), "path validator is required") + t.Errorf("NewFile(nil, logger) error = %q, want contains %q", err.Error(), "path validator is required") } }) @@ -62,24 +58,20 @@ func TestNewFileTools(t *testing.T) { t.Fatalf("NewPath() unexpected error: %v", err) } - ft, err := NewFileTools(pathVal, nil) + ft, err := NewFile(pathVal, nil) if err == nil { - t.Fatal("NewFileTools(pathVal, nil) expected error, got nil") + t.Fatal("NewFile(pathVal, nil) expected error, got nil") } if ft != nil { - t.Errorf("NewFileTools(pathVal, nil) = %v, want nil", ft) + t.Errorf("NewFile(pathVal, nil) = %v, want nil", ft) } if !strings.Contains(err.Error(), "logger is required") { - t.Errorf("NewFileTools(pathVal, nil) error = %q, want contains %q", err.Error(), "logger is required") + t.Errorf("NewFile(pathVal, nil) error = %q, want contains %q", err.Error(), "logger is required") } }) } -// ============================================================================ -// RegisterFileTools Tests -// ============================================================================ - -func TestRegisterFileTools(t *testing.T) { +func TestRegisterFile(t *testing.T) { t.Parallel() t.Run("successful registration", func(t *testing.T) { @@ -90,24 +82,24 @@ func TestRegisterFileTools(t *testing.T) { t.Fatalf("NewPath() unexpected error: %v", err) } - ft, err := NewFileTools(pathVal, testLogger()) + ft, err := NewFile(pathVal, testLogger()) if err != nil { - t.Fatalf("NewFileTools() unexpected error: %v", err) + t.Fatalf("NewFile() unexpected error: %v", err) } - tools, err := RegisterFileTools(g, ft) + tools, err := RegisterFile(g, ft) if err != nil { - t.Fatalf("RegisterFileTools() unexpected error: %v", err) + t.Fatalf("RegisterFile() unexpected error: %v", err) } if got, want := len(tools), 5; got != want { - t.Errorf("RegisterFileTools() tool count = %d, want %d (should register 5 file tools)", got, want) + t.Errorf("RegisterFile() tool count = %d, want %d (should register 5 file tools)", got, want) } // Verify tool names - expectedNames := []string{ToolReadFile, ToolWriteFile, ToolListFiles, ToolDeleteFile, ToolGetFileInfo} + expectedNames := []string{ReadFileName, WriteFileName, ListFilesName, DeleteFileName, FileInfoName} actualNames := extractToolNames(tools) if !slicesEqual(expectedNames, actualNames) { - t.Errorf("RegisterFileTools() tool names = %v, want %v", actualNames, expectedNames) + t.Errorf("RegisterFile() tool names = %v, want %v", actualNames, expectedNames) } }) @@ -118,45 +110,41 @@ func TestRegisterFileTools(t *testing.T) { t.Fatalf("NewPath() unexpected error: %v", err) } - ft, err := NewFileTools(pathVal, testLogger()) + ft, err := NewFile(pathVal, testLogger()) if err != nil { - t.Fatalf("NewFileTools() unexpected error: %v", err) + t.Fatalf("NewFile() unexpected error: %v", err) } - tools, err := RegisterFileTools(nil, ft) + tools, err := RegisterFile(nil, ft) if err == nil { - t.Fatal("RegisterFileTools(nil, ft) expected error, got nil") + t.Fatal("RegisterFile(nil, ft) expected error, got nil") } if tools != nil { - t.Errorf("RegisterFileTools(nil, ft) = %v, want nil", tools) + t.Errorf("RegisterFile(nil, ft) = %v, want nil", tools) } if !strings.Contains(err.Error(), "genkit instance is required") { - t.Errorf("RegisterFileTools(nil, ft) error = %q, want contains %q", err.Error(), "genkit instance is required") + t.Errorf("RegisterFile(nil, ft) error = %q, want contains %q", err.Error(), "genkit instance is required") } }) - t.Run("nil FileTools", func(t *testing.T) { + t.Run("nil File", func(t *testing.T) { t.Parallel() g := setupTestGenkit(t) - tools, err := RegisterFileTools(g, nil) + tools, err := RegisterFile(g, nil) if err == nil { - t.Fatal("RegisterFileTools(g, nil) expected error, got nil") + t.Fatal("RegisterFile(g, nil) expected error, got nil") } if tools != nil { - t.Errorf("RegisterFileTools(g, nil) = %v, want nil", tools) + t.Errorf("RegisterFile(g, nil) = %v, want nil", tools) } - if !strings.Contains(err.Error(), "FileTools is required") { - t.Errorf("RegisterFileTools(g, nil) error = %q, want contains %q", err.Error(), "FileTools is required") + if !strings.Contains(err.Error(), "File is required") { + t.Errorf("RegisterFile(g, nil) error = %q, want contains %q", err.Error(), "File is required") } }) } -// ============================================================================ -// NewSystemTools Tests -// ============================================================================ - -func TestNewSystemTools(t *testing.T) { +func TestNewSystem(t *testing.T) { t.Parallel() t.Run("successful creation", func(t *testing.T) { @@ -164,12 +152,12 @@ func TestNewSystemTools(t *testing.T) { cmdVal := security.NewCommand() envVal := security.NewEnv() - st, err := NewSystemTools(cmdVal, envVal, testLogger()) + st, err := NewSystem(cmdVal, envVal, testLogger()) if err != nil { - t.Fatalf("NewSystemTools() unexpected error: %v", err) + t.Fatalf("NewSystem() unexpected error: %v", err) } if st == nil { - t.Fatal("NewSystemTools() = nil, want non-nil") + t.Fatal("NewSystem() = nil, want non-nil") } }) @@ -177,15 +165,15 @@ func TestNewSystemTools(t *testing.T) { t.Parallel() envVal := security.NewEnv() - st, err := NewSystemTools(nil, envVal, testLogger()) + st, err := NewSystem(nil, envVal, testLogger()) if err == nil { - t.Fatal("NewSystemTools(nil, envVal, logger) expected error, got nil") + t.Fatal("NewSystem(nil, envVal, logger) expected error, got nil") } if st != nil { - t.Errorf("NewSystemTools(nil, envVal, logger) = %v, want nil", st) + t.Errorf("NewSystem(nil, envVal, logger) = %v, want nil", st) } if !strings.Contains(err.Error(), "command validator is required") { - t.Errorf("NewSystemTools(nil, envVal, logger) error = %q, want contains %q", err.Error(), "command validator is required") + t.Errorf("NewSystem(nil, envVal, logger) error = %q, want contains %q", err.Error(), "command validator is required") } }) @@ -193,15 +181,15 @@ func TestNewSystemTools(t *testing.T) { t.Parallel() cmdVal := security.NewCommand() - st, err := NewSystemTools(cmdVal, nil, testLogger()) + st, err := NewSystem(cmdVal, nil, testLogger()) if err == nil { - t.Fatal("NewSystemTools(cmdVal, nil, logger) expected error, got nil") + t.Fatal("NewSystem(cmdVal, nil, logger) expected error, got nil") } if st != nil { - t.Errorf("NewSystemTools(cmdVal, nil, logger) = %v, want nil", st) + t.Errorf("NewSystem(cmdVal, nil, logger) = %v, want nil", st) } if !strings.Contains(err.Error(), "env validator is required") { - t.Errorf("NewSystemTools(cmdVal, nil, logger) error = %q, want contains %q", err.Error(), "env validator is required") + t.Errorf("NewSystem(cmdVal, nil, logger) error = %q, want contains %q", err.Error(), "env validator is required") } }) @@ -210,24 +198,20 @@ func TestNewSystemTools(t *testing.T) { cmdVal := security.NewCommand() envVal := security.NewEnv() - st, err := NewSystemTools(cmdVal, envVal, nil) + st, err := NewSystem(cmdVal, envVal, nil) if err == nil { - t.Fatal("NewSystemTools(cmdVal, envVal, nil) expected error, got nil") + t.Fatal("NewSystem(cmdVal, envVal, nil) expected error, got nil") } if st != nil { - t.Errorf("NewSystemTools(cmdVal, envVal, nil) = %v, want nil", st) + t.Errorf("NewSystem(cmdVal, envVal, nil) = %v, want nil", st) } if !strings.Contains(err.Error(), "logger is required") { - t.Errorf("NewSystemTools(cmdVal, envVal, nil) error = %q, want contains %q", err.Error(), "logger is required") + t.Errorf("NewSystem(cmdVal, envVal, nil) error = %q, want contains %q", err.Error(), "logger is required") } }) } -// ============================================================================ -// RegisterSystemTools Tests -// ============================================================================ - -func TestRegisterSystemTools(t *testing.T) { +func TestRegisterSystem(t *testing.T) { t.Parallel() t.Run("successful registration", func(t *testing.T) { @@ -236,24 +220,24 @@ func TestRegisterSystemTools(t *testing.T) { cmdVal := security.NewCommand() envVal := security.NewEnv() - st, err := NewSystemTools(cmdVal, envVal, testLogger()) + st, err := NewSystem(cmdVal, envVal, testLogger()) if err != nil { - t.Fatalf("NewSystemTools() unexpected error: %v", err) + t.Fatalf("NewSystem() unexpected error: %v", err) } - tools, err := RegisterSystemTools(g, st) + tools, err := RegisterSystem(g, st) if err != nil { - t.Fatalf("RegisterSystemTools() unexpected error: %v", err) + t.Fatalf("RegisterSystem() unexpected error: %v", err) } if got, want := len(tools), 3; got != want { - t.Errorf("RegisterSystemTools() tool count = %d, want %d (should register 3 system tools)", got, want) + t.Errorf("RegisterSystem() tool count = %d, want %d (should register 3 system tools)", got, want) } // Verify tool names - expectedNames := []string{ToolCurrentTime, ToolExecuteCommand, ToolGetEnv} + expectedNames := []string{CurrentTimeName, ExecuteCommandName, GetEnvName} actualNames := extractToolNames(tools) if !slicesEqual(expectedNames, actualNames) { - t.Errorf("RegisterSystemTools() tool names = %v, want %v", actualNames, expectedNames) + t.Errorf("RegisterSystem() tool names = %v, want %v", actualNames, expectedNames) } }) @@ -262,235 +246,219 @@ func TestRegisterSystemTools(t *testing.T) { cmdVal := security.NewCommand() envVal := security.NewEnv() - st, err := NewSystemTools(cmdVal, envVal, testLogger()) + st, err := NewSystem(cmdVal, envVal, testLogger()) if err != nil { - t.Fatalf("NewSystemTools() unexpected error: %v", err) + t.Fatalf("NewSystem() unexpected error: %v", err) } - tools, err := RegisterSystemTools(nil, st) + tools, err := RegisterSystem(nil, st) if err == nil { - t.Fatal("RegisterSystemTools(nil, st) expected error, got nil") + t.Fatal("RegisterSystem(nil, st) expected error, got nil") } if tools != nil { - t.Errorf("RegisterSystemTools(nil, st) = %v, want nil", tools) + t.Errorf("RegisterSystem(nil, st) = %v, want nil", tools) } if !strings.Contains(err.Error(), "genkit instance is required") { - t.Errorf("RegisterSystemTools(nil, st) error = %q, want contains %q", err.Error(), "genkit instance is required") + t.Errorf("RegisterSystem(nil, st) error = %q, want contains %q", err.Error(), "genkit instance is required") } }) - t.Run("nil SystemTools", func(t *testing.T) { + t.Run("nil System", func(t *testing.T) { t.Parallel() g := setupTestGenkit(t) - tools, err := RegisterSystemTools(g, nil) + tools, err := RegisterSystem(g, nil) if err == nil { - t.Fatal("RegisterSystemTools(g, nil) expected error, got nil") + t.Fatal("RegisterSystem(g, nil) expected error, got nil") } if tools != nil { - t.Errorf("RegisterSystemTools(g, nil) = %v, want nil", tools) + t.Errorf("RegisterSystem(g, nil) = %v, want nil", tools) } - if !strings.Contains(err.Error(), "SystemTools is required") { - t.Errorf("RegisterSystemTools(g, nil) error = %q, want contains %q", err.Error(), "SystemTools is required") + if !strings.Contains(err.Error(), "System is required") { + t.Errorf("RegisterSystem(g, nil) error = %q, want contains %q", err.Error(), "System is required") } }) } -// ============================================================================ -// NewNetworkTools Tests -// ============================================================================ - -func TestNewNetworkTools(t *testing.T) { +func TestNewNetwork(t *testing.T) { t.Parallel() t.Run("successful creation", func(t *testing.T) { t.Parallel() - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: "http://localhost:8080", } - nt, err := NewNetworkTools(cfg, testLogger()) + nt, err := NewNetwork(cfg, testLogger()) if err != nil { - t.Fatalf("NewNetworkTools() unexpected error: %v", err) + t.Fatalf("NewNetwork() unexpected error: %v", err) } if nt == nil { - t.Fatal("NewNetworkTools() = nil, want non-nil") + t.Fatal("NewNetwork() = nil, want non-nil") } }) t.Run("empty search base URL", func(t *testing.T) { t.Parallel() - cfg := NetworkConfig{} + cfg := NetConfig{} - nt, err := NewNetworkTools(cfg, testLogger()) + nt, err := NewNetwork(cfg, testLogger()) if err == nil { - t.Fatal("NewNetworkTools(empty config) expected error, got nil") + t.Fatal("NewNetwork(empty config) expected error, got nil") } if nt != nil { - t.Errorf("NewNetworkTools(empty config) = %v, want nil", nt) + t.Errorf("NewNetwork(empty config) = %v, want nil", nt) } if !strings.Contains(err.Error(), "search base URL is required") { - t.Errorf("NewNetworkTools(empty config) error = %q, want contains %q", err.Error(), "search base URL is required") + t.Errorf("NewNetwork(empty config) error = %q, want contains %q", err.Error(), "search base URL is required") } }) t.Run("nil logger", func(t *testing.T) { t.Parallel() - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: "http://localhost:8080", } - nt, err := NewNetworkTools(cfg, nil) + nt, err := NewNetwork(cfg, nil) if err == nil { - t.Fatal("NewNetworkTools(cfg, nil) expected error, got nil") + t.Fatal("NewNetwork(cfg, nil) expected error, got nil") } if nt != nil { - t.Errorf("NewNetworkTools(cfg, nil) = %v, want nil", nt) + t.Errorf("NewNetwork(cfg, nil) = %v, want nil", nt) } if !strings.Contains(err.Error(), "logger is required") { - t.Errorf("NewNetworkTools(cfg, nil) error = %q, want contains %q", err.Error(), "logger is required") + t.Errorf("NewNetwork(cfg, nil) error = %q, want contains %q", err.Error(), "logger is required") } }) t.Run("default values applied", func(t *testing.T) { t.Parallel() - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: "http://localhost:8080/", // Leave other values at zero - should get defaults } - nt, err := NewNetworkTools(cfg, testLogger()) + nt, err := NewNetwork(cfg, testLogger()) if err != nil { - t.Fatalf("NewNetworkTools() unexpected error: %v", err) + t.Fatalf("NewNetwork() unexpected error: %v", err) } if nt == nil { - t.Fatal("NewNetworkTools() = nil, want non-nil") + t.Fatal("NewNetwork() = nil, want non-nil") } }) } -// ============================================================================ -// RegisterNetworkTools Tests -// ============================================================================ - -func TestRegisterNetworkTools(t *testing.T) { +func TestRegisterNetwork(t *testing.T) { t.Parallel() t.Run("successful registration", func(t *testing.T) { t.Parallel() g := setupTestGenkit(t) - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: "http://localhost:8080", } - nt, err := NewNetworkTools(cfg, testLogger()) + nt, err := NewNetwork(cfg, testLogger()) if err != nil { - t.Fatalf("NewNetworkTools() unexpected error: %v", err) + t.Fatalf("NewNetwork() unexpected error: %v", err) } - tools, err := RegisterNetworkTools(g, nt) + tools, err := RegisterNetwork(g, nt) if err != nil { - t.Fatalf("RegisterNetworkTools() unexpected error: %v", err) + t.Fatalf("RegisterNetwork() unexpected error: %v", err) } if got, want := len(tools), 2; got != want { - t.Errorf("RegisterNetworkTools() tool count = %d, want %d (should register 2 network tools)", got, want) + t.Errorf("RegisterNetwork() tool count = %d, want %d (should register 2 network tools)", got, want) } // Verify tool names - expectedNames := []string{ToolWebSearch, ToolWebFetch} + expectedNames := []string{WebSearchName, WebFetchName} actualNames := extractToolNames(tools) if !slicesEqual(expectedNames, actualNames) { - t.Errorf("RegisterNetworkTools() tool names = %v, want %v", actualNames, expectedNames) + t.Errorf("RegisterNetwork() tool names = %v, want %v", actualNames, expectedNames) } }) t.Run("nil genkit", func(t *testing.T) { t.Parallel() - cfg := NetworkConfig{ + cfg := NetConfig{ SearchBaseURL: "http://localhost:8080", } - nt, err := NewNetworkTools(cfg, testLogger()) + nt, err := NewNetwork(cfg, testLogger()) if err != nil { - t.Fatalf("NewNetworkTools() unexpected error: %v", err) + t.Fatalf("NewNetwork() unexpected error: %v", err) } - tools, err := RegisterNetworkTools(nil, nt) + tools, err := RegisterNetwork(nil, nt) if err == nil { - t.Fatal("RegisterNetworkTools(nil, nt) expected error, got nil") + t.Fatal("RegisterNetwork(nil, nt) expected error, got nil") } if tools != nil { - t.Errorf("RegisterNetworkTools(nil, nt) = %v, want nil", tools) + t.Errorf("RegisterNetwork(nil, nt) = %v, want nil", tools) } if !strings.Contains(err.Error(), "genkit instance is required") { - t.Errorf("RegisterNetworkTools(nil, nt) error = %q, want contains %q", err.Error(), "genkit instance is required") + t.Errorf("RegisterNetwork(nil, nt) error = %q, want contains %q", err.Error(), "genkit instance is required") } }) - t.Run("nil NetworkTools", func(t *testing.T) { + t.Run("nil Network", func(t *testing.T) { t.Parallel() g := setupTestGenkit(t) - tools, err := RegisterNetworkTools(g, nil) + tools, err := RegisterNetwork(g, nil) if err == nil { - t.Fatal("RegisterNetworkTools(g, nil) expected error, got nil") + t.Fatal("RegisterNetwork(g, nil) expected error, got nil") } if tools != nil { - t.Errorf("RegisterNetworkTools(g, nil) = %v, want nil", tools) + t.Errorf("RegisterNetwork(g, nil) = %v, want nil", tools) } - if !strings.Contains(err.Error(), "NetworkTools is required") { - t.Errorf("RegisterNetworkTools(g, nil) error = %q, want contains %q", err.Error(), "NetworkTools is required") + if !strings.Contains(err.Error(), "Network is required") { + t.Errorf("RegisterNetwork(g, nil) error = %q, want contains %q", err.Error(), "Network is required") } }) } -// ============================================================================ -// RegisterKnowledgeTools Tests -// ============================================================================ +// Note: Knowledge validation (nil store, nil logger) is tested in +// TestNewKnowledge in knowledge_test.go. These tests verify +// RegisterKnowledge parameter validation only. -// Note: KnowledgeTools validation (nil store, nil logger) is tested in -// TestNewKnowledgeTools in knowledge_test.go. These tests verify -// RegisterKnowledgeTools parameter validation only. - -func TestRegisterKnowledgeTools(t *testing.T) { +func TestRegisterKnowledge(t *testing.T) { t.Parallel() t.Run("nil genkit", func(t *testing.T) { t.Parallel() - tools, err := RegisterKnowledgeTools(nil, &KnowledgeTools{}) + tools, err := RegisterKnowledge(nil, &Knowledge{}) if err == nil { - t.Fatal("RegisterKnowledgeTools(nil, kt) expected error, got nil") + t.Fatal("RegisterKnowledge(nil, kt) expected error, got nil") } if tools != nil { - t.Errorf("RegisterKnowledgeTools(nil, kt) = %v, want nil", tools) + t.Errorf("RegisterKnowledge(nil, kt) = %v, want nil", tools) } if !strings.Contains(err.Error(), "genkit instance is required") { - t.Errorf("RegisterKnowledgeTools(nil, kt) error = %q, want contains %q", err.Error(), "genkit instance is required") + t.Errorf("RegisterKnowledge(nil, kt) error = %q, want contains %q", err.Error(), "genkit instance is required") } }) - t.Run("nil KnowledgeTools", func(t *testing.T) { + t.Run("nil Knowledge", func(t *testing.T) { t.Parallel() g := setupTestGenkit(t) - tools, err := RegisterKnowledgeTools(g, nil) + tools, err := RegisterKnowledge(g, nil) if err == nil { - t.Fatal("RegisterKnowledgeTools(g, nil) expected error, got nil") + t.Fatal("RegisterKnowledge(g, nil) expected error, got nil") } if tools != nil { - t.Errorf("RegisterKnowledgeTools(g, nil) = %v, want nil", tools) + t.Errorf("RegisterKnowledge(g, nil) = %v, want nil", tools) } - if !strings.Contains(err.Error(), "KnowledgeTools is required") { - t.Errorf("RegisterKnowledgeTools(g, nil) error = %q, want contains %q", err.Error(), "KnowledgeTools is required") + if !strings.Contains(err.Error(), "Knowledge is required") { + t.Errorf("RegisterKnowledge(g, nil) error = %q, want contains %q", err.Error(), "Knowledge is required") } }) } -// ============================================================================ -// Helper Functions -// ============================================================================ - func extractToolNames(tools []ai.Tool) []string { names := make([]string, len(tools)) for i, tool := range tools { diff --git a/internal/tools/setup_test.go b/internal/tools/setup_test.go index b710699..491c48f 100644 --- a/internal/tools/setup_test.go +++ b/internal/tools/setup_test.go @@ -1,25 +1,25 @@ package tools import ( + "log/slog" "testing" - - "github.com/koopa0/koopa/internal/log" ) // testLogger returns a no-op logger for testing. -func testLogger() log.Logger { - return log.NewNop() +func testLogger() *slog.Logger { + return slog.New(slog.DiscardHandler) } -// newNetworkToolsForTesting creates a NetworkTools instance with SSRF protection +// newNetworkForTesting creates a Network instance with SSRF protection // disabled. This allows tests to use httptest.Server (which binds to localhost). // Only accessible within tools package tests (unexported). -func newNetworkToolsForTesting(tb testing.TB, cfg NetworkConfig, logger log.Logger) *NetworkTools { +func newNetworkForTesting(tb testing.TB, cfg NetConfig, logger *slog.Logger) *Network { tb.Helper() - nt, err := NewNetworkTools(cfg, logger) + nt, err := NewNetwork(cfg, logger) if err != nil { - tb.Fatalf("NewNetworkTools() unexpected error: %v", err) + tb.Fatalf("NewNetwork() unexpected error: %v", err) } nt.skipSSRFCheck = true + nt.searchClient.Transport = nil // allow localhost in tests return nt } diff --git a/internal/tools/system.go b/internal/tools/system.go index 3dca4ff..9f6dbb5 100644 --- a/internal/tools/system.go +++ b/internal/tools/system.go @@ -3,6 +3,7 @@ package tools import ( "context" "fmt" + "log/slog" "os" "os/exec" "strings" @@ -11,14 +12,17 @@ import ( "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/genkit" - "github.com/koopa0/koopa/internal/log" "github.com/koopa0/koopa/internal/security" ) +// Tool name constants for system operations registered with Genkit. const ( - ToolCurrentTime = "current_time" - ToolExecuteCommand = "execute_command" - ToolGetEnv = "get_env" + // CurrentTimeName is the Genkit tool name for retrieving the current time. + CurrentTimeName = "current_time" + // ExecuteCommandName is the Genkit tool name for executing shell commands. + ExecuteCommandName = "execute_command" + // GetEnvName is the Genkit tool name for reading environment variables. + GetEnvName = "get_env" ) // ExecuteCommandInput defines input for execute_command tool. @@ -35,18 +39,18 @@ type GetEnvInput struct { // CurrentTimeInput defines input for current_time tool (no input needed). type CurrentTimeInput struct{} -// SystemTools holds dependencies for system operation handlers. -// Use NewSystemTools to create an instance, then either: +// System holds dependencies for system operation handlers. +// Use NewSystem to create an instance, then either: // - Call methods directly (for MCP) -// - Use RegisterSystemTools to register with Genkit -type SystemTools struct { +// - Use RegisterSystem to register with Genkit +type System struct { cmdVal *security.Command envVal *security.Env - logger log.Logger + logger *slog.Logger } -// NewSystemTools creates a SystemTools instance. -func NewSystemTools(cmdVal *security.Command, envVal *security.Env, logger log.Logger) (*SystemTools, error) { +// NewSystem creates a System instance. +func NewSystem(cmdVal *security.Command, envVal *security.Env, logger *slog.Logger) (*System, error) { if cmdVal == nil { return nil, fmt.Errorf("command validator is required") } @@ -56,45 +60,45 @@ func NewSystemTools(cmdVal *security.Command, envVal *security.Env, logger log.L if logger == nil { return nil, fmt.Errorf("logger is required") } - return &SystemTools{cmdVal: cmdVal, envVal: envVal, logger: logger}, nil + return &System{cmdVal: cmdVal, envVal: envVal, logger: logger}, nil } -// RegisterSystemTools registers all system operation tools with Genkit. +// RegisterSystem registers all system operation tools with Genkit. // Tools are registered with event emission wrappers for streaming support. -func RegisterSystemTools(g *genkit.Genkit, st *SystemTools) ([]ai.Tool, error) { +func RegisterSystem(g *genkit.Genkit, st *System) ([]ai.Tool, error) { if g == nil { return nil, fmt.Errorf("genkit instance is required") } if st == nil { - return nil, fmt.Errorf("SystemTools is required") + return nil, fmt.Errorf("System is required") } return []ai.Tool{ - genkit.DefineTool(g, ToolCurrentTime, + genkit.DefineTool(g, CurrentTimeName, "Get the current system date and time. "+ "Returns: formatted time string, Unix timestamp, and ISO 8601 format. "+ "Use this to: check current time, calculate relative times, add timestamps to outputs. "+ "Always returns the server's local time zone.", - WithEvents(ToolCurrentTime, st.CurrentTime)), - genkit.DefineTool(g, ToolExecuteCommand, + WithEvents(CurrentTimeName, st.CurrentTime)), + genkit.DefineTool(g, ExecuteCommandName, "Execute a shell command from the allowed list with security validation. "+ "Allowed commands: git, npm, yarn, go, make, docker, kubectl, ls, cat, grep, find, pwd, echo. "+ "Commands run with a timeout to prevent hanging. "+ "Returns: stdout, stderr, exit code, and execution time. "+ "Use this for: running builds, checking git status, listing processes, viewing file contents. "+ "Security: Dangerous commands (rm -rf, sudo, chmod, etc.) are blocked.", - WithEvents(ToolExecuteCommand, st.ExecuteCommand)), - genkit.DefineTool(g, ToolGetEnv, + WithEvents(ExecuteCommandName, st.ExecuteCommand)), + genkit.DefineTool(g, GetEnvName, "Read an environment variable value from the system. "+ "Returns: the variable name and its value. "+ "Use this to: check configuration, verify paths, read non-sensitive settings. "+ "Security: Sensitive variables containing KEY, SECRET, TOKEN, or PASSWORD in their names are protected and will not be returned.", - WithEvents(ToolGetEnv, st.GetEnv)), + WithEvents(GetEnvName, st.GetEnv)), }, nil } // CurrentTime returns the current system date and time in multiple formats. -func (s *SystemTools) CurrentTime(_ *ai.ToolContext, _ CurrentTimeInput) (Result, error) { +func (s *System) CurrentTime(_ *ai.ToolContext, _ CurrentTimeInput) (Result, error) { s.logger.Info("CurrentTime called") now := time.Now() s.logger.Info("CurrentTime succeeded") @@ -112,11 +116,11 @@ func (s *SystemTools) CurrentTime(_ *ai.ToolContext, _ CurrentTimeInput) (Result // Dangerous commands like rm -rf, sudo, and shutdown are blocked. // Business errors (blocked commands, execution failures) are returned in Result.Error. // Only context cancellation returns a Go error. -func (s *SystemTools) ExecuteCommand(ctx *ai.ToolContext, input ExecuteCommandInput) (Result, error) { +func (s *System) ExecuteCommand(ctx *ai.ToolContext, input ExecuteCommandInput) (Result, error) { s.logger.Info("ExecuteCommand called", "command", input.Command, "args", input.Args) // Command security validation (prevent command injection attacks CWE-78) - if err := s.cmdVal.ValidateCommand(input.Command, input.Args); err != nil { + if err := s.cmdVal.Validate(input.Command, input.Args); err != nil { s.logger.Error("ExecuteCommand dangerous command rejected", "command", input.Command, "args", input.Args, "error", err) return Result{ Status: StatusError, @@ -138,17 +142,16 @@ func (s *SystemTools) ExecuteCommand(ctx *ai.ToolContext, input ExecuteCommandIn if err != nil { // Check if it was canceled by context - this is infrastructure error if execCtx.Err() != nil { - s.logger.Error("ExecuteCommand canceled", "command", input.Command, "error", execCtx.Err()) return Result{}, fmt.Errorf("command execution canceled: %w", execCtx.Err()) } // Command execution failure is a business error - s.logger.Error("ExecuteCommand failed", "command", input.Command, "error", err, "output", string(output)) + s.logger.Error("executing command", "command", input.Command, "error", err, "output", string(output)) return Result{ Status: StatusError, Error: &Error{ Code: ErrCodeExecution, - Message: fmt.Sprintf("command failed: %v", err), + Message: fmt.Sprintf("executing command: %v", err), Details: map[string]any{ "command": input.Command, "args": strings.Join(input.Args, " "), @@ -174,11 +177,11 @@ func (s *SystemTools) ExecuteCommand(ctx *ai.ToolContext, input ExecuteCommandIn // GetEnv reads an environment variable value with security protection. // Sensitive variables containing KEY, SECRET, or TOKEN in the name are blocked. // Business errors (sensitive variable blocked) are returned in Result.Error. -func (s *SystemTools) GetEnv(_ *ai.ToolContext, input GetEnvInput) (Result, error) { +func (s *System) GetEnv(_ *ai.ToolContext, input GetEnvInput) (Result, error) { s.logger.Info("GetEnv called", "key", input.Key) // Environment variable security validation (prevent sensitive information leakage) - if err := s.envVal.ValidateEnvAccess(input.Key); err != nil { + if err := s.envVal.Validate(input.Key); err != nil { s.logger.Error("GetEnv sensitive variable blocked", "key", input.Key, "error", err) return Result{ Status: StatusError, diff --git a/internal/tools/system_integration_test.go b/internal/tools/system_integration_test.go index 5cc593c..8e73a9a 100644 --- a/internal/tools/system_integration_test.go +++ b/internal/tools/system_integration_test.go @@ -11,7 +11,7 @@ import ( "github.com/koopa0/koopa/internal/security" ) -// systemTools provides test utilities for SystemTools. +// systemTools provides test utilities for System. type systemTools struct { t *testing.T } @@ -21,13 +21,13 @@ func newsystemTools(t *testing.T) *systemTools { return &systemTools{t: t} } -func (h *systemTools) createSystemTools() *SystemTools { +func (h *systemTools) createSystem() *System { h.t.Helper() cmdVal := security.NewCommand() envVal := security.NewEnv() - st, err := NewSystemTools(cmdVal, envVal, testLogger()) + st, err := NewSystem(cmdVal, envVal, testLogger()) if err != nil { - h.t.Fatalf("failed to create system tools: %v", err) + h.t.Fatalf("creating system tools: %v", err) } return st } @@ -36,11 +36,7 @@ func (*systemTools) toolContext() *ai.ToolContext { return &ai.ToolContext{Context: context.Background()} } -// ============================================================================ -// ExecuteCommand Integration Tests - Command Injection Prevention -// ============================================================================ - -func TestSystemTools_ExecuteCommand_WhitelistEnforcement(t *testing.T) { +func TestSystem_ExecuteCommand_WhitelistEnforcement(t *testing.T) { t.Parallel() tests := []struct { @@ -138,7 +134,7 @@ func TestSystemTools_ExecuteCommand_WhitelistEnforcement(t *testing.T) { t.Parallel() h := newsystemTools(t) - st := h.createSystemTools() + st := h.createSystem() ctx := h.toolContext() result, err := st.ExecuteCommand(ctx, ExecuteCommandInput{ @@ -176,7 +172,7 @@ func TestSystemTools_ExecuteCommand_WhitelistEnforcement(t *testing.T) { } } -func TestSystemTools_ExecuteCommand_DangerousPatterns(t *testing.T) { +func TestSystem_ExecuteCommand_DangerousPatterns(t *testing.T) { t.Parallel() tests := []struct { @@ -234,7 +230,7 @@ func TestSystemTools_ExecuteCommand_DangerousPatterns(t *testing.T) { t.Parallel() h := newsystemTools(t) - st := h.createSystemTools() + st := h.createSystem() ctx := h.toolContext() result, err := st.ExecuteCommand(ctx, ExecuteCommandInput{ @@ -257,11 +253,11 @@ func TestSystemTools_ExecuteCommand_DangerousPatterns(t *testing.T) { } } -func TestSystemTools_ExecuteCommand_Success(t *testing.T) { +func TestSystem_ExecuteCommand_Success(t *testing.T) { t.Parallel() h := newsystemTools(t) - st := h.createSystemTools() + st := h.createSystem() ctx := h.toolContext() result, err := st.ExecuteCommand(ctx, ExecuteCommandInput{ @@ -298,7 +294,7 @@ func TestSystemTools_ExecuteCommand_Success(t *testing.T) { } } -func TestSystemTools_ExecuteCommand_ContextCancellation(t *testing.T) { +func TestSystem_ExecuteCommand_ContextCancellation(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("sleep command not available on Windows") } @@ -306,7 +302,7 @@ func TestSystemTools_ExecuteCommand_ContextCancellation(t *testing.T) { t.Parallel() h := newsystemTools(t) - st := h.createSystemTools() + st := h.createSystem() // Create a context that's already canceled ctx, cancel := context.WithCancel(context.Background()) @@ -328,11 +324,7 @@ func TestSystemTools_ExecuteCommand_ContextCancellation(t *testing.T) { } } -// ============================================================================ -// GetEnv Integration Tests - Sensitive Variable Protection -// ============================================================================ - -func TestSystemTools_GetEnv_SensitiveVariableBlocked(t *testing.T) { +func TestSystem_GetEnv_SensitiveVariableBlocked(t *testing.T) { t.Parallel() tests := []struct { @@ -440,7 +432,7 @@ func TestSystemTools_GetEnv_SensitiveVariableBlocked(t *testing.T) { t.Parallel() h := newsystemTools(t) - st := h.createSystemTools() + st := h.createSystem() result, err := st.GetEnv(nil, GetEnvInput{Key: tt.envKey}) @@ -459,7 +451,7 @@ func TestSystemTools_GetEnv_SensitiveVariableBlocked(t *testing.T) { } } -func TestSystemTools_GetEnv_SafeVariableAllowed(t *testing.T) { +func TestSystem_GetEnv_SafeVariableAllowed(t *testing.T) { t.Parallel() tests := []struct { @@ -481,7 +473,7 @@ func TestSystemTools_GetEnv_SafeVariableAllowed(t *testing.T) { t.Parallel() h := newsystemTools(t) - st := h.createSystemTools() + st := h.createSystem() result, err := st.GetEnv(nil, GetEnvInput{Key: tt.envKey}) @@ -507,7 +499,7 @@ func TestSystemTools_GetEnv_SafeVariableAllowed(t *testing.T) { } } -func TestSystemTools_GetEnv_CaseInsensitiveBlocking(t *testing.T) { +func TestSystem_GetEnv_CaseInsensitiveBlocking(t *testing.T) { t.Parallel() tests := []struct { @@ -527,7 +519,7 @@ func TestSystemTools_GetEnv_CaseInsensitiveBlocking(t *testing.T) { t.Parallel() h := newsystemTools(t) - st := h.createSystemTools() + st := h.createSystem() result, err := st.GetEnv(nil, GetEnvInput{Key: tt.envKey}) @@ -546,15 +538,11 @@ func TestSystemTools_GetEnv_CaseInsensitiveBlocking(t *testing.T) { } } -// ============================================================================ -// CurrentTime Integration Tests -// ============================================================================ - -func TestSystemTools_CurrentTime_Success(t *testing.T) { +func TestSystem_CurrentTime_Success(t *testing.T) { t.Parallel() h := newsystemTools(t) - st := h.createSystemTools() + st := h.createSystem() result, err := st.CurrentTime(nil, CurrentTimeInput{}) diff --git a/internal/tools/system_test.go b/internal/tools/system_test.go index 2950310..3677005 100644 --- a/internal/tools/system_test.go +++ b/internal/tools/system_test.go @@ -6,41 +6,41 @@ import ( "github.com/koopa0/koopa/internal/security" ) -func TestSystemTools_Constructor(t *testing.T) { +func TestSystem_Constructor(t *testing.T) { t.Run("valid inputs", func(t *testing.T) { cmdVal := security.NewCommand() envVal := security.NewEnv() - st, err := NewSystemTools(cmdVal, envVal, testLogger()) + st, err := NewSystem(cmdVal, envVal, testLogger()) if err != nil { - t.Errorf("NewSystemTools() error = %v, want nil", err) + t.Errorf("NewSystem() error = %v, want nil", err) } if st == nil { - t.Error("NewSystemTools() returned nil, want non-nil") + t.Error("NewSystem() returned nil, want non-nil") } }) t.Run("nil command validator", func(t *testing.T) { envVal := security.NewEnv() - st, err := NewSystemTools(nil, envVal, testLogger()) + st, err := NewSystem(nil, envVal, testLogger()) if err == nil { - t.Error("NewSystemTools() error = nil, want error") + t.Error("NewSystem() error = nil, want error") } if st != nil { - t.Error("NewSystemTools() returned non-nil, want nil") + t.Error("NewSystem() returned non-nil, want nil") } }) t.Run("nil env validator", func(t *testing.T) { cmdVal := security.NewCommand() - st, err := NewSystemTools(cmdVal, nil, testLogger()) + st, err := NewSystem(cmdVal, nil, testLogger()) if err == nil { - t.Error("NewSystemTools() error = nil, want error") + t.Error("NewSystem() error = nil, want error") } if st != nil { - t.Error("NewSystemTools() returned non-nil, want nil") + t.Error("NewSystem() returned non-nil, want nil") } }) @@ -48,59 +48,30 @@ func TestSystemTools_Constructor(t *testing.T) { cmdVal := security.NewCommand() envVal := security.NewEnv() - st, err := NewSystemTools(cmdVal, envVal, nil) + st, err := NewSystem(cmdVal, envVal, nil) if err == nil { - t.Error("NewSystemTools() error = nil, want error") + t.Error("NewSystem() error = nil, want error") } if st != nil { - t.Error("NewSystemTools() returned non-nil, want nil") + t.Error("NewSystem() returned non-nil, want nil") } }) } func TestSystemToolConstants(t *testing.T) { expectedNames := map[string]string{ - "ToolCurrentTime": "current_time", - "ToolExecuteCommand": "execute_command", - "ToolGetEnv": "get_env", + "CurrentTimeName": "current_time", + "ExecuteCommandName": "execute_command", + "GetEnvName": "get_env", } - if ToolCurrentTime != expectedNames["ToolCurrentTime"] { - t.Errorf("ToolCurrentTime = %q, want %q", ToolCurrentTime, expectedNames["ToolCurrentTime"]) + if CurrentTimeName != expectedNames["CurrentTimeName"] { + t.Errorf("CurrentTimeName = %q, want %q", CurrentTimeName, expectedNames["CurrentTimeName"]) } - if ToolExecuteCommand != expectedNames["ToolExecuteCommand"] { - t.Errorf("ToolExecuteCommand = %q, want %q", ToolExecuteCommand, expectedNames["ToolExecuteCommand"]) + if ExecuteCommandName != expectedNames["ExecuteCommandName"] { + t.Errorf("ExecuteCommandName = %q, want %q", ExecuteCommandName, expectedNames["ExecuteCommandName"]) } - if ToolGetEnv != expectedNames["ToolGetEnv"] { - t.Errorf("ToolGetEnv = %q, want %q", ToolGetEnv, expectedNames["ToolGetEnv"]) + if GetEnvName != expectedNames["GetEnvName"] { + t.Errorf("GetEnvName = %q, want %q", GetEnvName, expectedNames["GetEnvName"]) } } - -func TestExecuteCommandInput(t *testing.T) { - input := ExecuteCommandInput{ - Command: "ls", - Args: []string{"-la", "/tmp"}, - } - if input.Command != "ls" { - t.Errorf("ExecuteCommandInput.Command = %q, want %q", input.Command, "ls") - } - if len(input.Args) != 2 { - t.Errorf("ExecuteCommandInput.Args length = %d, want 2", len(input.Args)) - } - if input.Args[0] != "-la" { - t.Errorf("ExecuteCommandInput.Args[0] = %q, want %q", input.Args[0], "-la") - } -} - -func TestGetEnvInput(t *testing.T) { - input := GetEnvInput{Key: "PATH"} - if input.Key != "PATH" { - t.Errorf("GetEnvInput.Key = %q, want %q", input.Key, "PATH") - } -} - -func TestCurrentTimeInput(t *testing.T) { - // CurrentTimeInput is an empty struct - input := CurrentTimeInput{} - _ = input // Just verify it can be created -} diff --git a/internal/tools/types.go b/internal/tools/types.go index bcbe524..32e20cd 100644 --- a/internal/tools/types.go +++ b/internal/tools/types.go @@ -1,15 +1,5 @@ -// Package tools provides tool types and result helpers for agent tool operations. -// -// Error Handling: -// - All tools return Result with structured error information -// - Business errors (validation, not found, etc.) use Result.Error -// - Only infrastructure errors (context cancellation) return Go error package tools -// ============================================================================ -// Status and ErrorCode (for JSON responses to LLM) -// ============================================================================ - // Status represents the execution status of a tool. type Status string @@ -35,10 +25,6 @@ const ( ErrCodeValidation ErrorCode = "ValidationError" ) -// ============================================================================ -// Result Types (for structured JSON responses) -// ============================================================================ - // Result is the standard return format for all tools. type Result struct { Status Status `json:"status" jsonschema_description:"The execution status"` diff --git a/internal/tools/types_test.go b/internal/tools/types_test.go index 8f735c9..436a50b 100644 --- a/internal/tools/types_test.go +++ b/internal/tools/types_test.go @@ -10,17 +10,17 @@ func TestResult_Success(t *testing.T) { result := Result{Status: StatusSuccess, Data: data} if result.Status != StatusSuccess { - t.Errorf("Status = %v, want %v", result.Status, StatusSuccess) + t.Errorf("Result{...}.Status = %v, want %v", result.Status, StatusSuccess) } if result.Data == nil { - t.Fatal("Data is nil, want non-nil") + t.Fatal("Result{...}.Data is nil, want non-nil") } dataMap, ok := result.Data.(map[string]any) if !ok { - t.Fatalf("Data type = %T, want map[string]any", result.Data) + t.Fatalf("Result{...}.Data type = %T, want map[string]any", result.Data) } if dataMap["path"] != "/tmp/test" { - t.Errorf("Data[path] = %v, want /tmp/test", dataMap["path"]) + t.Errorf("Result{...}.Data[\"path\"] = %v, want %q", dataMap["path"], "/tmp/test") } }) @@ -28,10 +28,10 @@ func TestResult_Success(t *testing.T) { result := Result{Status: StatusSuccess} if result.Status != StatusSuccess { - t.Errorf("Status = %v, want %v", result.Status, StatusSuccess) + t.Errorf("Result{...}.Status = %v, want %v", result.Status, StatusSuccess) } if result.Data != nil { - t.Errorf("Data = %v, want nil", result.Data) + t.Errorf("Result{...}.Data = %v, want nil", result.Data) } }) } @@ -42,9 +42,9 @@ func TestResult_Error(t *testing.T) { code ErrorCode message string }{ - {"security error", ErrCodeSecurity, "access denied"}, - {"not found error", ErrCodeNotFound, "file not found"}, - {"execution error", ErrCodeExecution, "command failed"}, + {name: "security error", code: ErrCodeSecurity, message: "access denied"}, + {name: "not found error", code: ErrCodeNotFound, message: "file not found"}, + {name: "execution error", code: ErrCodeExecution, message: "executing command"}, } for _, tt := range tests { @@ -55,19 +55,19 @@ func TestResult_Error(t *testing.T) { } if result.Status != StatusError { - t.Errorf("Status = %v, want %v", result.Status, StatusError) + t.Errorf("Result{...}.Status = %v, want %v", result.Status, StatusError) } if result.Data != nil { - t.Errorf("Data = %v, want nil", result.Data) + t.Errorf("Result{...}.Data = %v, want nil", result.Data) } if result.Error == nil { - t.Fatal("Error is nil, want non-nil") + t.Fatal("Result{...}.Error is nil, want non-nil") } if result.Error.Code != tt.code { - t.Errorf("Error.Code = %v, want %v", result.Error.Code, tt.code) + t.Errorf("Result{...}.Error.Code = %v, want %v", result.Error.Code, tt.code) } if result.Error.Message != tt.message { - t.Errorf("Error.Message = %v, want %v", result.Error.Message, tt.message) + t.Errorf("Result{...}.Error.Message = %q, want %q", result.Error.Message, tt.message) } }) } @@ -83,35 +83,35 @@ func TestResult_ErrorWithDetails(t *testing.T) { Status: StatusError, Error: &Error{ Code: ErrCodeExecution, - Message: "command failed", + Message: "executing command", Details: details, }, } if result.Status != StatusError { - t.Errorf("Status = %v, want %v", result.Status, StatusError) + t.Errorf("Result{...}.Status = %v, want %v", result.Status, StatusError) } if result.Error == nil { - t.Fatal("Error is nil, want non-nil") + t.Fatal("Result{...}.Error is nil, want non-nil") } if result.Error.Details == nil { - t.Error("Error.Details is nil, want non-nil") + t.Error("Result{...}.Error.Details is nil, want non-nil") } detailsMap, ok := result.Error.Details.(map[string]any) if !ok { - t.Fatalf("Error.Details type = %T, want map[string]any", result.Error.Details) + t.Fatalf("Result{...}.Error.Details type = %T, want map[string]any", result.Error.Details) } if detailsMap["command"] != "ls" { - t.Errorf("Error.Details[command] = %v, want ls", detailsMap["command"]) + t.Errorf("Result{...}.Error.Details[\"command\"] = %v, want %q", detailsMap["command"], "ls") } } func TestStatusConstants(t *testing.T) { if StatusSuccess != "success" { - t.Errorf("StatusSuccess = %v, want success", StatusSuccess) + t.Errorf("StatusSuccess = %q, want %q", StatusSuccess, "success") } if StatusError != "error" { - t.Errorf("StatusError = %v, want error", StatusError) + t.Errorf("StatusError = %q, want %q", StatusError, "error") } } @@ -127,9 +127,9 @@ func TestErrorCodeConstants(t *testing.T) { ErrCodeValidation: "ValidationError", } - for code, expected := range codes { - if string(code) != expected { - t.Errorf("ErrorCode %v = %v, want %v", code, string(code), expected) + for code, want := range codes { + if string(code) != want { + t.Errorf("ErrorCode(%q) = %q, want %q", code, string(code), want) } } } diff --git a/internal/tui/benchmark_test.go b/internal/tui/benchmark_test.go index 0e2569a..b898761 100644 --- a/internal/tui/benchmark_test.go +++ b/internal/tui/benchmark_test.go @@ -8,15 +8,15 @@ import ( "charm.land/bubbles/v2/textarea" tea "charm.land/bubbletea/v2" - "github.com/koopa0/koopa/internal/agent/chat" + "github.com/koopa0/koopa/internal/chat" ) -// newBenchmarkTUI creates a TUI for benchmarking with minimal setup. -func newBenchmarkTUI() *TUI { +// newBenchmarkModel creates a Model for benchmarking with minimal setup. +func newBenchmarkModel() *Model { ta := textarea.New() ta.SetHeight(3) ta.ShowLineNumbers = false - return &TUI{ + return &Model{ state: StateInput, input: ta, history: make([]string, 0, maxHistory), @@ -32,7 +32,7 @@ func newBenchmarkTUI() *TUI { // BenchmarkTUI_View measures View rendering performance. func BenchmarkTUI_View(b *testing.B) { b.Run("empty", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() b.ResetTimer() b.ReportAllocs() for b.Loop() { @@ -41,7 +41,7 @@ func BenchmarkTUI_View(b *testing.B) { }) b.Run("10_messages", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() for i := 0; i < 10; i++ { tui.addMessage(Message{Role: "user", Text: "Hello, this is a test message"}) tui.addMessage(Message{Role: "assistant", Text: "This is a response with some content"}) @@ -54,7 +54,7 @@ func BenchmarkTUI_View(b *testing.B) { }) b.Run("50_messages", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() for i := 0; i < 50; i++ { tui.addMessage(Message{Role: "user", Text: "Hello, this is a test message"}) tui.addMessage(Message{Role: "assistant", Text: "This is a response with some content"}) @@ -67,7 +67,7 @@ func BenchmarkTUI_View(b *testing.B) { }) b.Run("max_messages", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() for i := 0; i < maxMessages; i++ { tui.addMessage(Message{Role: "user", Text: "Hello, this is a test message with some longer content to simulate real usage"}) } @@ -79,7 +79,7 @@ func BenchmarkTUI_View(b *testing.B) { }) b.Run("streaming_state", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() tui.state = StateStreaming tui.output.WriteString("This is streaming output that is being written in real-time...") for i := 0; i < 10; i++ { @@ -94,7 +94,7 @@ func BenchmarkTUI_View(b *testing.B) { }) b.Run("thinking_state", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() tui.state = StateThinking b.ResetTimer() b.ReportAllocs() @@ -104,7 +104,7 @@ func BenchmarkTUI_View(b *testing.B) { }) b.Run("large_messages", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() largeText := strings.Repeat("This is a large message with lots of content. ", 100) for i := 0; i < 20; i++ { tui.addMessage(Message{Role: "assistant", Text: largeText}) @@ -120,7 +120,7 @@ func BenchmarkTUI_View(b *testing.B) { // BenchmarkTUI_AddMessage measures message addition performance. func BenchmarkTUI_AddMessage(b *testing.B) { b.Run("single", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() msg := Message{Role: "user", Text: "Hello"} b.ResetTimer() b.ReportAllocs() @@ -131,7 +131,7 @@ func BenchmarkTUI_AddMessage(b *testing.B) { }) b.Run("with_bounds_check", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() // Pre-fill to near capacity for i := 0; i < maxMessages-1; i++ { tui.messages = append(tui.messages, Message{Role: "user", Text: "test"}) @@ -149,7 +149,7 @@ func BenchmarkTUI_AddMessage(b *testing.B) { }) b.Run("at_capacity", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() // Fill to capacity for i := 0; i < maxMessages; i++ { tui.messages = append(tui.messages, Message{Role: "user", Text: "test"}) @@ -166,30 +166,30 @@ func BenchmarkTUI_AddMessage(b *testing.B) { // BenchmarkTUI_Update measures Update loop performance. func BenchmarkTUI_Update(b *testing.B) { b.Run("key_press", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() key := tea.Key{Code: 'a'} msg := tea.KeyPressMsg(key) b.ResetTimer() b.ReportAllocs() for b.Loop() { model, _ := tui.Update(msg) - tui = model.(*TUI) + tui = model.(*Model) } }) b.Run("window_resize", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() msg := tea.WindowSizeMsg{Width: 120, Height: 40} b.ResetTimer() b.ReportAllocs() for b.Loop() { model, _ := tui.Update(msg) - tui = model.(*TUI) + tui = model.(*Model) } }) b.Run("stream_text_msg", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() tui.state = StateStreaming eventCh := make(chan streamEvent, 1) tui.streamEventCh = eventCh @@ -199,7 +199,7 @@ func BenchmarkTUI_Update(b *testing.B) { b.ReportAllocs() for b.Loop() { model, _ := tui.Update(msg) - tui = model.(*TUI) + tui = model.(*Model) tui.output.Reset() // Reset to avoid unbounded growth } }) @@ -208,14 +208,14 @@ func BenchmarkTUI_Update(b *testing.B) { // BenchmarkTUI_NavigateHistory measures history navigation performance. func BenchmarkTUI_NavigateHistory(b *testing.B) { b.Run("small_history", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() tui.history = []string{"one", "two", "three"} tui.historyIdx = 1 b.ResetTimer() b.ReportAllocs() for b.Loop() { model, _ := tui.navigateHistory(-1) - tui = model.(*TUI) + tui = model.(*Model) if tui.historyIdx == 0 { tui.historyIdx = len(tui.history) } @@ -223,7 +223,7 @@ func BenchmarkTUI_NavigateHistory(b *testing.B) { }) b.Run("large_history", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() for i := 0; i < maxHistory; i++ { tui.history = append(tui.history, "history entry "+string(rune('a'+i%26))) } @@ -232,7 +232,7 @@ func BenchmarkTUI_NavigateHistory(b *testing.B) { b.ReportAllocs() for b.Loop() { model, _ := tui.navigateHistory(-1) - tui = model.(*TUI) + tui = model.(*Model) if tui.historyIdx == 0 { tui.historyIdx = len(tui.history) } @@ -376,18 +376,18 @@ func BenchmarkStyles(b *testing.B) { // BenchmarkTUI_HandleSlashCommand measures slash command handling performance. func BenchmarkTUI_HandleSlashCommand(b *testing.B) { b.Run("help", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() b.ResetTimer() b.ReportAllocs() for b.Loop() { tui.messages = tui.messages[:0] // Reset messages model, _ := tui.handleSlashCommand("/help") - tui = model.(*TUI) + tui = model.(*Model) } }) b.Run("clear", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() for i := 0; i < 10; i++ { tui.addMessage(Message{Role: "user", Text: "test"}) } @@ -396,18 +396,18 @@ func BenchmarkTUI_HandleSlashCommand(b *testing.B) { for b.Loop() { tui.messages = []Message{{Role: "user", Text: "test"}} model, _ := tui.handleSlashCommand("/clear") - tui = model.(*TUI) + tui = model.(*Model) } }) b.Run("unknown", func(b *testing.B) { - tui := newBenchmarkTUI() + tui := newBenchmarkModel() b.ResetTimer() b.ReportAllocs() for b.Loop() { tui.messages = tui.messages[:0] model, _ := tui.handleSlashCommand("/unknown") - tui = model.(*TUI) + tui = model.(*Model) } }) } diff --git a/internal/tui/fuzz_test.go b/internal/tui/fuzz_test.go index 5cf1fda..d73e5ac 100644 --- a/internal/tui/fuzz_test.go +++ b/internal/tui/fuzz_test.go @@ -31,12 +31,12 @@ func FuzzTUI_HandleSlashCommand(f *testing.F) { return } - tui := newTestTUI() + tui := newTestModel() tui.messages = []Message{{Role: "user", Text: "hello"}} // Should never panic model, resultCmd := tui.handleSlashCommand(cmd) - result := model.(*TUI) + result := model.(*Model) // Basic invariants if result == nil { @@ -71,13 +71,13 @@ func FuzzTUI_NavigateHistory(f *testing.F) { f.Add(-1000000) f.Fuzz(func(t *testing.T, delta int) { - tui := newTestTUI() + tui := newTestModel() tui.history = []string{"first", "second", "third"} tui.historyIdx = 1 // Should never panic model, _ := tui.navigateHistory(delta) - result := model.(*TUI) + result := model.(*Model) // Index should be within bounds if result.historyIdx < 0 { @@ -104,7 +104,7 @@ func FuzzTUI_AddMessage(f *testing.F) { f.Add("user", "\x00\x01\x02") // Binary f.Fuzz(func(t *testing.T, role, text string) { - tui := newTestTUI() + tui := newTestModel() // Should never panic tui.addMessage(Message{Role: role, Text: text}) @@ -138,7 +138,7 @@ func FuzzTUI_KeyPress(f *testing.F) { f.Add(int32(tea.KeySpace), 0) //nolint:unconvert // f.Add requires exact types f.Fuzz(func(t *testing.T, code int32, mod int) { - tui := newTestTUI() + tui := newTestModel() // Use background context to avoid nil pointer issues tui.ctx = context.Background() @@ -166,7 +166,7 @@ func FuzzTUI_View(f *testing.F) { f.Add(0, 10000, 1) // Very wide f.Fuzz(func(t *testing.T, state, width, height int) { - tui := newTestTUI() + tui := newTestModel() // Set state (bounded to valid values) if state >= 0 && state <= 2 { @@ -219,7 +219,7 @@ func FuzzMarkdownRenderer_Render(f *testing.F) { f.Fuzz(func(t *testing.T, markdown string) { mr := newMarkdownRenderer(80) if mr == nil { - t.Skip("Failed to create markdown renderer") + t.Skip("newMarkdownRenderer() returned nil") } // Should never panic @@ -251,7 +251,7 @@ func FuzzMarkdownRenderer_UpdateWidth(f *testing.F) { f.Fuzz(func(t *testing.T, width int) { mr := newMarkdownRenderer(80) if mr == nil { - t.Skip("Failed to create markdown renderer") + t.Skip("newMarkdownRenderer() returned nil") } // Should never panic diff --git a/internal/tui/i18n.go b/internal/tui/i18n.go new file mode 100644 index 0000000..c443bc6 --- /dev/null +++ b/internal/tui/i18n.go @@ -0,0 +1,27 @@ +package tui + +// toolDisplayNames maps tool names to localized display names. +var toolDisplayNames = map[string]string{ + "web_search": "搜尋網路", + "web_fetch": "讀取網頁", + "read_file": "讀取檔案", + "write_file": "寫入檔案", + "list_files": "瀏覽目錄", + "delete_file": "刪除檔案", + "get_file_info": "取得檔案資訊", + "execute_command": "執行命令", + "current_time": "取得時間", + "get_env": "取得環境變數", + "search_history": "搜尋對話記錄", + "search_documents": "搜尋知識庫", + "search_system_knowledge": "搜尋系統知識", + "knowledge_store": "儲存知識", +} + +// toolDisplayName returns a localized display name for a tool. +func toolDisplayName(name string) string { + if display, ok := toolDisplayNames[name]; ok { + return display + } + return name +} diff --git a/internal/tui/integration_test.go b/internal/tui/integration_test.go index 08f4bde..5a80ea9 100644 --- a/internal/tui/integration_test.go +++ b/internal/tui/integration_test.go @@ -10,7 +10,8 @@ import ( "go.uber.org/goleak" - "github.com/koopa0/koopa/internal/agent/chat" + "github.com/google/uuid" + "github.com/koopa0/koopa/internal/chat" ) func TestMain(m *testing.M) { @@ -25,17 +26,17 @@ func TestMain(m *testing.M) { } // createTestSession creates a session in the database and returns its ID and cleanup function. -func createTestSession(t *testing.T, setup *chatFlowSetup) (string, func()) { +func createTestSession(t *testing.T, setup *chatFlowSetup) (uuid.UUID, func()) { t.Helper() - sess, err := setup.SessionStore.CreateSession(setup.Ctx, "test-session", "gemini-2.0-flash", "") + sess, err := setup.SessionStore.CreateSession(setup.Ctx, "test-session") if err != nil { - t.Fatalf("Failed to create test session: %v", err) + t.Fatalf("CreateSession() error: %v", err) } cleanup := func() { // Use background context for cleanup since test context may be canceled _ = setup.SessionStore.DeleteSession(context.Background(), sess.ID) } - return sess.ID.String(), cleanup + return sess.ID, cleanup } func TestTUI_Integration_StartStream_Success(t *testing.T) { @@ -50,7 +51,7 @@ func TestTUI_Integration_StartStream_Success(t *testing.T) { defer sessionCleanup() tui, err := New(setup.Ctx, setup.Flow, sessionID) if err != nil { - t.Fatalf("Failed to create TUI: %v", err) + t.Fatalf("New() error: %v", err) } // Start a stream with a simple query @@ -132,7 +133,7 @@ func TestTUI_Integration_StartStream_Cancellation(t *testing.T) { defer sessionCleanup() tui, err := New(setup.Ctx, setup.Flow, sessionID) if err != nil { - t.Fatalf("Failed to create TUI: %v", err) + t.Fatalf("New() error: %v", err) } // Start a stream with a long query @@ -198,7 +199,7 @@ func TestTUI_Integration_HandleSubmit_StateTransition(t *testing.T) { defer sessionCleanup() tui, err := New(setup.Ctx, setup.Flow, sessionID) if err != nil { - t.Fatalf("Failed to create TUI: %v", err) + t.Fatalf("New() error: %v", err) } // Set input value @@ -206,7 +207,7 @@ func TestTUI_Integration_HandleSubmit_StateTransition(t *testing.T) { // Call handleSubmit model, cmd := tui.handleSubmit() - result := model.(*TUI) + result := model.(*Model) // Verify state changed to thinking if result.state != StateThinking { @@ -355,13 +356,13 @@ func TestTUI_Integration_ViewDuringStreaming(t *testing.T) { defer sessionCleanup() tui, err := New(setup.Ctx, setup.Flow, sessionID) if err != nil { - t.Fatalf("Failed to create TUI: %v", err) + t.Fatalf("New() error: %v", err) } // Start streaming tui.input.SetValue("Tell me about Go programming") model, cmd := tui.handleSubmit() - tui = model.(*TUI) + tui = model.(*Model) // Call View during different states and verify no panic _ = tui.View() @@ -380,7 +381,7 @@ func TestTUI_Integration_ViewDuringStreaming(t *testing.T) { break } model, cmd = tui.Update(msg) - tui = model.(*TUI) + tui = model.(*Model) // View should work in any state _ = tui.View() diff --git a/internal/tui/keys.go b/internal/tui/keys.go index e281d8b..8ea5f07 100644 --- a/internal/tui/keys.go +++ b/internal/tui/keys.go @@ -42,193 +42,193 @@ func newKeyMap() keyMap { } //nolint:gocyclo // Keyboard handler requires branching for all key combinations -func (t *TUI) handleKey(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { +func (m *Model) handleKey(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { k := msg.Key() // Check for Ctrl modifier if k.Mod&tea.ModCtrl != 0 { switch k.Code { case 'c': - return t.handleCtrlC() + return m.handleCtrlC() case 'd': - cmd := t.cleanup() - return t, cmd + cmd := m.cleanup() + return m, cmd } } // Check special keys switch k.Code { case tea.KeyEnter: - if t.state == StateInput { + if m.state == StateInput { // Enter without Shift = submit // Shift+Enter = newline (pass through to textarea) if k.Mod&tea.ModShift == 0 { - return t.handleSubmit() + return m.handleSubmit() } } case tea.KeyUp: // Up at first line navigates history, otherwise pass to textarea - if t.state == StateInput && t.input.Line() == 0 { - return t.navigateHistory(-1) + if m.state == StateInput && m.input.Line() == 0 { + return m.navigateHistory(-1) } case tea.KeyDown: // Down at last line navigates history, otherwise pass to textarea - if t.state == StateInput && t.input.Line() == t.input.LineCount()-1 { - return t.navigateHistory(1) + if m.state == StateInput && m.input.Line() == m.input.LineCount()-1 { + return m.navigateHistory(1) } case tea.KeyEscape: - if t.state == StateStreaming || t.state == StateThinking { - t.cancelStream() - t.state = StateInput - t.output.Reset() - return t, nil + if m.state == StateStreaming || m.state == StateThinking { + m.cancelStream() + m.state = StateInput + m.output.Reset() + return m, nil } case tea.KeyPgUp: - t.viewport.PageUp() - return t, nil + m.viewport.PageUp() + return m, nil case tea.KeyPgDown: - t.viewport.PageDown() - return t, nil + m.viewport.PageDown() + return m, nil } // Pass keys to textarea for typing - ALWAYS allow typing even during streaming // Better UX: users can prepare next message while LLM responds var cmd tea.Cmd - t.input, cmd = t.input.Update(msg) - return t, cmd + m.input, cmd = m.input.Update(msg) + return m, cmd } -func (t *TUI) handleCtrlC() (tea.Model, tea.Cmd) { +func (m *Model) handleCtrlC() (tea.Model, tea.Cmd) { now := time.Now() // Double Ctrl+C within 1 second = quit - if now.Sub(t.lastCtrlC) < time.Second { - cmd := t.cleanup() - return t, cmd + if now.Sub(m.lastCtrlC) < time.Second { + cmd := m.cleanup() + return m, cmd } - t.lastCtrlC = now + m.lastCtrlC = now - switch t.state { + switch m.state { case StateInput: - t.input.Reset() - return t, nil + m.input.Reset() + return m, nil case StateThinking, StateStreaming: - t.cancelStream() - t.state = StateInput - t.output.Reset() - t.addMessage(Message{Role: "system", Text: "(Canceled)"}) - return t, nil + m.cancelStream() + m.state = StateInput + m.output.Reset() + m.addMessage(Message{Role: roleSystem, Text: "(Canceled)"}) + return m, nil } - return t, nil + return m, nil } -func (t *TUI) handleSubmit() (tea.Model, tea.Cmd) { - query := strings.TrimSpace(t.input.Value()) +func (m *Model) handleSubmit() (tea.Model, tea.Cmd) { + query := strings.TrimSpace(m.input.Value()) if query == "" { - return t, nil + return m, nil } // Handle slash commands if strings.HasPrefix(query, "/") { - return t.handleSlashCommand(query) + return m.handleSlashCommand(query) } // Add to history (enforce maxHistory cap) - t.history = append(t.history, query) - if len(t.history) > maxHistory { + m.history = append(m.history, query) + if len(m.history) > maxHistory { // Remove oldest entries to stay within bounds - t.history = t.history[len(t.history)-maxHistory:] + m.history = m.history[len(m.history)-maxHistory:] } - t.historyIdx = len(t.history) + m.historyIdx = len(m.history) // Add user message - t.addMessage(Message{Role: "user", Text: query}) + m.addMessage(Message{Role: roleUser, Text: query}) // Clear input - t.input.Reset() + m.input.Reset() // Start thinking - t.state = StateThinking + m.state = StateThinking - return t, tea.Batch( - t.spinner.Tick, - t.startStream(query), + return m, tea.Batch( + m.spinner.Tick, + m.startStream(query), ) } -func (t *TUI) handleSlashCommand(cmd string) (tea.Model, tea.Cmd) { +func (m *Model) handleSlashCommand(cmd string) (tea.Model, tea.Cmd) { switch cmd { case cmdHelp: - t.addMessage(Message{ + m.addMessage(Message{ Role: roleSystem, Text: "Commands: " + cmdHelp + ", " + cmdClear + ", " + cmdExit + "\nShortcuts:\n Enter: send message\n Shift+Enter: new line\n Ctrl+C: cancel/clear\n Ctrl+D: exit\n Up/Down: history\n PgUp/PgDn: scroll", }) case cmdClear: - t.messages = nil + m.messages = nil case cmdExit, cmdQuit: - cleanupCmd := t.cleanup() - return t, cleanupCmd + cleanupCmd := m.cleanup() + return m, cleanupCmd default: - t.addMessage(Message{ + m.addMessage(Message{ Role: roleError, Text: "Unknown command: " + cmd, }) } - t.input.Reset() - return t, nil + m.input.Reset() + return m, nil } -func (t *TUI) navigateHistory(delta int) (tea.Model, tea.Cmd) { - if len(t.history) == 0 { - return t, nil +func (m *Model) navigateHistory(delta int) (tea.Model, tea.Cmd) { + if len(m.history) == 0 { + return m, nil } - t.historyIdx += delta + m.historyIdx += delta - if t.historyIdx < 0 { - t.historyIdx = 0 + if m.historyIdx < 0 { + m.historyIdx = 0 } - if t.historyIdx > len(t.history) { - t.historyIdx = len(t.history) + if m.historyIdx > len(m.history) { + m.historyIdx = len(m.history) } - if t.historyIdx == len(t.history) { - t.input.SetValue("") + if m.historyIdx == len(m.history) { + m.input.SetValue("") } else { - t.input.SetValue(t.history[t.historyIdx]) + m.input.SetValue(m.history[m.historyIdx]) // Move cursor to end of text - t.input.CursorEnd() + m.input.CursorEnd() } - return t, nil + return m, nil } -func (t *TUI) cancelStream() { - if t.streamCancel != nil { - t.streamCancel() - t.streamCancel = nil +func (m *Model) cancelStream() { + if m.streamCancel != nil { + m.streamCancel() + m.streamCancel = nil } } // cleanup cancels any active stream and returns the quit command. // Waits for goroutine exit with timeout to prevent resource leaks. -func (t *TUI) cleanup() tea.Cmd { - // Cancel main context first - this triggers all goroutines using t.ctx - if t.ctxCancel != nil { - t.ctxCancel() - t.ctxCancel = nil +func (m *Model) cleanup() tea.Cmd { + // Cancel main context first - this triggers all goroutines using m.ctx + if m.ctxCancel != nil { + m.ctxCancel() + m.ctxCancel = nil } // Then cancel stream-specific context (may already be canceled via parent) - t.cancelStream() - t.streamEventCh = nil + m.cancelStream() + m.streamEventCh = nil return tea.Quit } diff --git a/internal/tui/model.go b/internal/tui/model.go new file mode 100644 index 0000000..467dabc --- /dev/null +++ b/internal/tui/model.go @@ -0,0 +1,198 @@ +// Package tui provides Bubble Tea terminal interface for Koopa. +package tui + +import ( + "context" + "errors" + "strings" + "time" + + "charm.land/bubbles/v2/help" + "charm.land/bubbles/v2/spinner" + "charm.land/bubbles/v2/textarea" + "charm.land/bubbles/v2/viewport" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" + + "github.com/google/uuid" + "github.com/koopa0/koopa/internal/chat" +) + +// State represents TUI state machine. +type State int + +// TUI state machine states. +const ( + StateInput State = iota // Awaiting user input + StateThinking // Processing request + StateStreaming // Streaming response +) + +// Memory bounds to prevent unbounded growth. +const ( + maxMessages = 100 // Maximum messages stored + maxHistory = 100 // Maximum command history entries +) + +// Timeout constants for stream operations. +const streamTimeout = 5 * time.Minute // Maximum time for a single stream + +// Message role constants for consistent display. +const ( + roleUser = "user" + roleAssistant = "assistant" + roleSystem = "system" + roleError = "error" +) + +// Layout constants for viewport height calculation. +const ( + separatorLines = 2 // Two separator lines (above and below input) + helpLines = 1 // Help bar height + promptLines = 1 // Prompt prefix line + minViewport = 3 // Minimum viewport height +) + +// Message represents a conversation message for display. +type Message struct { + Role string // "user", "assistant", "system", "error" + Text string +} + +// Model is the Bubble Tea model for Koopa terminal interface. +type Model struct { + // Input (textarea for multi-line support, Shift+Enter for newline) + input textarea.Model + history []string + historyIdx int + + // State + state State + lastCtrlC time.Time + + // Output + spinner spinner.Model + output strings.Builder + viewBuf strings.Builder // Reusable buffer for View() to reduce allocations + messages []Message + + // Scrollable message viewport + viewport viewport.Model + + // Help bar for keyboard shortcuts + help help.Model + keys keyMap + + // Stream management + // Note: No sync.WaitGroup - Bubble Tea's event loop provides synchronization. + // Single union channel with discriminated events simplifies select logic. + streamCancel context.CancelFunc + streamEventCh <-chan streamEvent + toolStatus string // Current tool status (e.g., "搜尋網路..."), empty when idle + + // Dependencies (direct, no interface) + chatFlow *chat.Flow + sessionID uuid.UUID + ctx context.Context + ctxCancel context.CancelFunc // For canceling all operations on exit + + // Dimensions + width int + height int + + // Styles + styles Styles + + // Markdown rendering (nil = graceful degradation to plain text) + markdown *markdownRenderer +} + +// addMessage appends a message and enforces maxMessages bound. +func (m *Model) addMessage(msg Message) { + m.messages = append(m.messages, msg) + if len(m.messages) > maxMessages { + // Remove oldest messages to stay within bounds + m.messages = m.messages[len(m.messages)-maxMessages:] + } +} + +// New creates a Model for chat interaction. +// Returns error if required dependencies are nil. +// +// IMPORTANT: ctx MUST be the same context passed to tea.WithContext() +// to ensure consistent cancellation behavior. +func New(ctx context.Context, flow *chat.Flow, sessionID uuid.UUID) (*Model, error) { + if flow == nil { + return nil, errors.New("tui.New: flow is required") + } + if ctx == nil { + return nil, errors.New("tui.New: ctx is required") + } + if sessionID == uuid.Nil { + return nil, errors.New("tui.New: session ID is required") + } + + // Create cancellable context for cleanup on exit + ctx, cancel := context.WithCancel(ctx) + + // Create textarea for multi-line input + // Enter submits, Shift+Enter adds newline (default behavior) + ta := textarea.New() + ta.Placeholder = "Ask anything..." + ta.SetHeight(1) // Single line by default + ta.SetWidth(120) // Wide enough for long text, updated on WindowSizeMsg + ta.MaxWidth = 0 // No max width limit + ta.ShowLineNumbers = false + + // Clean, minimal styling like Claude Code / Gemini CLI + // No background colors, just simple text + cleanStyle := textarea.StyleState{ + Base: lipgloss.NewStyle(), + Text: lipgloss.NewStyle(), + Placeholder: lipgloss.NewStyle().Foreground(lipgloss.Color("240")), // Gray placeholder + Prompt: lipgloss.NewStyle(), + } + ta.SetStyles(textarea.Styles{ + Focused: cleanStyle, + Blurred: cleanStyle, + }) + ta.Focus() + + sp := spinner.New() + sp.Spinner = spinner.Dot + + // Create viewport for scrollable message history. + // Disable built-in keyboard handling — we route keys explicitly + // in handleKey to avoid conflicts with textarea/history navigation. + vp := viewport.New(viewport.WithWidth(80), viewport.WithHeight(20)) + vp.MouseWheelEnabled = true + vp.SoftWrap = true + vp.KeyMap = viewport.KeyMap{} // Disable default key bindings + + h := help.New() + + return &Model{ + chatFlow: flow, + sessionID: sessionID, + ctx: ctx, + ctxCancel: cancel, + input: ta, + spinner: sp, + viewport: vp, + help: h, + keys: newKeyMap(), + styles: DefaultStyles(), + history: make([]string, 0, maxHistory), + markdown: newMarkdownRenderer(80), + width: 80, // Default width until WindowSizeMsg arrives + }, nil +} + +// Init implements tea.Model. +func (m *Model) Init() tea.Cmd { + return tea.Batch( + textarea.Blink, + m.spinner.Tick, + m.input.Focus(), // Ensure textarea is focused on startup + ) +} diff --git a/internal/tui/setup_test.go b/internal/tui/setup_test.go index 2ee1653..9dbfb63 100644 --- a/internal/tui/setup_test.go +++ b/internal/tui/setup_test.go @@ -12,12 +12,10 @@ package tui import ( "context" - "fmt" "io" "log/slog" "os" "path/filepath" - "runtime" "testing" "time" @@ -26,12 +24,13 @@ import ( "github.com/firebase/genkit/go/plugins/postgresql" "github.com/jackc/pgx/v5/pgxpool" - "github.com/koopa0/koopa/internal/agent/chat" + "github.com/koopa0/koopa/internal/chat" "github.com/koopa0/koopa/internal/config" "github.com/koopa0/koopa/internal/rag" "github.com/koopa0/koopa/internal/security" "github.com/koopa0/koopa/internal/session" "github.com/koopa0/koopa/internal/sqlc" + "github.com/koopa0/koopa/internal/testutil" "github.com/koopa0/koopa/internal/tools" ) @@ -44,27 +43,6 @@ type chatFlowSetup struct { Cancel context.CancelFunc } -// findProjectRoot finds the project root directory by looking for go.mod. -func findProjectRoot() (string, error) { - _, filename, _, ok := runtime.Caller(0) - if !ok { - return "", fmt.Errorf("runtime.Caller failed to get caller info") - } - - dir := filepath.Dir(filename) - for { - goModPath := filepath.Join(dir, "go.mod") - if _, err := os.Stat(goModPath); err == nil { - return dir, nil - } - parent := filepath.Dir(dir) - if parent == dir { - return "", fmt.Errorf("go.mod not found in any parent directory of %s", filename) - } - dir = parent - } -} - // setupChatFlow creates a complete chat flow setup for integration testing. // // This function assembles all dependencies needed for TUI integration tests. @@ -81,7 +59,9 @@ func findProjectRoot() (string, error) { // setup, cleanup := setupChatFlow(t) // defer cleanup() // -// tui := New(setup.Ctx, setup.Flow, "test-session-id") +// sessionID, cleanup := createTestSession(t, setup) +// defer cleanup() +// tui := New(setup.Ctx, setup.Flow, sessionID) // } func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) { t.Helper() @@ -98,10 +78,10 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - projectRoot, err := findProjectRoot() + projectRoot, err := testutil.FindProjectRoot() if err != nil || projectRoot == "" { cancel() - t.Fatalf("Failed to find project root: %v", err) + t.Fatalf("FindProjectRoot() error: %v", err) } promptsDir := filepath.Join(projectRoot, "prompts") @@ -109,7 +89,7 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) { pool, err := pgxpool.New(ctx, dbURL) if err != nil { cancel() - t.Fatalf("Failed to connect to database: %v", err) + t.Fatalf("pgxpool.New() error: %v", err) } // Create PostgreSQL engine for Genkit @@ -120,7 +100,7 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) { if err != nil { pool.Close() cancel() - t.Fatalf("Failed to create PostgresEngine: %v", err) + t.Fatalf("NewPostgresEngine() error: %v", err) } postgres := &postgresql.Postgres{Engine: pEngine} @@ -133,15 +113,14 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) { if g == nil { pool.Close() cancel() - t.Fatal("Failed to initialize Genkit") + t.Fatal("genkit.Init() returned nil") } logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelWarn})) cfg := &config.Config{ ModelName: "gemini-2.0-flash", - EmbedderModel: "text-embedding-004", - RAGTopK: 5, + EmbedderModel: "gemini-embedding-001", MaxTurns: 10, } @@ -150,7 +129,7 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) { if embedder == nil { pool.Close() cancel() - t.Fatalf("Failed to create embedder for model %q", cfg.EmbedderModel) + t.Fatalf("GoogleAIEmbedder(%q) returned nil", cfg.EmbedderModel) } // Create DocStore and Retriever using shared config factory @@ -159,7 +138,7 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) { if err != nil { pool.Close() cancel() - t.Fatalf("Failed to define retriever: %v", err) + t.Fatalf("DefineRetriever() error: %v", err) } queries := sqlc.New(pool) @@ -169,51 +148,51 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) { if err != nil { pool.Close() cancel() - t.Fatalf("Failed to create path validator: %v", err) + t.Fatalf("NewPath() error: %v", err) } // Create and register file tools - ft, err := tools.NewFileTools(pathValidator, logger) + ft, err := tools.NewFile(pathValidator, logger) if err != nil { pool.Close() cancel() - t.Fatalf("Failed to create file tools: %v", err) + t.Fatalf("NewFile() error: %v", err) } - fileTools, err := tools.RegisterFileTools(g, ft) + fileTools, err := tools.RegisterFile(g, ft) if err != nil { pool.Close() cancel() - t.Fatalf("Failed to register file tools: %v", err) + t.Fatalf("RegisterFile() error: %v", err) } // Create and register system tools cmdValidator := security.NewCommand() envValidator := security.NewEnv() - st, err := tools.NewSystemTools(cmdValidator, envValidator, logger) + st, err := tools.NewSystem(cmdValidator, envValidator, logger) if err != nil { pool.Close() cancel() - t.Fatalf("Failed to create system tools: %v", err) + t.Fatalf("NewSystem() error: %v", err) } - systemTools, err := tools.RegisterSystemTools(g, st) + systemTools, err := tools.RegisterSystem(g, st) if err != nil { pool.Close() cancel() - t.Fatalf("Failed to register system tools: %v", err) + t.Fatalf("RegisterSystem() error: %v", err) } // Create and register knowledge tools - kt, err := tools.NewKnowledgeTools(retriever, nil, logger) + kt, err := tools.NewKnowledge(retriever, nil, logger) if err != nil { pool.Close() cancel() - t.Fatalf("Failed to create knowledge tools: %v", err) + t.Fatalf("NewKnowledge() error: %v", err) } - knowledgeTools, err := tools.RegisterKnowledgeTools(g, kt) + knowledgeTools, err := tools.RegisterKnowledge(g, kt) if err != nil { pool.Close() cancel() - t.Fatalf("Failed to register knowledge tools: %v", err) + t.Fatalf("RegisterKnowledge() error: %v", err) } // Combine all tools @@ -221,27 +200,20 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) { chatAgent, err := chat.New(chat.Config{ Genkit: g, - Retriever: retriever, SessionStore: sessionStore, Logger: logger, Tools: allTools, MaxTurns: cfg.MaxTurns, - RAGTopK: cfg.RAGTopK, }) if err != nil { pool.Close() cancel() - t.Fatalf("Failed to create chat agent: %v", err) + t.Fatalf("chat.New() error: %v", err) } // Initialize Flow singleton (reset first for test isolation) chat.ResetFlowForTesting() - flow, err := chat.InitFlow(g, chatAgent) - if err != nil { - pool.Close() - cancel() - t.Fatalf("Failed to init chat flow: %v", err) - } + flow := chat.NewFlow(g, chatAgent) setup := &chatFlowSetup{ Flow: flow, diff --git a/internal/tui/commands.go b/internal/tui/stream.go similarity index 69% rename from internal/tui/commands.go rename to internal/tui/stream.go index a326079..38761c1 100644 --- a/internal/tui/commands.go +++ b/internal/tui/stream.go @@ -7,7 +7,8 @@ import ( tea "charm.land/bubbletea/v2" - "github.com/koopa0/koopa/internal/agent/chat" + "github.com/koopa0/koopa/internal/chat" + "github.com/koopa0/koopa/internal/tools" ) // streamBufferSize is sized for ~1.5s burst at 60 FPS refresh rate. @@ -20,10 +21,11 @@ const streamBufferSize = 100 // and eliminates complex multi-channel closure handling. type streamEvent struct { // Exactly one of these fields is set per event - text string // Text chunk (when non-empty) - output chat.Output // Final output (when done is true) - err error // Error (when non-nil) - done bool // True when stream completed successfully + text string // Text chunk (when non-empty) + output chat.Output // Final output (when done is true) + err error // Error (when non-nil) + done bool // True when stream completed successfully + toolStatus string // Tool status message (when non-empty, e.g. "搜尋網路中...") } // Stream message types for Bubble Tea @@ -44,6 +46,42 @@ type streamErrorMsg struct { err error } +type streamToolMsg struct { + status string +} + +// tuiToolEmitter implements tools.Emitter for the TUI. +// Sends tool status through the stream event channel so Bubble Tea +// can display tool execution progress (e.g., "搜尋網路中..."). +type tuiToolEmitter struct { + eventCh chan<- streamEvent +} + +func (e *tuiToolEmitter) OnToolStart(name string) { + display := toolDisplayName(name) + select { + case e.eventCh <- streamEvent{toolStatus: display + "..."}: + default: // best-effort: don't block if channel is full + } +} + +func (e *tuiToolEmitter) OnToolComplete(_ string) { + select { + case e.eventCh <- streamEvent{toolStatus: ""}: + default: + } +} + +func (e *tuiToolEmitter) OnToolError(_ string) { + select { + case e.eventCh <- streamEvent{toolStatus: ""}: + default: + } +} + +// Compile-time interface verification. +var _ tools.Emitter = (*tuiToolEmitter)(nil) + // startStream creates a command that initiates streaming. // Directly uses *chat.Flow - no adapter needed. // @@ -53,12 +91,15 @@ type streamErrorMsg struct { // 3. Error occurs // // Channel closure signals completion - no WaitGroup needed. -func (t *TUI) startStream(query string) tea.Cmd { +func (m *Model) startStream(query string) tea.Cmd { return func() tea.Msg { eventCh := make(chan streamEvent, streamBufferSize) // Create context with timeout to prevent indefinite hangs - ctx, cancel := context.WithTimeout(t.ctx, streamTimeout) + ctx, cancel := context.WithTimeout(m.ctx, streamTimeout) + + // Inject tool event emitter so tool status is shown in TUI + ctx = tools.ContextWithEmitter(ctx, &tuiToolEmitter{eventCh: eventCh}) go func() { // Ensure timer resources are released on all exit paths @@ -81,9 +122,9 @@ func (t *TUI) startStream(query string) tea.Cmd { // Directly use chat.Flow's iterator (Go 1.23+ range-over-func) // Genkit's StreamingFlowValue has: {Stream.Text, Output, Done} - for streamValue, err := range t.chatFlow.Stream(ctx, chat.Input{ + for streamValue, err := range m.chatFlow.Stream(ctx, chat.Input{ Query: query, - SessionID: t.sessionID, + SessionID: m.sessionID.String(), }) { if err != nil { select { @@ -154,6 +195,8 @@ func listenForStream(eventCh <-chan streamEvent) tea.Cmd { return streamErrorMsg{err: event.err} case event.done: return streamDoneMsg{output: event.output} + case event.toolStatus != "": + return streamToolMsg{status: event.toolStatus} case event.text != "": return streamTextMsg{text: event.text} default: diff --git a/internal/tui/tui.go b/internal/tui/tui.go deleted file mode 100644 index 44f31b5..0000000 --- a/internal/tui/tui.go +++ /dev/null @@ -1,413 +0,0 @@ -// Package tui provides Bubble Tea terminal interface for Koopa. -package tui - -import ( - "context" - "errors" - "strings" - "time" - - "charm.land/bubbles/v2/help" - "charm.land/bubbles/v2/key" - "charm.land/bubbles/v2/spinner" - "charm.land/bubbles/v2/textarea" - "charm.land/bubbles/v2/viewport" - tea "charm.land/bubbletea/v2" - "charm.land/lipgloss/v2" - - "github.com/koopa0/koopa/internal/agent/chat" -) - -// State represents TUI state machine. -type State int - -// TUI state machine states. -const ( - StateInput State = iota // Awaiting user input - StateThinking // Processing request - StateStreaming // Streaming response -) - -// Memory bounds to prevent unbounded growth. -const ( - maxMessages = 100 // Maximum messages stored - maxHistory = 100 // Maximum command history entries -) - -// Timeout constants for stream operations. -const streamTimeout = 5 * time.Minute // Maximum time for a single stream - -// Message role constants for consistent display. -const ( - roleUser = "user" - roleAssistant = "assistant" - roleSystem = "system" - roleError = "error" -) - -// Layout constants for viewport height calculation. -const ( - separatorLines = 2 // Two separator lines (above and below input) - helpLines = 1 // Help bar height - promptLines = 1 // Prompt prefix line - minViewport = 3 // Minimum viewport height -) - -// Message represents a conversation message for display. -type Message struct { - Role string // "user", "assistant", "system", "error" - Text string -} - -// TUI is the Bubble Tea model for Koopa terminal interface. -type TUI struct { - // Input (textarea for multi-line support, Shift+Enter for newline) - input textarea.Model - history []string - historyIdx int - - // State - state State - lastCtrlC time.Time - - // Output - spinner spinner.Model - output strings.Builder - viewBuf strings.Builder // Reusable buffer for View() to reduce allocations - messages []Message - - // Scrollable message viewport - viewport viewport.Model - - // Help bar for keyboard shortcuts - help help.Model - keys keyMap - - // Stream management - // Note: No sync.WaitGroup - Bubble Tea's event loop provides synchronization. - // Single union channel with discriminated events simplifies select logic. - streamCancel context.CancelFunc - streamEventCh <-chan streamEvent - - // Dependencies (direct, no interface) - chatFlow *chat.Flow - sessionID string - ctx context.Context - ctxCancel context.CancelFunc // For canceling all operations on exit - - // Dimensions - width int - height int - - // Styles - styles Styles - - // Markdown rendering (nil = graceful degradation to plain text) - markdown *markdownRenderer -} - -// addMessage appends a message and enforces maxMessages bound. -func (t *TUI) addMessage(msg Message) { - t.messages = append(t.messages, msg) - if len(t.messages) > maxMessages { - // Remove oldest messages to stay within bounds - t.messages = t.messages[len(t.messages)-maxMessages:] - } -} - -// New creates a TUI model for chat interaction. -// Returns error if required dependencies are nil. -// -// IMPORTANT: ctx MUST be the same context passed to tea.WithContext() -// to ensure consistent cancellation behavior. -func New(ctx context.Context, flow *chat.Flow, sessionID string) (*TUI, error) { - if flow == nil { - return nil, errors.New("tui.New: flow is required") - } - if ctx == nil { - return nil, errors.New("tui.New: ctx is required") - } - if sessionID == "" { - return nil, errors.New("tui.New: session ID is required") - } - - // Create cancellable context for cleanup on exit - ctx, cancel := context.WithCancel(ctx) - - // Create textarea for multi-line input - // Enter submits, Shift+Enter adds newline (default behavior) - ta := textarea.New() - ta.Placeholder = "Ask anything..." - ta.SetHeight(1) // Single line by default - ta.SetWidth(120) // Wide enough for long text, updated on WindowSizeMsg - ta.MaxWidth = 0 // No max width limit - ta.ShowLineNumbers = false - - // Clean, minimal styling like Claude Code / Gemini CLI - // No background colors, just simple text - cleanStyle := textarea.StyleState{ - Base: lipgloss.NewStyle(), - Text: lipgloss.NewStyle(), - Placeholder: lipgloss.NewStyle().Foreground(lipgloss.Color("240")), // Gray placeholder - Prompt: lipgloss.NewStyle(), - } - ta.SetStyles(textarea.Styles{ - Focused: cleanStyle, - Blurred: cleanStyle, - }) - ta.Focus() - - sp := spinner.New() - sp.Spinner = spinner.Dot - - // Create viewport for scrollable message history. - // Disable built-in keyboard handling — we route keys explicitly - // in handleKey to avoid conflicts with textarea/history navigation. - vp := viewport.New(viewport.WithWidth(80), viewport.WithHeight(20)) - vp.MouseWheelEnabled = true - vp.SoftWrap = true - vp.KeyMap = viewport.KeyMap{} // Disable default key bindings - - h := help.New() - - return &TUI{ - chatFlow: flow, - sessionID: sessionID, - ctx: ctx, - ctxCancel: cancel, - input: ta, - spinner: sp, - viewport: vp, - help: h, - keys: newKeyMap(), - styles: DefaultStyles(), - history: make([]string, 0, maxHistory), - markdown: newMarkdownRenderer(80), - width: 80, // Default width until WindowSizeMsg arrives - }, nil -} - -// Init implements tea.Model. -func (t *TUI) Init() tea.Cmd { - return tea.Batch( - textarea.Blink, - t.spinner.Tick, - t.input.Focus(), // Ensure textarea is focused on startup - ) -} - -// Update implements tea.Model. -// -//nolint:gocognit,gocyclo // Bubble Tea Update requires type switch on all message types -func (t *TUI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.KeyPressMsg: - return t.handleKey(msg) - - case tea.WindowSizeMsg: - t.width = msg.Width - t.height = msg.Height - - // Calculate viewport height: total - input - separators - help - inputHeight := t.input.Height() + promptLines - fixedHeight := separatorLines + inputHeight + helpLines - vpHeight := max(msg.Height-fixedHeight, minViewport) - - t.viewport.SetWidth(msg.Width) - t.viewport.SetHeight(vpHeight) - t.input.SetWidth(msg.Width - 4) // Room for "> " prompt - t.help.SetWidth(msg.Width) - t.markdown.UpdateWidth(msg.Width) - - // Rebuild viewport content with new dimensions - t.rebuildViewportContent() - return t, nil - - case tea.MouseWheelMsg: - // Forward mouse wheel to viewport for scrolling - var cmd tea.Cmd - t.viewport, cmd = t.viewport.Update(msg) - return t, cmd - - case spinner.TickMsg: - var cmd tea.Cmd - t.spinner, cmd = t.spinner.Update(msg) - // Rebuild viewport to update spinner animation - if t.state == StateThinking { - t.rebuildViewportContent() - } - return t, cmd - - case streamStartedMsg: - t.streamCancel = msg.cancel - t.streamEventCh = msg.eventCh - t.state = StateStreaming - t.rebuildViewportContent() - t.viewport.GotoBottom() - return t, listenForStream(msg.eventCh) - - case streamTextMsg: - t.output.WriteString(msg.text) - t.rebuildViewportContent() - t.viewport.GotoBottom() - return t, listenForStream(t.streamEventCh) - - case streamDoneMsg: - t.state = StateInput - - // Cancel context to release timer resources - if t.streamCancel != nil { - t.streamCancel() - t.streamCancel = nil - } - t.streamEventCh = nil - - // Prefer msg.output.Response (complete response from Genkit) over accumulated chunks. - // This handles models that don't stream or send final content only in Output. - finalText := msg.output.Response - if finalText == "" { - // Fallback to accumulated chunks if Response is empty - finalText = t.output.String() - } - - t.addMessage(Message{ - Role: roleAssistant, - Text: finalText, - }) - t.output.Reset() - t.rebuildViewportContent() - t.viewport.GotoBottom() - // Re-focus textarea after stream completes - return t, t.input.Focus() - - case streamErrorMsg: - t.state = StateInput - - // Cancel context to release timer resources - if t.streamCancel != nil { - t.streamCancel() - t.streamCancel = nil - } - t.streamEventCh = nil - - switch { - case errors.Is(msg.err, context.Canceled): - t.addMessage(Message{Role: roleSystem, Text: "(Canceled)"}) - case errors.Is(msg.err, context.DeadlineExceeded): - t.addMessage(Message{Role: roleError, Text: "Query timeout (>5 min). Try a simpler query or break it into steps."}) - default: - t.addMessage(Message{Role: roleError, Text: msg.err.Error()}) - } - t.output.Reset() - t.rebuildViewportContent() - t.viewport.GotoBottom() - // Re-focus textarea after error - return t, t.input.Focus() - } - - var cmd tea.Cmd - t.input, cmd = t.input.Update(msg) - return t, cmd -} - -// View implements tea.Model. -// Uses AltScreen with viewport for scrollable message history. -func (t *TUI) View() tea.View { - t.viewBuf.Reset() - - // Viewport (scrollable message area) - _, _ = t.viewBuf.WriteString(t.viewport.View()) - _, _ = t.viewBuf.WriteString("\n") - - // Separator line above input - _, _ = t.viewBuf.WriteString(t.renderSeparator()) - _, _ = t.viewBuf.WriteString("\n") - - // Input prompt - always show and always accept input - // Users can type while LLM is thinking/streaming (better UX) - _, _ = t.viewBuf.WriteString(t.styles.Prompt.Render("> ")) - _, _ = t.viewBuf.WriteString(t.input.View()) - _, _ = t.viewBuf.WriteString("\n") - - // Separator line below input - _, _ = t.viewBuf.WriteString(t.renderSeparator()) - _, _ = t.viewBuf.WriteString("\n") - - // Help bar (keyboard shortcuts) - _, _ = t.viewBuf.WriteString(t.renderStatusBar()) - - v := tea.NewView(t.viewBuf.String()) - v.AltScreen = true - return v -} - -// rebuildViewportContent reconstructs the viewport content from messages and state. -// Called when messages, streaming output, or state changes. -func (t *TUI) rebuildViewportContent() { - var b strings.Builder - - // Banner (ASCII art) and tips - _, _ = b.WriteString(t.styles.RenderBanner()) - _, _ = b.WriteString("\n") - _, _ = b.WriteString(t.styles.RenderWelcomeTips()) - _, _ = b.WriteString("\n") - - // Messages (already bounded by addMessage) - for _, msg := range t.messages { - switch msg.Role { - case roleUser: - _, _ = b.WriteString(t.styles.User.Render("You> ")) - _, _ = b.WriteString(msg.Text) - case roleAssistant: - _, _ = b.WriteString(t.styles.Assistant.Render("Koopa> ")) - _, _ = b.WriteString(t.markdown.Render(msg.Text)) - case roleSystem: - _, _ = b.WriteString(t.styles.System.Render(msg.Text)) - case roleError: - _, _ = b.WriteString(t.styles.Error.Render("Error: " + msg.Text)) - } - _, _ = b.WriteString("\n\n") - } - - // Current streaming output - if t.state == StateStreaming && t.output.Len() > 0 { - _, _ = b.WriteString(t.styles.Assistant.Render("Koopa> ")) - _, _ = b.WriteString(t.output.String()) - _, _ = b.WriteString("\n\n") - } - - // Thinking indicator - if t.state == StateThinking { - _, _ = b.WriteString(t.spinner.View()) - _, _ = b.WriteString(" Thinking...\n\n") - } - - t.viewport.SetContent(b.String()) -} - -// renderSeparator returns a horizontal line separator. -func (t *TUI) renderSeparator() string { - width := t.width - if width <= 0 { - width = 80 // Default width - } - return t.styles.Separator.Render(strings.Repeat("─", width)) -} - -// renderStatusBar returns state-appropriate keyboard shortcut help. -func (t *TUI) renderStatusBar() string { - var bindings []key.Binding - switch t.state { - case StateInput: - bindings = []key.Binding{ - t.keys.Submit, t.keys.NewLine, t.keys.History, - t.keys.Cancel, t.keys.Quit, t.keys.ScrollUp, - } - case StateThinking, StateStreaming: - bindings = []key.Binding{ - t.keys.EscCancel, t.keys.Cancel, - t.keys.ScrollUp, t.keys.ScrollDown, - } - } - return t.help.ShortHelpView(bindings) -} diff --git a/internal/tui/tui_test.go b/internal/tui/tui_test.go index d362468..a1d7ebd 100644 --- a/internal/tui/tui_test.go +++ b/internal/tui/tui_test.go @@ -12,7 +12,8 @@ import ( tea "charm.land/bubbletea/v2" "go.uber.org/goleak" - "github.com/koopa0/koopa/internal/agent/chat" + "github.com/google/uuid" + "github.com/koopa0/koopa/internal/chat" ) // goleakOptions returns standard goleak options for all TUI tests. @@ -27,8 +28,8 @@ func goleakOptions() []goleak.Option { } } -// newTestTUI creates a TUI with properly initialized components for testing. -func newTestTUI() *TUI { +// newTestModel creates a Model with properly initialized components for testing. +func newTestModel() *Model { ta := textarea.New() ta.SetHeight(3) ta.ShowLineNumbers = false @@ -38,7 +39,7 @@ func newTestTUI() *TUI { vp.SoftWrap = true vp.KeyMap = viewport.KeyMap{} - return &TUI{ + return &Model{ state: StateInput, input: ta, viewport: vp, @@ -53,7 +54,7 @@ func newTestTUI() *TUI { } func TestNew_ErrorOnNilFlow(t *testing.T) { - _, err := New(context.Background(), nil, "test") + _, err := New(context.Background(), nil, uuid.New()) if err == nil { t.Error("Expected error for nil flow") } @@ -64,23 +65,23 @@ func TestNew_ErrorOnNilContext(t *testing.T) { // so we're testing that error is returned for nil context var flow *chat.Flow //lint:ignore SA1012 intentionally testing nil context handling - _, err := New(nil, flow, "test") //nolint:staticcheck + _, err := New(nil, flow, uuid.New()) //nolint:staticcheck if err == nil { t.Error("Expected error for nil context") } } -func TestNew_ErrorOnEmptySessionID(t *testing.T) { - _, err := New(context.Background(), nil, "") +func TestNew_ErrorOnNilSessionID(t *testing.T) { + _, err := New(context.Background(), nil, uuid.Nil) if err == nil { - t.Error("Expected error for empty session ID") + t.Error("Expected error for nil session ID") } } func TestTUI_Init(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() cmd := tui.Init() if cmd == nil { t.Error("Init should return a command (blink + spinner tick)") @@ -96,22 +97,22 @@ func TestTUI_HandleSlashCommands(t *testing.T) { wantExit bool wantMsgs int // number of messages added }{ - {"help", "/help", false, 1}, - {"clear", "/clear", false, 0}, // clears messages - {"exit", "/exit", true, 0}, - {"quit", "/quit", true, 0}, - {"unknown", "/unknown", false, 1}, // error message + {name: "help", cmd: "/help", wantExit: false, wantMsgs: 1}, + {name: "clear", cmd: "/clear", wantExit: false, wantMsgs: 0}, // clears messages + {name: "exit", cmd: "/exit", wantExit: true, wantMsgs: 0}, + {name: "quit", cmd: "/quit", wantExit: true, wantMsgs: 0}, + {name: "unknown", cmd: "/unknown", wantExit: false, wantMsgs: 1}, // error message } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tui := newTestTUI() + tui := newTestModel() // Pre-populate with a message for /clear test tui.messages = []Message{{Role: "user", Text: "hello"}} model, cmd := tui.handleSlashCommand(tt.cmd) - result := model.(*TUI) + result := model.(*Model) if tt.wantExit { if cmd == nil { @@ -135,29 +136,29 @@ func TestTUI_HandleSlashCommands(t *testing.T) { func TestTUI_HistoryNavigation(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() tui.history = []string{"first", "second", "third"} tui.historyIdx = 3 tests := []struct { - delta int - expected string + delta int + want string }{ - {-1, "third"}, - {-1, "second"}, - {-1, "first"}, - {-1, "first"}, // Should stay at first - {1, "second"}, - {1, "third"}, - {1, ""}, // Past end = empty - {1, ""}, // Should stay empty + {delta: -1, want: "third"}, + {delta: -1, want: "second"}, + {delta: -1, want: "first"}, + {delta: -1, want: "first"}, // Should stay at first + {delta: 1, want: "second"}, + {delta: 1, want: "third"}, + {delta: 1, want: ""}, // Past end = empty + {delta: 1, want: ""}, // Should stay empty } for i, tt := range tests { model, _ := tui.navigateHistory(tt.delta) - tui = model.(*TUI) - if tui.input.Value() != tt.expected { - t.Errorf("Step %d: got %q, want %q", i, tui.input.Value(), tt.expected) + tui = model.(*Model) + if tui.input.Value() != tt.want { + t.Errorf("navigateHistory(%d) step %d: got %q, want %q", tt.delta, i, tui.input.Value(), tt.want) } } } @@ -165,11 +166,11 @@ func TestTUI_HistoryNavigation(t *testing.T) { func TestTUI_CtrlC_ClearsInput(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() tui.input.SetValue("some input") model, _ := tui.handleCtrlC() - result := model.(*TUI) + result := model.(*Model) if result.input.Value() != "" { t.Error("First Ctrl+C should clear input") @@ -179,7 +180,7 @@ func TestTUI_CtrlC_ClearsInput(t *testing.T) { func TestTUI_DoubleCtrlC_Exits(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() tui.lastCtrlC = time.Now() _, cmd := tui.handleCtrlC() @@ -192,7 +193,7 @@ func TestTUI_DoubleCtrlC_Exits(t *testing.T) { func TestTUI_Update_KeyPress(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() tui.input.SetValue("test") // Simulate Ctrl+C (should clear input) @@ -200,7 +201,7 @@ func TestTUI_Update_KeyPress(t *testing.T) { msg := tea.KeyPressMsg(key) model, _ := tui.Update(msg) - result := model.(*TUI) + result := model.(*Model) if result.input.Value() != "" { t.Error("Ctrl+C should clear input") @@ -210,7 +211,7 @@ func TestTUI_Update_KeyPress(t *testing.T) { func TestTUI_View_ReturnsAltScreen(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() view := tui.View() if !view.AltScreen { @@ -224,7 +225,7 @@ func TestTUI_View_ReturnsAltScreen(t *testing.T) { func TestTUI_RebuildViewportContent(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() tui.addMessage(Message{Role: roleUser, Text: "hello"}) tui.addMessage(Message{Role: roleAssistant, Text: "world"}) @@ -239,7 +240,7 @@ func TestTUI_RebuildViewportContent(t *testing.T) { func TestTUI_ViewportScrolling(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() // Add enough messages to exceed viewport height for i := 0; i < 50; i++ { @@ -262,12 +263,12 @@ func TestTUI_StreamMessageTypes(t *testing.T) { t.Run("streamTextMsg", func(t *testing.T) { eventCh := make(chan streamEvent, 1) - tui := newTestTUI() + tui := newTestModel() tui.state = StateStreaming tui.streamEventCh = eventCh model, _ := tui.Update(streamTextMsg{text: "Hello"}) - result := model.(*TUI) + result := model.(*Model) if result.output.String() != "Hello" { t.Errorf("Expected 'Hello', got %q", result.output.String()) @@ -275,12 +276,12 @@ func TestTUI_StreamMessageTypes(t *testing.T) { }) t.Run("streamDoneMsg", func(t *testing.T) { - tui := newTestTUI() + tui := newTestModel() tui.state = StateStreaming _, _ = tui.output.WriteString("Hello World") model, _ := tui.Update(streamDoneMsg{output: chat.Output{Response: "Hello World"}}) - result := model.(*TUI) + result := model.(*Model) if result.state != StateInput { t.Error("Should return to StateInput after stream done") @@ -294,11 +295,11 @@ func TestTUI_StreamMessageTypes(t *testing.T) { }) t.Run("streamErrorMsg", func(t *testing.T) { - tui := newTestTUI() + tui := newTestModel() tui.state = StateStreaming model, _ := tui.Update(streamErrorMsg{err: context.Canceled}) - result := model.(*TUI) + result := model.(*Model) if result.state != StateInput { t.Error("Should return to StateInput after error") @@ -399,7 +400,7 @@ func TestListenForStream_UnionChannel(t *testing.T) { func TestTUI_AddMessage_BoundsEnforcement(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() // Add more than maxMessages for i := 0; i < maxMessages+50; i++ { @@ -418,7 +419,7 @@ func TestTUI_AddMessage_BoundsEnforcement(t *testing.T) { func TestTUI_HandleSubmit_AddsToHistory(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() tui.ctx = context.Background() tui.input.SetValue("test query") @@ -445,7 +446,7 @@ func TestTUI_HandleSubmit_AddsToHistory(t *testing.T) { func TestTUI_HandleSubmit_HistoryBounds(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() // Pre-fill history to max for i := 0; i < maxHistory; i++ { @@ -474,17 +475,17 @@ func TestMarkdownRenderer_UpdateWidth(t *testing.T) { t.Run("creates renderer with correct width", func(t *testing.T) { mr := newMarkdownRenderer(100) if mr == nil { - t.Fatal("Failed to create markdown renderer") + t.Fatal("newMarkdownRenderer() returned nil") } if mr.width != 100 { - t.Errorf("Expected width 100, got %d", mr.width) + t.Errorf("newMarkdownRenderer(100).width = %d, want 100", mr.width) } }) t.Run("UpdateWidth changes width", func(t *testing.T) { mr := newMarkdownRenderer(80) if mr == nil { - t.Fatal("Failed to create markdown renderer") + t.Fatal("newMarkdownRenderer() returned nil") } updated := mr.UpdateWidth(120) @@ -492,14 +493,14 @@ func TestMarkdownRenderer_UpdateWidth(t *testing.T) { t.Error("UpdateWidth should return true when width changes") } if mr.width != 120 { - t.Errorf("Expected width 120, got %d", mr.width) + t.Errorf("UpdateWidth(120) width = %d, want 120", mr.width) } }) t.Run("UpdateWidth no-op for same width", func(t *testing.T) { mr := newMarkdownRenderer(80) if mr == nil { - t.Fatal("Failed to create markdown renderer") + t.Fatal("newMarkdownRenderer() returned nil") } updated := mr.UpdateWidth(80) @@ -519,7 +520,7 @@ func TestMarkdownRenderer_UpdateWidth(t *testing.T) { t.Run("UpdateWidth handles invalid width", func(t *testing.T) { mr := newMarkdownRenderer(80) if mr == nil { - t.Fatal("Failed to create markdown renderer") + t.Fatal("newMarkdownRenderer() returned nil") } updated := mr.UpdateWidth(0) @@ -540,7 +541,7 @@ func TestMarkdownRenderer_Render(t *testing.T) { t.Run("renders markdown", func(t *testing.T) { mr := newMarkdownRenderer(80) if mr == nil { - t.Fatal("Failed to create markdown renderer") + t.Fatal("newMarkdownRenderer() returned nil") } result := mr.Render("**bold**") @@ -562,7 +563,7 @@ func TestMarkdownRenderer_Render(t *testing.T) { func TestTUI_Cleanup(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() // Setup stream state eventCh := make(chan streamEvent, 1) @@ -582,7 +583,7 @@ func TestTUI_Cleanup(t *testing.T) { func TestTUI_CancelStream(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() canceled := false tui.streamCancel = func() { canceled = true } @@ -600,14 +601,14 @@ func TestTUI_CancelStream(t *testing.T) { func TestTUI_CtrlC_CancelsStream(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() tui.state = StateStreaming canceled := false tui.streamCancel = func() { canceled = true } model, _ := tui.handleCtrlC() - result := model.(*TUI) + result := model.(*Model) if !canceled { t.Error("Ctrl+C during streaming should cancel") @@ -623,7 +624,7 @@ func TestTUI_CtrlC_CancelsStream(t *testing.T) { func TestTUI_RenderStatusBar_StateInput(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() tui.state = StateInput bar := tui.renderStatusBar() @@ -635,7 +636,7 @@ func TestTUI_RenderStatusBar_StateInput(t *testing.T) { func TestTUI_RenderStatusBar_StateStreaming(t *testing.T) { defer goleak.VerifyNone(t, goleakOptions()...) - tui := newTestTUI() + tui := newTestModel() tui.state = StateStreaming bar := tui.renderStatusBar() diff --git a/internal/tui/update.go b/internal/tui/update.go new file mode 100644 index 0000000..1f65575 --- /dev/null +++ b/internal/tui/update.go @@ -0,0 +1,132 @@ +package tui + +import ( + "context" + "errors" + + "charm.land/bubbles/v2/spinner" + tea "charm.land/bubbletea/v2" +) + +// Update implements tea.Model. +// +//nolint:gocognit,gocyclo // Bubble Tea Update requires type switch on all message types +func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyPressMsg: + return m.handleKey(msg) + + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + + // Calculate viewport height: total - input - separators - help + inputHeight := m.input.Height() + promptLines + fixedHeight := separatorLines + inputHeight + helpLines + vpHeight := max(msg.Height-fixedHeight, minViewport) + + m.viewport.SetWidth(msg.Width) + m.viewport.SetHeight(vpHeight) + m.input.SetWidth(msg.Width - 4) // Room for "> " prompt + m.help.SetWidth(msg.Width) + m.markdown.UpdateWidth(msg.Width) + + // Rebuild viewport content with new dimensions + m.rebuildViewportContent() + return m, nil + + case tea.MouseWheelMsg: + // Forward mouse wheel to viewport for scrolling + var cmd tea.Cmd + m.viewport, cmd = m.viewport.Update(msg) + return m, cmd + + case spinner.TickMsg: + var cmd tea.Cmd + m.spinner, cmd = m.spinner.Update(msg) + // Rebuild viewport to update spinner animation during thinking or tool execution + if m.state == StateThinking || (m.state == StateStreaming && m.toolStatus != "") { + m.rebuildViewportContent() + } + return m, cmd + + case streamStartedMsg: + m.streamCancel = msg.cancel + m.streamEventCh = msg.eventCh + m.state = StateStreaming + m.rebuildViewportContent() + m.viewport.GotoBottom() + return m, listenForStream(msg.eventCh) + + case streamToolMsg: + m.toolStatus = msg.status + m.rebuildViewportContent() + m.viewport.GotoBottom() + return m, listenForStream(m.streamEventCh) + + case streamTextMsg: + m.toolStatus = "" // Clear tool status when text arrives + m.output.WriteString(msg.text) + m.rebuildViewportContent() + m.viewport.GotoBottom() + return m, listenForStream(m.streamEventCh) + + case streamDoneMsg: + m.state = StateInput + m.toolStatus = "" + + // Cancel context to release timer resources + if m.streamCancel != nil { + m.streamCancel() + m.streamCancel = nil + } + m.streamEventCh = nil + + // Prefer msg.output.Response (complete response from Genkit) over accumulated chunks. + // This handles models that don't stream or send final content only in Output. + finalText := msg.output.Response + if finalText == "" { + // Fallback to accumulated chunks if Response is empty + finalText = m.output.String() + } + + m.addMessage(Message{ + Role: roleAssistant, + Text: finalText, + }) + m.output.Reset() + m.rebuildViewportContent() + m.viewport.GotoBottom() + // Re-focus textarea after stream completes + return m, m.input.Focus() + + case streamErrorMsg: + m.state = StateInput + m.toolStatus = "" + + // Cancel context to release timer resources + if m.streamCancel != nil { + m.streamCancel() + m.streamCancel = nil + } + m.streamEventCh = nil + + switch { + case errors.Is(msg.err, context.Canceled): + m.addMessage(Message{Role: roleSystem, Text: "(Canceled)"}) + case errors.Is(msg.err, context.DeadlineExceeded): + m.addMessage(Message{Role: roleError, Text: "Query timeout (>5 min). Try a simpler query or break it into steps."}) + default: + m.addMessage(Message{Role: roleError, Text: msg.err.Error()}) + } + m.output.Reset() + m.rebuildViewportContent() + m.viewport.GotoBottom() + // Re-focus textarea after error + return m, m.input.Focus() + } + + var cmd tea.Cmd + m.input, cmd = m.input.Update(msg) + return m, cmd +} diff --git a/internal/tui/view.go b/internal/tui/view.go new file mode 100644 index 0000000..0567309 --- /dev/null +++ b/internal/tui/view.go @@ -0,0 +1,118 @@ +package tui + +import ( + "strings" + + "charm.land/bubbles/v2/key" + tea "charm.land/bubbletea/v2" +) + +// View implements tea.Model. +// Uses AltScreen with viewport for scrollable message history. +func (m *Model) View() tea.View { + m.viewBuf.Reset() + + // Viewport (scrollable message area) + _, _ = m.viewBuf.WriteString(m.viewport.View()) + _, _ = m.viewBuf.WriteString("\n") + + // Separator line above input + _, _ = m.viewBuf.WriteString(m.renderSeparator()) + _, _ = m.viewBuf.WriteString("\n") + + // Input prompt - always show and always accept input + // Users can type while LLM is thinking/streaming (better UX) + _, _ = m.viewBuf.WriteString(m.styles.Prompt.Render("> ")) + _, _ = m.viewBuf.WriteString(m.input.View()) + _, _ = m.viewBuf.WriteString("\n") + + // Separator line below input + _, _ = m.viewBuf.WriteString(m.renderSeparator()) + _, _ = m.viewBuf.WriteString("\n") + + // Help bar (keyboard shortcuts) + _, _ = m.viewBuf.WriteString(m.renderStatusBar()) + + v := tea.NewView(m.viewBuf.String()) + v.AltScreen = true + return v +} + +// rebuildViewportContent reconstructs the viewport content from messages and state. +// Called when messages, streaming output, or state changes. +func (m *Model) rebuildViewportContent() { + var b strings.Builder + + // Banner (ASCII art) and tips + _, _ = b.WriteString(m.styles.RenderBanner()) + _, _ = b.WriteString("\n") + _, _ = b.WriteString(m.styles.RenderWelcomeTips()) + _, _ = b.WriteString("\n") + + // Messages (already bounded by addMessage) + for _, msg := range m.messages { + switch msg.Role { + case roleUser: + _, _ = b.WriteString(m.styles.User.Render("You> ")) + _, _ = b.WriteString(msg.Text) + case roleAssistant: + _, _ = b.WriteString(m.styles.Assistant.Render("Koopa> ")) + _, _ = b.WriteString(m.markdown.Render(msg.Text)) + case roleSystem: + _, _ = b.WriteString(m.styles.System.Render(msg.Text)) + case roleError: + _, _ = b.WriteString(m.styles.Error.Render("Error: " + msg.Text)) + } + _, _ = b.WriteString("\n\n") + } + + // Current streaming output + if m.state == StateStreaming && m.output.Len() > 0 { + _, _ = b.WriteString(m.styles.Assistant.Render("Koopa> ")) + _, _ = b.WriteString(m.output.String()) + _, _ = b.WriteString("\n\n") + } + + // Tool status indicator (shown during streaming when a tool is executing) + if m.state == StateStreaming && m.toolStatus != "" { + _, _ = b.WriteString(m.spinner.View()) + _, _ = b.WriteString(" ") + _, _ = b.WriteString(m.styles.System.Render(m.toolStatus)) + _, _ = b.WriteString("\n\n") + } + + // Thinking indicator + if m.state == StateThinking { + _, _ = b.WriteString(m.spinner.View()) + _, _ = b.WriteString(" Thinking...\n\n") + } + + m.viewport.SetContent(b.String()) +} + +// renderSeparator returns a horizontal line separator. +func (m *Model) renderSeparator() string { + width := m.width + if width <= 0 { + width = 80 // Default width + } + return m.styles.Separator.Render(strings.Repeat("─", width)) +} + +// renderStatusBar returns state-appropriate keyboard shortcut help. +func (m *Model) renderStatusBar() string { + var bindings []key.Binding + switch m.state { + case StateInput: + bindings = []key.Binding{ + m.keys.Submit, m.keys.NewLine, m.keys.History, + m.keys.Cancel, m.keys.Quit, m.keys.ScrollUp, + } + case StateThinking, StateStreaming: + bindings = []key.Binding{ + m.keys.EscCancel, m.keys.Cancel, + m.keys.ScrollUp, m.keys.ScrollDown, + } + } + return m.help.ShortHelpView(bindings) +} diff --git a/prompts/koopa.prompt b/prompts/koopa.prompt index 8dfcdfc..08e565e 100644 --- a/prompts/koopa.prompt +++ b/prompts/koopa.prompt @@ -1,4 +1,5 @@ --- +# Model below is a dotprompt default; overridden at runtime by chat.Config.ModelName. model: googleai/gemini-2.5-flash config: temperature: 0.7