diff --git a/.golangci.yml b/.golangci.yml
index f52fd17..7875193 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -381,7 +381,7 @@ linters:
- contextcheck
# contextcheck: flow.go uses InvocationContext which wraps context
- - path: internal/agent/chat/flow\.go
+ - path: internal/chat/flow\.go
linters:
- contextcheck
diff --git a/cmd/addr.go b/cmd/addr.go
new file mode 100644
index 0000000..5893230
--- /dev/null
+++ b/cmd/addr.go
@@ -0,0 +1,74 @@
+package cmd
+
+import (
+ "flag"
+ "fmt"
+ "net"
+ "os"
+ "strconv"
+ "strings"
+)
+
+// parseServeAddr parses and validates the server address from command line arguments.
+// Uses flag.FlagSet for standard Go flag parsing, supporting:
+// - koopa serve :8080 (positional)
+// - koopa serve --addr :8080 (flag)
+// - koopa serve -addr :8080 (single dash)
+func parseServeAddr() (string, error) {
+ const defaultAddr = "127.0.0.1:3400"
+
+ serveFlags := flag.NewFlagSet("serve", flag.ContinueOnError)
+ serveFlags.SetOutput(os.Stderr)
+
+ addr := serveFlags.String("addr", defaultAddr, "Server address (host:port)")
+
+ args := []string{}
+ if len(os.Args) > 2 {
+ args = os.Args[2:]
+ }
+
+ // Check for positional argument first (koopa serve :8080)
+ if len(args) > 0 && !strings.HasPrefix(args[0], "-") {
+ *addr = args[0]
+ args = args[1:]
+ }
+
+ if err := serveFlags.Parse(args); err != nil {
+ return "", fmt.Errorf("parsing serve flags: %w", err)
+ }
+
+ if err := validateAddr(*addr); err != nil {
+ return "", fmt.Errorf("invalid address %q: %w", *addr, err)
+ }
+
+ return *addr, nil
+}
+
+// validateAddr validates the server address format.
+func validateAddr(addr string) error {
+ host, port, err := net.SplitHostPort(addr)
+ if err != nil {
+ return fmt.Errorf("must be in host:port format: %w", err)
+ }
+
+ if host != "" && host != "localhost" {
+ if ip := net.ParseIP(host); ip == nil {
+ if strings.ContainsAny(host, " \t\n") {
+ return fmt.Errorf("invalid host: %s", host)
+ }
+ }
+ }
+
+ if port == "" {
+ return fmt.Errorf("port is required")
+ }
+ portNum, err := strconv.Atoi(port)
+ if err != nil {
+ return fmt.Errorf("port must be numeric: %w", err)
+ }
+ if portNum < 0 || portNum > 65535 {
+ return fmt.Errorf("port must be 0-65535 (0 = auto-assign), got %d", portNum)
+ }
+
+ return nil
+}
diff --git a/cmd/addr_test.go b/cmd/addr_test.go
new file mode 100644
index 0000000..fe12772
--- /dev/null
+++ b/cmd/addr_test.go
@@ -0,0 +1,68 @@
+package cmd
+
+import "testing"
+
+func TestValidateAddr(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ addr string
+ wantErr bool
+ }{
+ // Valid addresses
+ {name: "port only", addr: ":8080", wantErr: false},
+ {name: "localhost", addr: "localhost:3400", wantErr: false},
+ {name: "loopback", addr: "127.0.0.1:3400", wantErr: false},
+ {name: "all interfaces", addr: "0.0.0.0:80", wantErr: false},
+ {name: "ipv6 loopback", addr: "[::1]:8080", wantErr: false},
+ {name: "port zero", addr: ":0", wantErr: false},
+ {name: "port max", addr: ":65535", wantErr: false},
+ {name: "hostname", addr: "myhost:9090", wantErr: false},
+
+ // Invalid: bad format
+ {name: "no port", addr: "localhost", wantErr: true},
+ {name: "port alone", addr: "8080", wantErr: true},
+ {name: "empty string", addr: "", wantErr: true},
+
+ // Invalid: bad port
+ {name: "port non-numeric", addr: ":abc", wantErr: true},
+ {name: "port negative", addr: ":-1", wantErr: true},
+ {name: "port too high", addr: ":65536", wantErr: true},
+ {name: "port empty after colon", addr: "localhost:", wantErr: true},
+
+ // Invalid: bad host
+ {name: "host with space", addr: "my host:8080", wantErr: true},
+ {name: "host with tab", addr: "my\thost:8080", wantErr: true},
+ {name: "host with newline", addr: "my\nhost:8080", wantErr: true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ err := validateAddr(tt.addr)
+ if tt.wantErr && err == nil {
+ t.Errorf("validateAddr(%q) = nil, want error", tt.addr)
+ }
+ if !tt.wantErr && err != nil {
+ t.Errorf("validateAddr(%q) = %v, want nil", tt.addr, err)
+ }
+ })
+ }
+}
+
+func FuzzValidateAddr(f *testing.F) {
+ f.Add(":8080")
+ f.Add("localhost:3400")
+ f.Add("127.0.0.1:80")
+ f.Add("")
+ f.Add("abc")
+ f.Add(":0")
+ f.Add(":99999")
+ f.Add("[::1]:8080")
+ f.Add("host with space:80")
+
+ f.Fuzz(func(t *testing.T, addr string) {
+ _ = validateAddr(addr) // must not panic
+ })
+}
diff --git a/cmd/cli.go b/cmd/cli.go
index fcfede2..6356ebc 100644
--- a/cmd/cli.go
+++ b/cmd/cli.go
@@ -2,7 +2,6 @@ package cmd
import (
"context"
- "errors"
"fmt"
"log/slog"
"os/signal"
@@ -11,8 +10,8 @@ import (
tea "charm.land/bubbletea/v2"
"github.com/koopa0/koopa/internal/app"
+ "github.com/koopa0/koopa/internal/chat"
"github.com/koopa0/koopa/internal/config"
- "github.com/koopa0/koopa/internal/session"
"github.com/koopa0/koopa/internal/tui"
)
@@ -20,63 +19,42 @@ import (
func runCLI() error {
cfg, err := config.Load()
if err != nil {
- return err
+ return fmt.Errorf("loading config: %w", err)
}
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
- runtime, err := app.NewChatRuntime(ctx, cfg)
+ a, err := app.Setup(ctx, cfg)
if err != nil {
- return fmt.Errorf("failed to initialize runtime: %w", err)
+ return fmt.Errorf("initializing application: %w", err)
}
defer func() {
- if closeErr := runtime.Close(); closeErr != nil {
- slog.Warn("runtime close error", "error", closeErr)
+ if closeErr := a.Close(); closeErr != nil {
+ slog.Warn("shutdown error", "error", closeErr)
}
}()
- sessionID, err := getOrCreateSessionID(ctx, runtime.App.SessionStore, cfg)
+ agent, err := a.CreateAgent()
if err != nil {
- return fmt.Errorf("failed to get session: %w", err)
+ return fmt.Errorf("creating agent: %w", err)
}
- model, err := tui.New(ctx, runtime.Flow, sessionID)
- if err != nil {
- return fmt.Errorf("failed to create TUI: %w", err)
- }
- program := tea.NewProgram(model, tea.WithContext(ctx))
-
- if _, err = program.Run(); err != nil {
- return fmt.Errorf("TUI exited: %w", err)
- }
- return nil
-}
+ flow := chat.NewFlow(a.Genkit, agent)
-// getOrCreateSessionID returns a valid session ID, creating a new session if needed.
-func getOrCreateSessionID(ctx context.Context, store *session.Store, cfg *config.Config) (string, error) {
- currentID, err := session.LoadCurrentSessionID()
+ sessionID, err := a.SessionStore.ResolveCurrentSession(ctx)
if err != nil {
- return "", fmt.Errorf("failed to load session: %w", err)
+ return fmt.Errorf("resolving session: %w", err)
}
- if currentID != nil {
- if _, err = store.Session(ctx, *currentID); err == nil {
- return currentID.String(), nil
- }
- if !errors.Is(err, session.ErrSessionNotFound) {
- return "", fmt.Errorf("failed to validate session: %w", err)
- }
- }
-
- newSess, err := store.CreateSession(ctx, "New Session", cfg.ModelName, "You are a helpful assistant.")
+ model, err := tui.New(ctx, flow, sessionID)
if err != nil {
- return "", fmt.Errorf("failed to create session: %w", err)
+ return fmt.Errorf("creating TUI: %w", err)
}
+ program := tea.NewProgram(model, tea.WithContext(ctx))
- if err := session.SaveCurrentSessionID(newSess.ID); err != nil {
- slog.Warn("failed to save session state", "error", err)
+ if _, err = program.Run(); err != nil {
+ return fmt.Errorf("running TUI: %w", err)
}
-
- return newSess.ID.String(), nil
+ return nil
}
diff --git a/cmd/cmd.go b/cmd/cmd.go
new file mode 100644
index 0000000..3dfb5f5
--- /dev/null
+++ b/cmd/cmd.go
@@ -0,0 +1,76 @@
+// Package cmd provides CLI commands for Koopa.
+//
+// Commands:
+// - cli: Interactive terminal chat with Bubble Tea TUI
+// - serve: HTTP API server with SSE streaming
+// - mcp: Model Context Protocol server for IDE integration
+//
+// Signal handling and graceful shutdown are implemented
+// for all commands via context cancellation.
+package cmd
+
+import (
+ "fmt"
+ "log/slog"
+ "os"
+)
+
+// Execute is the main entry point for the Koopa CLI application.
+func Execute() error {
+ // Initialize logger once at entry point
+ level := slog.LevelInfo
+ if os.Getenv("DEBUG") != "" {
+ level = slog.LevelDebug
+ }
+ slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})))
+
+ if len(os.Args) < 2 {
+ runHelp()
+ return nil
+ }
+
+ switch os.Args[1] {
+ case "cli":
+ return runCLI()
+ case "serve":
+ return runServe()
+ case "mcp":
+ return runMCP()
+ case "version", "--version", "-v":
+ runVersion()
+ return nil
+ case "help", "--help", "-h":
+ runHelp()
+ return nil
+ default:
+ return fmt.Errorf("unknown command: %s", os.Args[1])
+ }
+}
+
+// runHelp displays the help message.
+func runHelp() {
+ fmt.Println("Koopa - Your terminal AI personal assistant")
+ fmt.Println()
+ fmt.Println("Usage:")
+ fmt.Println(" koopa cli Start interactive chat mode")
+ fmt.Println(" koopa serve [addr] Start HTTP API server (default: 127.0.0.1:3400)")
+ fmt.Println(" koopa mcp Start MCP server (for Claude Desktop/Cursor)")
+ fmt.Println(" koopa --version Show version information")
+ fmt.Println(" koopa --help Show this help")
+ fmt.Println()
+ fmt.Println("CLI Commands (in interactive mode):")
+ fmt.Println(" /help Show available commands")
+ fmt.Println(" /version Show version")
+ fmt.Println(" /clear Clear conversation history")
+ fmt.Println(" /exit, /quit Exit Koopa")
+ fmt.Println()
+ fmt.Println("Shortcuts:")
+ fmt.Println(" Ctrl+D Exit Koopa")
+ fmt.Println(" Ctrl+C Cancel current input")
+ fmt.Println()
+ fmt.Println("Environment Variables:")
+ fmt.Println(" GEMINI_API_KEY Required: Gemini API key")
+ fmt.Println(" DEBUG Optional: Enable debug logging")
+ fmt.Println()
+ fmt.Println("Learn more: https://github.com/koopa0/koopa")
+}
diff --git a/cmd/e2e_test.go b/cmd/e2e_test.go
index 2e3dd93..b3bf767 100644
--- a/cmd/e2e_test.go
+++ b/cmd/e2e_test.go
@@ -9,7 +9,7 @@ import (
"os"
"os/exec"
"path/filepath"
- goruntime "runtime" // Alias to avoid confusion with app.ChatRuntime
+ goruntime "runtime" // Alias to avoid conflict with runtime package name
"strings"
"testing"
"time"
@@ -101,7 +101,7 @@ func findOrBuildKoopa(t *testing.T) string {
cmd := exec.Command("go", "build", "-o", binName, ".")
cmd.Dir = projectRoot
if output, err := cmd.CombinedOutput(); err != nil {
- t.Fatalf("Failed to build koopa: %v\nOutput: %s", err, output)
+ t.Fatalf("go build error: %v\nOutput: %s", err, output)
}
return koopaBin
@@ -141,14 +141,14 @@ func TestE2E_VersionCommand(t *testing.T) {
output, err := ctx.runKoopaCommand(shortTimeout, "version")
if err != nil {
- t.Fatalf("version command unexpected error: %v", err)
+ t.Fatalf("running version command: %v", err)
}
if !strings.Contains(output, "Koopa") {
t.Errorf("version command output = %q, want to contain %q", output, "Koopa")
}
- if !strings.Contains(output, "v0.") {
- t.Errorf("version command output = %q, want to contain %q", output, "v0.")
+ if !strings.Contains(output, "v") {
+ t.Errorf("version command output = %q, want to contain %q", output, "v")
}
}
@@ -159,7 +159,7 @@ func TestE2E_ErrorRecovery(t *testing.T) {
t.Run("help command works", func(t *testing.T) {
output, err := ctx.runKoopaCommand(shortTimeout, "help")
if err != nil {
- t.Errorf("help command unexpected error: %v", err)
+ t.Errorf("running help command: %v", err)
}
if !strings.Contains(strings.ToLower(output), "koopa") {
t.Errorf("help command output = %q, want to contain %q", output, "koopa")
@@ -178,38 +178,10 @@ func TestE2E_ErrorRecovery(t *testing.T) {
// Version command should still work without API key
if err != nil {
- t.Errorf("version command without API key unexpected error: %v", err)
+ t.Errorf("running version command without API key: %v", err)
}
if !strings.Contains(output, "Koopa") {
t.Errorf("version command output = %q, want to contain %q", output, "Koopa")
}
})
}
-
-// TestE2E_IntegrationTestHelper verifies E2E test infrastructure
-func TestE2E_IntegrationTestHelper(t *testing.T) {
- ctx := setupE2ETest(t)
-
- // Verify binary exists
- if _, err := os.Stat(ctx.koopaBin); err != nil {
- t.Errorf("koopa binary should exist at %q, but got error: %v", ctx.koopaBin, err)
- }
-
- // Verify working directory
- if info, err := os.Stat(ctx.workDir); err != nil || !info.IsDir() {
- t.Errorf("working directory should exist at %q, but got error: %v", ctx.workDir, err)
- }
-
- // Verify environment
- if ctx.databaseURL == "" {
- t.Error("DATABASE_URL should be set")
- }
- if ctx.apiKey == "" {
- t.Error("GEMINI_API_KEY should be set")
- }
-
- t.Logf("E2E test infrastructure:")
- t.Logf(" Binary: %s", ctx.koopaBin)
- t.Logf(" WorkDir: %s", ctx.workDir)
- t.Logf(" Database: %s", ctx.databaseURL)
-}
diff --git a/cmd/execute.go b/cmd/execute.go
deleted file mode 100644
index 9bf77eb..0000000
--- a/cmd/execute.go
+++ /dev/null
@@ -1,58 +0,0 @@
-// Package cmd provides CLI commands for Koopa.
-//
-// Commands:
-// - cli: Interactive terminal chat with Bubble Tea TUI
-// - serve: HTTP API server with SSE streaming
-// - mcp: Model Context Protocol server for IDE integration
-//
-// Signal handling and graceful shutdown are implemented
-// for all commands via context cancellation.
-package cmd
-
-import (
- "fmt"
- "log/slog"
- "os"
-
- "github.com/koopa0/koopa/internal/log"
-)
-
-// Version is set at build time via ldflags:
-//
-// go build -ldflags "-X github.com/koopa0/koopa/cmd.Version=1.0.0"
-//
-// Default value "dev" indicates a development build.
-var Version = "dev"
-
-// Execute is the main entry point for the Koopa CLI application.
-func Execute() error {
- // Initialize logger once at entry point
- level := slog.LevelInfo
- if os.Getenv("DEBUG") != "" {
- level = slog.LevelDebug
- }
- slog.SetDefault(log.New(log.Config{Level: level}))
-
- if len(os.Args) < 2 {
- runHelp()
- return nil
- }
-
- switch os.Args[1] {
- case "cli":
- return runCLI()
- case "serve":
- return runServe()
- case "mcp":
- return runMCP()
- case "version", "--version", "-v":
- runVersion()
- return nil
- case "help", "--help", "-h":
- runHelp()
- return nil
- default:
- fmt.Fprintf(os.Stderr, "Error: unknown command %q\n", os.Args[1])
- return fmt.Errorf("unknown command: %s", os.Args[1])
- }
-}
diff --git a/cmd/integration_test.go b/cmd/integration_test.go
index e2e8618..0b0deef 100644
--- a/cmd/integration_test.go
+++ b/cmd/integration_test.go
@@ -5,133 +5,81 @@ package cmd
import (
"context"
- "fmt"
"os"
"path/filepath"
- "runtime"
"testing"
"time"
- "github.com/koopa0/koopa/internal/agent/chat"
+ "github.com/google/uuid"
"github.com/koopa0/koopa/internal/app"
+ "github.com/koopa0/koopa/internal/chat"
"github.com/koopa0/koopa/internal/config"
+ "github.com/koopa0/koopa/internal/testutil"
"github.com/koopa0/koopa/internal/tui"
)
-// findProjectRoot finds the project root directory by looking for go.mod.
-func findProjectRoot() (string, error) {
- _, filename, _, ok := runtime.Caller(0)
- if !ok {
- return "", fmt.Errorf("runtime.Caller failed to get caller info")
- }
-
- dir := filepath.Dir(filename)
- for {
- goModPath := filepath.Join(dir, "go.mod")
- if _, err := os.Stat(goModPath); err == nil {
- return dir, nil
- }
- parent := filepath.Dir(dir)
- if parent == dir {
- return "", fmt.Errorf("go.mod not found in any parent directory of %s", filename)
- }
- dir = parent
- }
-}
-
-// TestTUI_Integration tests the TUI can be created with real runtime.
-// Note: Bubble Tea TUI cannot be fully tested without a real TTY.
-// These tests verify initialization and component wiring.
-func TestTUI_Integration(t *testing.T) {
- if os.Getenv("GEMINI_API_KEY") == "" {
- t.Skip("GEMINI_API_KEY not set - skipping integration test")
- }
+// setupApp is a test helper that creates an App instance.
+func setupApp(t *testing.T) *app.App {
+ t.Helper()
- // Reset Flow singleton for test isolation
chat.ResetFlowForTesting()
cfg, err := config.Load()
if err != nil {
- t.Fatalf("Failed to load config: %v", err)
+ t.Fatalf("config.Load() error: %v", err)
}
- // Set absolute path for prompts directory (required for tests running from different directories)
- projectRoot, err := findProjectRoot()
+ projectRoot, err := testutil.FindProjectRoot()
if err != nil {
- t.Fatalf("Failed to find project root: %v", err)
+ t.Fatalf("testutil.FindProjectRoot() error: %v", err)
}
cfg.PromptDir = filepath.Join(projectRoot, "prompts")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
- defer cancel()
- // Initialize runtime
- runtime, err := app.NewChatRuntime(ctx, cfg)
+ a, err := app.Setup(ctx, cfg)
if err != nil {
- t.Fatalf("Failed to initialize runtime: %v", err)
+ cancel()
+ t.Fatalf("app.Setup() error: %v", err)
}
+
t.Cleanup(func() {
- if err := runtime.Close(); err != nil {
- t.Logf("runtime close error: %v", err)
+ if err := a.Close(); err != nil {
+ t.Logf("app close error: %v", err)
}
+ cancel()
})
- // Verify TUI can be created with real Flow
- tuiModel, err := tui.New(ctx, runtime.Flow, "test-session-id")
- if err != nil {
- t.Fatalf("Failed to create TUI: %v", err)
- }
-
- // Verify Init returns a command
- cmd := tuiModel.Init()
- if cmd == nil {
- t.Error("Init should return a command (blink + spinner)")
- }
+ return a
}
-// TestTUI_SlashCommands tests slash command handling.
-func TestTUI_SlashCommands(t *testing.T) {
+// TestTUI_Integration tests the TUI can be created with real dependencies.
+// Note: Bubble Tea TUI cannot be fully tested without a real TTY.
+// These tests verify initialization and component wiring.
+func TestTUI_Integration(t *testing.T) {
if os.Getenv("GEMINI_API_KEY") == "" {
t.Skip("GEMINI_API_KEY not set - skipping integration test")
}
- // Reset Flow singleton for test isolation
- chat.ResetFlowForTesting()
+ a := setupApp(t)
- cfg, err := config.Load()
+ agent, err := a.CreateAgent()
if err != nil {
- t.Fatalf("Failed to load config: %v", err)
+ t.Fatalf("CreateAgent() error: %v", err)
}
- // Set absolute path for prompts directory (required for tests running from different directories)
- projectRoot, err := findProjectRoot()
- if err != nil {
- t.Fatalf("Failed to find project root: %v", err)
- }
- cfg.PromptDir = filepath.Join(projectRoot, "prompts")
+ flow := chat.NewFlow(a.Genkit, agent)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
- runtime, err := app.NewChatRuntime(ctx, cfg)
+ tuiModel, err := tui.New(ctx, flow, uuid.New())
if err != nil {
- t.Fatalf("Failed to initialize runtime: %v", err)
+ t.Fatalf("tui.New() error: %v", err)
}
- t.Cleanup(func() {
- if err := runtime.Close(); err != nil {
- t.Logf("runtime close error: %v", err)
- }
- })
- tuiModel, err := tui.New(ctx, runtime.Flow, "test-session-id")
- if err != nil {
- t.Fatalf("Failed to create TUI: %v", err)
- }
-
- // Test /help command by simulating the message flow
- // Note: Full TUI testing requires teatest or similar framework
- view := tuiModel.View()
- if view.Content == nil {
- t.Error("View content should not be nil")
+ cmd := tuiModel.Init()
+ if cmd == nil {
+ t.Error("Init should return a command (blink + spinner)")
}
}
diff --git a/cmd/mcp.go b/cmd/mcp.go
index c0948c6..5b8a0c6 100644
--- a/cmd/mcp.go
+++ b/cmd/mcp.go
@@ -4,125 +4,52 @@ import (
"context"
"fmt"
"log/slog"
- "os"
"os/signal"
"syscall"
- "time"
"github.com/koopa0/koopa/internal/app"
"github.com/koopa0/koopa/internal/config"
"github.com/koopa0/koopa/internal/mcp"
- "github.com/koopa0/koopa/internal/security"
- "github.com/koopa0/koopa/internal/tools"
mcpSdk "github.com/modelcontextprotocol/go-sdk/mcp"
)
-// runMCP initializes and starts the MCP server.
-// This is called when the user runs `koopa mcp`.
+// runMCP initializes and starts the MCP server on stdio transport.
func runMCP() error {
cfg, err := config.Load()
if err != nil {
- fmt.Fprintln(os.Stderr, err)
- return err
+ return fmt.Errorf("loading config: %w", err)
}
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
- return RunMCP(ctx, cfg, Version)
-}
-
-// RunMCP starts the MCP server on stdio transport
-//
-// Architecture:
-// - Creates all toolsets with necessary dependencies
-// - Creates MCP Server wrapping the toolsets
-// - Connects to stdio transport for Claude Desktop/Cursor
-// - Signal handling is done by caller (executeMCP)
-//
-// Error handling:
-// - Returns error if initialization fails (App, Toolsets, Server)
-// - Returns error if server connection fails
-// - Graceful shutdown on context cancellation
-func RunMCP(ctx context.Context, cfg *config.Config, version string) error {
- slog.Info("starting MCP server", "version", version)
+ slog.Info("starting MCP server", "version", Version)
- // Initialize application
- application, cleanup, err := app.InitializeApp(ctx, cfg)
+ a, err := app.Setup(ctx, cfg)
if err != nil {
- return fmt.Errorf("failed to initialize application: %w", err)
+ return fmt.Errorf("initializing application: %w", err)
}
- // Cleanup order: App.Close (goroutines) first, then cleanup (DB pool, OTel)
- defer cleanup()
defer func() {
- if closeErr := application.Close(); closeErr != nil {
- slog.Warn("app close error", "error", closeErr)
+ if closeErr := a.Close(); closeErr != nil {
+ slog.Warn("shutdown error", "error", closeErr)
}
}()
- // Create all required tools with logger
- logger := slog.Default()
-
- // 1. FileTools
- fileTools, err := tools.NewFileTools(application.PathValidator, logger)
- if err != nil {
- return fmt.Errorf("failed to create file tools: %w", err)
- }
-
- // 2. SystemTools
- cmdValidator := security.NewCommand()
- envValidator := security.NewEnv()
- systemTools, err := tools.NewSystemTools(cmdValidator, envValidator, logger)
- if err != nil {
- return fmt.Errorf("failed to create system tools: %w", err)
- }
-
- // 3. NetworkTools
- networkTools, err := tools.NewNetworkTools(tools.NetworkConfig{
- SearchBaseURL: cfg.SearXNG.BaseURL,
- FetchParallelism: cfg.WebScraper.Parallelism,
- FetchDelay: time.Duration(cfg.WebScraper.DelayMs) * time.Millisecond,
- FetchTimeout: time.Duration(cfg.WebScraper.TimeoutMs) * time.Millisecond,
- }, logger)
- if err != nil {
- return fmt.Errorf("failed to create network tools: %w", err)
- }
-
- // 4. KnowledgeTools (optional - requires retriever from App)
- var knowledgeTools *tools.KnowledgeTools
- toolCategories := []string{"file", "system", "network"}
- if application.Retriever != nil {
- kt, ktErr := tools.NewKnowledgeTools(application.Retriever, application.DocStore, logger)
- if ktErr != nil {
- slog.Warn("knowledge tools unavailable", "error", ktErr)
- } else {
- knowledgeTools = kt
- toolCategories = append(toolCategories, "knowledge")
- }
- }
-
- // Create MCP Server with all tools
mcpServer, err := mcp.NewServer(mcp.Config{
- Name: "koopa",
- Version: version,
- FileTools: fileTools,
- SystemTools: systemTools,
- NetworkTools: networkTools,
- KnowledgeTools: knowledgeTools,
+ Name: "koopa",
+ Version: Version,
+ Logger: slog.Default(),
+ File: a.File,
+ System: a.System,
+ Network: a.Network,
+ Knowledge: a.Knowledge,
})
if err != nil {
- return fmt.Errorf("failed to create MCP server: %w", err)
+ return fmt.Errorf("creating MCP server: %w", err)
}
- slog.Info("MCP server initialized",
- "name", "koopa",
- "version", version,
- "tools", toolCategories)
- slog.Info("starting MCP server on stdio transport")
+ slog.Info("MCP server ready", "name", "koopa", "version", Version, "transport", "stdio")
- // Run server on stdio transport
- // This is a blocking call that handles all MCP protocol communication
- // The server will run until ctx is canceled or an error occurs
if err := mcpServer.Run(ctx, &mcpSdk.StdioTransport{}); err != nil {
return fmt.Errorf("MCP server error: %w", err)
}
diff --git a/cmd/serve.go b/cmd/serve.go
index 55f0673..2e19997 100644
--- a/cmd/serve.go
+++ b/cmd/serve.go
@@ -3,87 +3,87 @@ package cmd
import (
"context"
"errors"
- "flag"
"fmt"
"log/slog"
- "net"
"net/http"
- "os"
"os/signal"
- "strconv"
- "strings"
"syscall"
"time"
"github.com/koopa0/koopa/internal/api"
"github.com/koopa0/koopa/internal/app"
+ "github.com/koopa0/koopa/internal/chat"
"github.com/koopa0/koopa/internal/config"
)
// Server timeout configuration.
const (
- ReadHeaderTimeout = 10 * time.Second
- ReadTimeout = 30 * time.Second
- WriteTimeout = 2 * time.Minute // SSE streaming needs longer timeout
- IdleTimeout = 120 * time.Second
- ShutdownTimeout = 30 * time.Second
+ readHeaderTimeout = 10 * time.Second
+ readTimeout = 30 * time.Second
+ writeTimeout = 2 * time.Minute // SSE streaming needs longer timeout
+ idleTimeout = 2 * time.Minute
+ shutdownTimeout = 30 * time.Second
)
-// RunServe starts the HTTP API server (JSON REST + Health checks).
-//
-// Architecture:
-// - Validates required configuration (HMAC_SECRET)
-// - Initializes the application runtime
-// - Creates the API server with all routes
-// - Signal handling is done by caller (executeServe)
-func RunServe(ctx context.Context, cfg *config.Config, version, addr string) error {
- logger := slog.Default()
-
- // Validate HMAC_SECRET for serve mode
- if cfg.HMACSecret == "" {
- return errors.New("HMAC_SECRET environment variable is required for serve mode (min 32 characters)")
+// runServe initializes and starts the HTTP API server.
+func runServe() error {
+ cfg, err := config.Load()
+ if err != nil {
+ return fmt.Errorf("loading config: %w", err)
}
- if len(cfg.HMACSecret) < 32 {
- return fmt.Errorf("HMAC_SECRET must be at least 32 characters, got %d", len(cfg.HMACSecret))
+ if err = cfg.ValidateServe(); err != nil {
+ return fmt.Errorf("validating config: %w", err)
}
- logger.Info("starting HTTP API server", "version", version)
+ addr, err := parseServeAddr()
+ if err != nil {
+ return fmt.Errorf("parsing address: %w", err)
+ }
+
+ ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+ defer cancel()
- // Initialize runtime with all components
- runtime, err := app.NewChatRuntime(ctx, cfg)
+ logger := slog.Default()
+ logger.Info("starting HTTP API server", "version", Version)
+
+ a, err := app.Setup(ctx, cfg)
if err != nil {
- return fmt.Errorf("failed to initialize runtime: %w", err)
+ return fmt.Errorf("initializing application: %w", err)
}
defer func() {
- if closeErr := runtime.Close(); closeErr != nil {
- logger.Warn("runtime close error", "error", closeErr)
+ if closeErr := a.Close(); closeErr != nil {
+ logger.Warn("shutdown error", "error", closeErr)
}
}()
- // Create API server (JSON REST + Health checks)
+ agent, err := a.CreateAgent()
+ if err != nil {
+ return fmt.Errorf("creating agent: %w", err)
+ }
+
+ flow := chat.NewFlow(a.Genkit, agent)
+
apiServer, err := api.NewServer(api.ServerConfig{
Logger: logger,
- Genkit: runtime.App.Genkit,
- ModelName: cfg.FullModelName(),
- ChatFlow: runtime.Flow,
- SessionStore: runtime.App.SessionStore,
+ ChatAgent: agent,
+ ChatFlow: flow,
+ SessionStore: a.SessionStore,
CSRFSecret: []byte(cfg.HMACSecret),
CORSOrigins: cfg.CORSOrigins,
IsDev: cfg.PostgresSSLMode == "disable",
TrustProxy: cfg.TrustProxy,
})
if err != nil {
- return fmt.Errorf("failed to create API server: %w", err)
+ return fmt.Errorf("creating API server: %w", err)
}
- // Create HTTP server
srv := &http.Server{
Addr: addr,
Handler: apiServer.Handler(),
- ReadHeaderTimeout: ReadHeaderTimeout,
- ReadTimeout: ReadTimeout,
- WriteTimeout: WriteTimeout,
- IdleTimeout: IdleTimeout,
+ ReadHeaderTimeout: readHeaderTimeout,
+ ReadTimeout: readTimeout,
+ WriteTimeout: writeTimeout,
+ IdleTimeout: idleTimeout,
}
logger.Info("HTTP server ready",
@@ -100,10 +100,10 @@ func RunServe(ctx context.Context, cfg *config.Config, version, addr string) err
select {
case <-ctx.Done():
logger.Info("shutting down HTTP server")
- shutdownCtx, cancel := context.WithTimeout(context.Background(), ShutdownTimeout)
- defer cancel()
+ shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
+ defer shutdownCancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
- return fmt.Errorf("server shutdown failed: %w", err)
+ return fmt.Errorf("shutting down server: %w", err)
}
<-errCh
return nil
@@ -111,90 +111,6 @@ func RunServe(ctx context.Context, cfg *config.Config, version, addr string) err
if errors.Is(err, http.ErrServerClosed) {
return nil
}
- return err
+ return fmt.Errorf("HTTP server: %w", err)
}
}
-
-// runServe initializes and starts the HTTP API server.
-// This is called when the user runs `koopa serve`.
-func runServe() error {
- cfg, err := config.Load()
- if err != nil {
- fmt.Fprintln(os.Stderr, err)
- return err
- }
-
- addr, err := parseServeAddr()
- if err != nil {
- return err
- }
-
- ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
- defer cancel()
-
- return RunServe(ctx, cfg, Version, addr)
-}
-
-// parseServeAddr parses and validates the server address from command line arguments.
-// Uses flag.FlagSet for standard Go flag parsing, supporting:
-// - koopa serve :8080 (positional)
-// - koopa serve --addr :8080 (flag)
-// - koopa serve -addr :8080 (single dash)
-func parseServeAddr() (string, error) {
- const defaultAddr = "127.0.0.1:3400"
-
- serveFlags := flag.NewFlagSet("serve", flag.ContinueOnError)
- serveFlags.SetOutput(os.Stderr)
-
- addr := serveFlags.String("addr", defaultAddr, "Server address (host:port)")
-
- args := []string{}
- if len(os.Args) > 2 {
- args = os.Args[2:]
- }
-
- // Check for positional argument first (koopa serve :8080)
- if len(args) > 0 && !strings.HasPrefix(args[0], "-") {
- *addr = args[0]
- args = args[1:]
- }
-
- if err := serveFlags.Parse(args); err != nil {
- return "", fmt.Errorf("failed to parse serve flags: %w", err)
- }
-
- if err := validateAddr(*addr); err != nil {
- return "", fmt.Errorf("invalid address %q: %w", *addr, err)
- }
-
- return *addr, nil
-}
-
-// validateAddr validates the server address format.
-func validateAddr(addr string) error {
- host, port, err := net.SplitHostPort(addr)
- if err != nil {
- return fmt.Errorf("must be in host:port format: %w", err)
- }
-
- if host != "" && host != "localhost" {
- if ip := net.ParseIP(host); ip == nil {
- if strings.ContainsAny(host, " \t\n") {
- return fmt.Errorf("invalid host: %s", host)
- }
- }
- }
-
- if port == "" {
- return fmt.Errorf("port is required")
- }
- portNum, err := strconv.Atoi(port)
- if err != nil {
- return fmt.Errorf("port must be numeric: %w", err)
- }
- if portNum < 0 || portNum > 65535 {
- return fmt.Errorf("port must be 0-65535 (0 = auto-assign), got %d", portNum)
- }
-
- return nil
-}
diff --git a/cmd/version.go b/cmd/version.go
index b0511cc..00c630b 100644
--- a/cmd/version.go
+++ b/cmd/version.go
@@ -2,35 +2,14 @@ package cmd
import "fmt"
+// Version is set at build time via ldflags:
+//
+// go build -ldflags "-X github.com/koopa0/koopa/cmd.Version=1.0.0"
+//
+// Default value "dev" indicates a development build.
+var Version = "dev"
+
// runVersion displays version information.
func runVersion() {
fmt.Printf("Koopa v%s\n", Version)
}
-
-// runHelp displays the help message.
-func runHelp() {
- fmt.Println("Koopa - Your terminal AI personal assistant")
- fmt.Println()
- fmt.Println("Usage:")
- fmt.Println(" koopa cli Start interactive chat mode")
- fmt.Println(" koopa serve [addr] Start HTTP API server (default: 127.0.0.1:3400)")
- fmt.Println(" koopa mcp Start MCP server (for Claude Desktop/Cursor)")
- fmt.Println(" koopa --version Show version information")
- fmt.Println(" koopa --help Show this help")
- fmt.Println()
- fmt.Println("CLI Commands (in interactive mode):")
- fmt.Println(" /help Show available commands")
- fmt.Println(" /version Show version")
- fmt.Println(" /clear Clear conversation history")
- fmt.Println(" /exit, /quit Exit Koopa")
- fmt.Println()
- fmt.Println("Shortcuts:")
- fmt.Println(" Ctrl+D Exit Koopa")
- fmt.Println(" Ctrl+C Cancel current input")
- fmt.Println()
- fmt.Println("Environment Variables:")
- fmt.Println(" GEMINI_API_KEY Required: Gemini API key")
- fmt.Println(" DEBUG Optional: Enable debug logging")
- fmt.Println()
- fmt.Println("Learn more: https://github.com/koopa0/koopa")
-}
diff --git a/db/migrate.go b/db/migrate.go
index bbd2469..06ceb71 100644
--- a/db/migrate.go
+++ b/db/migrate.go
@@ -31,44 +31,38 @@ func Migrate(connURL string) error {
// Create source driver from embedded filesystem
source, err := iofs.New(migrationsFS, "migrations")
if err != nil {
- slog.Error("failed to create migration source", "error", err)
- return fmt.Errorf("failed to create migration source: %w", err)
+ return fmt.Errorf("creating migration source: %w", err)
}
// Convert postgres:// or postgresql:// to pgx5:// scheme for golang-migrate pgx v5 driver
dbURL, err := convertToMigrateURL(connURL)
if err != nil {
- slog.Error("invalid database URL", "error", err)
return err
}
// Create migrate instance with pgx5 driver
m, err := migrate.NewWithSourceInstance("iofs", source, dbURL)
if err != nil {
- slog.Error("failed to connect to database for migrations", "error", err)
- return fmt.Errorf("failed to create migrate instance: %w", err)
+ return fmt.Errorf("creating migrate instance: %w", err)
}
+ // best-effort: close errors are non-fatal during migration teardown
defer func() {
srcErr, dbErr := m.Close()
if srcErr != nil {
- slog.Warn("failed to close migration source", "error", srcErr)
+ slog.Warn("closing migration source", "error", srcErr)
}
if dbErr != nil {
- slog.Warn("failed to close migration database connection", "error", dbErr)
+ slog.Warn("closing migration database connection", "error", dbErr)
}
}()
// Check for dirty state before running migrations
version, dirty, verErr := m.Version()
if verErr != nil && !errors.Is(verErr, migrate.ErrNilVersion) {
- slog.Error("failed to check migration version", "error", verErr)
- return fmt.Errorf("failed to check migration version: %w", verErr)
+ return fmt.Errorf("checking migration version: %w", verErr)
}
if dirty {
- slog.Error("database is in dirty migration state - manual intervention required",
- "version", version,
- "hint", fmt.Sprintf("inspect schema and run: migrate force %d", version))
- return fmt.Errorf("database in dirty state (version=%d), manual cleanup required", version)
+ return fmt.Errorf("database in dirty state (version=%d): inspect schema and run: migrate force %d", version, version)
}
// Run migrations
@@ -78,16 +72,13 @@ func Migrate(connURL string) error {
return nil
}
- // Check for dirty state after failure
+ // Include dirty state info in error if migration left database dirty
postVersion, postDirty, postErr := m.Version()
if postErr == nil && postDirty {
- slog.Error("migration failed - database now in dirty state",
- "version", postVersion,
- "hint", fmt.Sprintf("fix the migration and run: migrate force %d", postVersion))
+ return fmt.Errorf("running migrations (database now dirty at version %d, fix and run: migrate force %d): %w", postVersion, postVersion, err)
}
- slog.Error("failed to run migrations", "error", err)
- return fmt.Errorf("failed to run migrations: %w", err)
+ return fmt.Errorf("running migrations: %w", err)
}
finalVersion, finalDirty, verErr := m.Version()
@@ -108,7 +99,7 @@ func convertToMigrateURL(connURL string) (string, error) {
// Parse URL to validate and extract components
u, err := url.Parse(connURL)
if err != nil {
- return "", fmt.Errorf("failed to parse database URL: %w", err)
+ return "", fmt.Errorf("parsing database URL: %w", err)
}
// Validate scheme
diff --git a/db/migrations/000001_init_schema.down.sql b/db/migrations/000001_init_schema.down.sql
index a715372..18a34c2 100644
--- a/db/migrations/000001_init_schema.down.sql
+++ b/db/migrations/000001_init_schema.down.sql
@@ -2,31 +2,18 @@
-- Drops all objects created by 000001_init_schema.up.sql in reverse order
-- ============================================================================
--- Drop Messages Table (including triggers and indexes)
+-- Drop Messages Table
-- ============================================================================
-DROP TRIGGER IF EXISTS update_message_updated_at ON message;
-DROP INDEX IF EXISTS idx_message_content_gin;
-DROP INDEX IF EXISTS idx_message_status;
-DROP INDEX IF EXISTS idx_incomplete_messages;
-DROP INDEX IF EXISTS idx_message_session_seq;
-DROP INDEX IF EXISTS idx_message_session_id;
-DROP TABLE IF EXISTS message;
+DROP TABLE IF EXISTS messages;
-- ============================================================================
--- Drop Sessions Table (including triggers and indexes)
+-- Drop Sessions Table (including indexes)
-- ============================================================================
-DROP TRIGGER IF EXISTS update_sessions_updated_at ON sessions;
DROP INDEX IF EXISTS idx_sessions_updated_at;
DROP TABLE IF EXISTS sessions;
--- ============================================================================
--- Drop Helper Functions
--- ============================================================================
-
-DROP FUNCTION IF EXISTS update_updated_at_column();
-
-- ============================================================================
-- Drop Documents Table (including indexes)
-- ============================================================================
diff --git a/db/migrations/000001_init_schema.up.sql b/db/migrations/000001_init_schema.up.sql
index 49cc5a5..45b6705 100644
--- a/db/migrations/000001_init_schema.up.sql
+++ b/db/migrations/000001_init_schema.up.sql
@@ -13,7 +13,7 @@ CREATE EXTENSION IF NOT EXISTS vector;
CREATE TABLE IF NOT EXISTS documents (
id TEXT PRIMARY KEY,
content TEXT NOT NULL,
- embedding vector(768) NOT NULL, -- text-embedding-004 dimension
+ embedding vector(768) NOT NULL, -- gemini-embedding-001 truncated via OutputDimensionality
source_type TEXT, -- Metadata column for filtering
metadata JSONB -- Additional metadata in JSON format
);
@@ -30,19 +30,6 @@ CREATE INDEX IF NOT EXISTS idx_documents_source_type ON documents(source_type);
CREATE INDEX IF NOT EXISTS idx_documents_metadata_gin
ON documents USING GIN (metadata jsonb_path_ops);
--- ============================================================================
--- Helper Functions
--- ============================================================================
-
--- Auto-update updated_at timestamp
-CREATE OR REPLACE FUNCTION update_updated_at_column()
-RETURNS TRIGGER AS $$
-BEGIN
- NEW.updated_at = NOW();
- RETURN NEW;
-END;
-$$ LANGUAGE plpgsql;
-
-- ============================================================================
-- Sessions Table
-- ============================================================================
@@ -51,65 +38,24 @@ CREATE TABLE IF NOT EXISTS sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
title TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- model_name TEXT,
- system_prompt TEXT,
- message_count INTEGER DEFAULT 0
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at DESC);
--- Use DO block for trigger (no IF NOT EXISTS syntax for triggers)
-DO $$
-BEGIN
- IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_sessions_updated_at') THEN
- CREATE TRIGGER update_sessions_updated_at
- BEFORE UPDATE ON sessions
- FOR EACH ROW
- EXECUTE FUNCTION update_updated_at_column();
- END IF;
-END $$;
-
-- ============================================================================
-- Messages Table
-- ============================================================================
-CREATE TABLE IF NOT EXISTS message (
+CREATE TABLE IF NOT EXISTS messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
session_id UUID NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
role TEXT NOT NULL,
content JSONB NOT NULL,
sequence_number INTEGER NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- status TEXT NOT NULL DEFAULT 'completed'
- CHECK (status IN ('streaming', 'completed', 'failed')),
- updated_at TIMESTAMPTZ DEFAULT NOW(),
+ -- UNIQUE constraint automatically creates index on (session_id, sequence_number)
CONSTRAINT unique_message_sequence UNIQUE (session_id, sequence_number),
CONSTRAINT message_role_check CHECK (role IN ('user', 'assistant', 'system', 'tool'))
);
-
-CREATE INDEX IF NOT EXISTS idx_message_session_id ON message(session_id);
-CREATE INDEX IF NOT EXISTS idx_message_session_seq ON message(session_id, sequence_number);
-CREATE INDEX IF NOT EXISTS idx_incomplete_messages ON message(session_id, updated_at)
- WHERE status IN ('streaming', 'failed');
-
--- Index for querying failed/streaming messages
-CREATE INDEX IF NOT EXISTS idx_message_status ON message(session_id, status)
- WHERE status != 'completed';
-
--- Index for message.content (ai.Part array stored as JSONB)
--- Enables fast queries like: WHERE content @> '[{"text": "search term"}]'
-CREATE INDEX IF NOT EXISTS idx_message_content_gin
- ON message USING GIN (content jsonb_path_ops);
-
--- Use DO block for trigger (no IF NOT EXISTS syntax for triggers)
-DO $$
-BEGIN
- IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_message_updated_at') THEN
- CREATE TRIGGER update_message_updated_at
- BEFORE UPDATE ON message
- FOR EACH ROW
- EXECUTE FUNCTION update_updated_at_column();
- END IF;
-END $$;
diff --git a/db/queries/documents.sql b/db/queries/documents.sql
deleted file mode 100644
index 3773879..0000000
--- a/db/queries/documents.sql
+++ /dev/null
@@ -1,63 +0,0 @@
--- Documents queries for sqlc
--- Generated code will be in internal/sqlc/documents.sql.go
-
--- name: UpsertDocument :exec
-INSERT INTO documents (id, content, embedding, source_type, metadata)
-VALUES ($1, $2, $3, $4, $5)
-ON CONFLICT (id) DO UPDATE SET
- content = EXCLUDED.content,
- embedding = EXCLUDED.embedding,
- source_type = EXCLUDED.source_type,
- metadata = EXCLUDED.metadata;
-
--- name: SearchDocuments :many
-SELECT id, content, metadata,
- (1 - (embedding <=> sqlc.arg(query_embedding)::vector))::float8 AS similarity
-FROM documents
-WHERE metadata @> sqlc.arg(filter_metadata)::jsonb
-ORDER BY similarity DESC
-LIMIT sqlc.arg(result_limit);
-
--- name: SearchDocumentsAll :many
-SELECT id, content, metadata,
- (1 - (embedding <=> sqlc.arg(query_embedding)::vector))::float8 AS similarity
-FROM documents
-ORDER BY similarity DESC
-LIMIT sqlc.arg(result_limit);
-
--- name: CountDocuments :one
-SELECT COUNT(*)
-FROM documents
-WHERE metadata @> $1::jsonb;
-
--- name: CountDocumentsAll :one
-SELECT COUNT(*)
-FROM documents;
-
--- name: GetDocument :one
-SELECT id, content, metadata
-FROM documents
-WHERE id = $1;
-
--- name: DeleteDocument :exec
-DELETE FROM documents
-WHERE id = $1;
-
--- ===== Optimized RAG Queries (SQL-level filtering) =====
-
--- name: SearchBySourceType :many
--- Generic search by source_type using dedicated indexed column
-SELECT id, content, metadata,
- (1 - (embedding <=> sqlc.arg(query_embedding)::vector))::float8 AS similarity
-FROM documents
-WHERE source_type = sqlc.arg(source_type)::text
-ORDER BY similarity DESC
-LIMIT sqlc.arg(result_limit);
-
--- name: ListDocumentsBySourceType :many
--- List all documents by source_type using dedicated indexed column
--- Used for listing indexed files without needing embeddings
-SELECT id, content, metadata
-FROM documents
-WHERE source_type = sqlc.arg(source_type)::text
-LIMIT sqlc.arg(result_limit);
diff --git a/db/queries/sessions.sql b/db/queries/sessions.sql
index 2258b3e..849204b 100644
--- a/db/queries/sessions.sql
+++ b/db/queries/sessions.sql
@@ -1,37 +1,26 @@
--- Sessions queries for sqlc
+-- Sessions and messages queries for sqlc
-- Generated code will be in internal/sqlc/sessions.sql.go
-- name: CreateSession :one
-INSERT INTO sessions (title, model_name, system_prompt)
-VALUES ($1, $2, $3)
+INSERT INTO sessions (title)
+VALUES ($1)
RETURNING *;
-- name: Session :one
-SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count
+SELECT id, title, created_at, updated_at
FROM sessions
WHERE id = $1;
--- name: ListSessions :many
-SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count
+-- name: Sessions :many
+SELECT id, title, created_at, updated_at
FROM sessions
ORDER BY updated_at DESC
LIMIT sqlc.arg(result_limit)
OFFSET sqlc.arg(result_offset);
--- name: ListSessionsWithMessages :many
--- Only list sessions that have messages or titles (not empty sessions)
--- This is used for sidebar to hide "New Chat" placeholder sessions
-SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count
-FROM sessions
-WHERE message_count > 0 OR title IS NOT NULL
-ORDER BY updated_at DESC
-LIMIT sqlc.arg(result_limit)
-OFFSET sqlc.arg(result_offset);
-
-- name: UpdateSessionUpdatedAt :exec
UPDATE sessions
-SET updated_at = NOW(),
- message_count = sqlc.arg(message_count)
+SET updated_at = NOW()
WHERE id = sqlc.arg(session_id);
-- name: UpdateSessionTitle :exec
@@ -47,57 +36,24 @@ WHERE id = $1;
-- name: AddMessage :exec
-- Add a message to a session
-INSERT INTO message (session_id, role, content, sequence_number)
+INSERT INTO messages (session_id, role, content, sequence_number)
VALUES ($1, $2, $3, $4);
-- name: Messages :many
-- Get all messages for a session ordered by sequence
-SELECT *
-FROM message
+SELECT id, session_id, role, content, sequence_number, created_at
+FROM messages
WHERE session_id = sqlc.arg(session_id)
ORDER BY sequence_number ASC
LIMIT sqlc.arg(result_limit)
OFFSET sqlc.arg(result_offset);
--- name: GetMaxSequenceNumber :one
--- Get max sequence number for a session
+-- name: MaxSequenceNumber :one
+-- Get max sequence number for a session (returns 0 if no messages)
SELECT COALESCE(MAX(sequence_number), 0)::integer AS max_seq
-FROM message
+FROM messages
WHERE session_id = $1;
--- name: CountMessages :one
--- Count messages in a session
-SELECT COUNT(*)::integer AS count
-FROM message
-WHERE session_id = sqlc.arg(session_id);
-
-- name: LockSession :one
-- Locks the session row to prevent concurrent modifications
SELECT id FROM sessions WHERE id = $1 FOR UPDATE;
-
--- name: DeleteMessages :exec
--- Delete all messages in a session
-DELETE FROM message
-WHERE session_id = sqlc.arg(session_id);
-
--- name: AddMessageWithID :one
--- Add message with pre-assigned ID and status (for streaming)
-INSERT INTO message (id, session_id, role, content, status, sequence_number)
-VALUES ($1, $2, $3, $4, $5, $6)
-RETURNING *;
-
--- name: UpdateMessageContent :exec
--- Update message content and mark as completed
-UPDATE message
-SET content = $2,
- status = 'completed',
- updated_at = NOW()
-WHERE id = $1;
-
--- name: UpdateMessageStatus :exec
--- Update message status (streaming/completed/failed)
-UPDATE message
-SET status = $2,
- updated_at = NOW()
-WHERE id = $1;
-
diff --git a/docs/architecture-report.md b/docs/architecture-report.md
deleted file mode 100644
index 6a75196..0000000
--- a/docs/architecture-report.md
+++ /dev/null
@@ -1,1654 +0,0 @@
-# Koopa Project Architecture Report
-
-> Generated: 2026-02-04 | Branch: main | Commit: 26e635b
-
----
-
-## 1. Project Overview
-
-### Module Information
-
-| Field | Value |
-|-------|-------|
-| **Go Module** | `github.com/koopa0/koopa` |
-| **Go Version** | 1.25.1 |
-| **License** | MIT |
-| **Entry Point** | `main.go` → `cmd.Execute()` |
-
-### Directory Structure (3 Levels)
-
-```
-koopa/
-├── main.go # Entry point → cmd.Execute()
-├── go.mod / go.sum # Dependencies
-├── Taskfile.yml # Build tasks (task CLI)
-├── docker-compose.yml # PostgreSQL + SearXNG + Redis
-├── .golangci.yml # Linter config
-├── .mcp.json # MCP server config (Genkit)
-├── .env.example # Environment template
-├── config.example.yaml # Advanced config template (~/.koopa/config.yaml)
-├── GENKIT.md # Genkit framework guidelines
-├── CLAUDE.md # Claude Code project rules
-│
-├── cmd/ # CLI Commands (Cobra-like, manual)
-│ ├── execute.go # Command dispatcher
-│ ├── cli.go # Interactive TUI mode (BubbleTea)
-│ ├── serve.go # HTTP API server
-│ ├── mcp.go # MCP server (stdio)
-│ ├── version.go # Version display
-│ ├── e2e_test.go # CLI E2E tests
-│ └── integration_test.go # CLI integration tests
-│
-├── internal/
-│ ├── agent/ # AI Agent abstraction
-│ │ ├── errors.go # Sentinel errors
-│ │ └── chat/ # Chat agent (Genkit Flow)
-│ │ ├── chat.go # Core agent (Execute, ExecuteStream)
-│ │ ├── flow.go # Genkit StreamingFlow definition
-│ │ ├── retry.go # Exponential backoff retry
-│ │ ├── circuit.go # Circuit breaker pattern
-│ │ ├── tokens.go # Token budget management
-│ │ └── *_test.go # Tests
-│ │
-│ ├── app/ # Application lifecycle (DI)
-│ │ ├── app.go # App container struct
-│ │ ├── runtime.go # Runtime wrapper (App + Flow + cleanup)
-│ │ ├── wire.go # Wire DI providers (10-step chain)
-│ │ ├── wire_gen.go # Wire generated code
-│ │ └── *_test.go # Tests
-│ │
-│ ├── config/ # Configuration
-│ │ ├── config.go # Main Config struct + Load() + Validate()
-│ │ ├── tools.go # MCP/SearXNG/WebScraper config
-│ │ └── *_test.go # Tests
-│ │
-│ ├── knowledge/ # Knowledge store (placeholder)
-│ │
-│ ├── log/ # Logger setup
-│ │ └── log.go # slog wrapper (Logger = *slog.Logger)
-│ │
-│ ├── mcp/ # MCP Server implementation
-│ │ ├── doc.go # Architecture documentation
-│ │ ├── server.go # MCP server (10 tools)
-│ │ ├── file.go # File tool MCP handlers
-│ │ ├── system.go # System tool MCP handlers
-│ │ ├── network.go # Network tool MCP handlers
-│ │ ├── util.go # Result → MCP conversion
-│ │ └── *_test.go # Tests
-│ │
-│ ├── observability/ # OpenTelemetry setup
-│ │ └── datadog.go # OTLP HTTP exporter → Datadog Agent
-│ │
-│ ├── rag/ # RAG retriever/indexer
-│ │ ├── constants.go # Source types, schema config
-│ │ ├── system.go # System knowledge indexing (6 docs)
-│ │ └── doc.go # Package doc
-│ │
-│ ├── security/ # Input validators (5 modules)
-│ │ ├── path.go # Path traversal prevention
-│ │ ├── command.go # Command injection prevention
-│ │ ├── env.go # Env variable access control
-│ │ ├── url.go # SSRF prevention
-│ │ ├── prompt.go # Prompt injection detection
-│ │ └── *_test.go # Tests (incl. fuzz)
-│ │
-│ ├── session/ # Session persistence (PostgreSQL)
-│ │ ├── types.go # Session, Message, History structs
-│ │ ├── store.go # Store (CRUD, transactions)
-│ │ ├── errors.go # Sentinel errors, status constants
-│ │ ├── state.go # ~/.koopa/current_session persistence
-│ │ └── *_test.go # Tests
-│ │
-│ ├── sqlc/ # Generated SQL code (sqlc)
-│ │ ├── db.go # DBTX interface, Queries struct
-│ │ ├── models.go # Document, Message, Session models
-│ │ ├── documents.sql.go # Document queries
-│ │ └── sessions.sql.go # Session/message queries
-│ │
-│ ├── testutil/ # Test utilities
-│ │ ├── db.go # Testcontainer PostgreSQL setup
-│ │ ├── embedder.go # Deterministic test embedder
-│ │ └── logger.go # No-op logger
-│ │
-│ ├── tools/ # Tool implementations
-│ │ ├── types.go # Result, Error, Status types
-│ │ ├── metadata.go # DangerLevel, ToolMetadata registry
-│ │ ├── emitter.go # ToolEventEmitter interface
-│ │ ├── events.go # WithEvents wrapper
-│ │ ├── file.go # File tools (5 tools)
-│ │ ├── system.go # System tools (3 tools)
-│ │ ├── network.go # Network tools (2 tools)
-│ │ ├── knowledge.go # Knowledge tools (3 tools)
-│ │ └── *_test.go # Tests (incl. fuzz)
-│ │
-│ ├── tui/ # Terminal UI (BubbleTea)
-│ │ ├── tui.go # Model + Init + Update + View
-│ │ ├── keys.go # Key bindings
-│ │ ├── commands.go # Streaming tea.Cmd
-│ │ ├── styles.go # Lipgloss styles
-│ │ └── *_test.go # Tests
-│ │
-│ └── web/ # HTTP server + Web UI
-│ ├── server.go # Server, route registration, security headers
-│ ├── middleware.go # Recovery, Logging, MethodOverride, Session, CSRF
-│ ├── handlers/
-│ │ ├── chat.go # POST /genui/send, GET /genui/stream
-│ │ ├── pages.go # GET /genui (main page)
-│ │ ├── sessions.go # Session/CSRF management
-│ │ ├── health.go # Health/ready probes
-│ │ └── *_test.go # Tests (unit + integration + fuzz)
-│ ├── sse/
-│ │ └── writer.go # SSE writer (chunks, done, error, tools)
-│ ├── page/
-│ │ └── chat.templ # Chat page template
-│ ├── layout/
-│ │ └── app.templ # Base HTML layout
-│ ├── component/
-│ │ ├── message_bubble.templ
-│ │ ├── sidebar.templ
-│ │ ├── chat_input.templ
-│ │ ├── empty_state.templ
-│ │ ├── session_placeholders.templ
-│ │ └── _reference/ # templUI Pro blocks (222 blocks)
-│ ├── static/ # CSS (Tailwind), JS (HTMX, Elements, Prism)
-│ └── e2e/ # Playwright E2E tests
-│
-├── db/ # Database layer
-│ ├── migrate.go # Embedded migration runner
-│ ├── migrations/
-│ │ ├── 000001_init_schema.up.sql
-│ │ └── 000001_init_schema.down.sql
-│ └── queries/
-│ ├── documents.sql # 8 document queries
-│ └── sessions.sql # 28 session/message queries
-│
-├── build/
-│ ├── frontend/ # Tailwind CSS build config
-│ └── sql/
-│ └── sqlc.yaml # SQLC code generation config
-│
-├── prompts/
-│ └── koopa.prompt # Dotprompt system prompt (665 lines)
-│
-├── scripts/ # Utility scripts
-├── searxng/ # SearXNG Docker config
-└── docs/ # Documentation
-```
-
-### Third-Party Dependencies
-
-| Dependency | Version | Purpose |
-|---|---|---|
-| **AI/LLM** |
-| `github.com/firebase/genkit/go` | v1.2.0 | LLM orchestration, Flow, tools, Dotprompt |
-| **Database** |
-| `github.com/jackc/pgx/v5` | v5.7.6 | PostgreSQL driver (connection pooling) |
-| `github.com/pgvector/pgvector-go` | v0.3.0 | pgvector extension (vector embeddings) |
-| `github.com/golang-migrate/migrate/v4` | v4.19.1 | Database schema migrations |
-| **MCP** |
-| `github.com/modelcontextprotocol/go-sdk` | v1.1.0 | MCP server SDK (official) |
-| **TUI/CLI** |
-| `charm.land/bubbletea/v2` | v2.0.0-rc.2 | Interactive terminal UI framework |
-| `charm.land/bubbles/v2` | v2.0.0-rc.1 | BubbleTea components (textarea, spinner) |
-| `charm.land/lipgloss/v2` | v2.0.0-beta.3 | Terminal styling/formatting |
-| `github.com/charmbracelet/glamour` | v0.10.0 | Markdown → terminal rendering |
-| **Web** |
-| `github.com/a-h/templ` | v0.3.960 | Go HTML template compiler (SSR) |
-| **Web Scraping** |
-| `github.com/gocolly/colly/v2` | v2.2.0 | Web scraping framework |
-| `github.com/PuerkitoBio/goquery` | v1.11.0 | HTML parsing (jQuery-like) |
-| `github.com/go-shiori/go-readability` | v0.0.0-20250217 | Article extraction (Readability) |
-| **Configuration** |
-| `github.com/spf13/viper` | v1.21.0 | Config file + env var management |
-| `github.com/google/wire` | v0.7.0 | Compile-time dependency injection |
-| **Utilities** |
-| `github.com/google/uuid` | v1.6.0 | UUID generation |
-| `github.com/google/jsonschema-go` | v0.3.0 | JSON Schema inference for tool inputs |
-| `github.com/gofrs/flock` | v0.13.0 | File-based locking |
-| `golang.org/x/sync` | v0.18.0 | errgroup for lifecycle management |
-| `golang.org/x/time` | v0.14.0 | rate.Limiter for rate limiting |
-| **Observability** |
-| `go.opentelemetry.io/otel/sdk` | v1.38.0 | OpenTelemetry tracing SDK |
-| `go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp` | v1.38.0 | OTLP HTTP exporter (Datadog) |
-| **Testing** |
-| `github.com/stretchr/testify` | v1.11.1 | Assertions and mocking |
-| `github.com/testcontainers/testcontainers-go` | v0.40.0 | Docker containers for integration tests |
-| `github.com/testcontainers/testcontainers-go/modules/postgres` | v0.40.0 | PostgreSQL testcontainer |
-| `github.com/playwright-community/playwright-go` | v0.5200.1 | Browser automation (E2E tests) |
-| `go.uber.org/goleak` | v1.3.0 | Goroutine leak detector |
-
----
-
-## 2. Core Type Definitions
-
-### 2.1 Agent Types
-
-#### `internal/agent/chat/chat.go`
-
-```go
-// Chat is Koopa's main conversational agent - FULLY IMPLEMENTED
-type Chat struct {
- languagePrompt string // Resolved language for prompt template
- maxTurns int // Agentic loop iterations (default: 5)
- ragTopK int // RAG documents to retrieve
- retryConfig RetryConfig // LLM retry settings
- circuitBreaker *CircuitBreaker // Failure handling
- rateLimiter *rate.Limiter // Proactive rate limiting (default: 10/s, burst 30)
- tokenBudget TokenBudget // Context window limits
- g *genkit.Genkit
- retriever ai.Retriever // For RAG retrieval
- sessions *session.Store // Session persistence
- logger log.Logger
- tools []ai.Tool // Pre-registered tools
- toolRefs []ai.ToolRef // Cached for efficiency
- toolNames string // Comma-separated for logging
- prompt ai.Prompt // Cached Dotprompt instance
-}
-// Methods: New(Config), Execute(ctx, sessionID, input), ExecuteStream(ctx, sessionID, input, callback)
-
-// Config contains all required parameters for Chat agent
-type Config struct {
- Genkit *genkit.Genkit
- Retriever ai.Retriever
- SessionStore *session.Store
- Logger log.Logger
- Tools []ai.Tool
- MaxTurns int
- RAGTopK int
- Language string
- RetryConfig RetryConfig
- CircuitBreakerConfig CircuitBreakerConfig
- RateLimiter *rate.Limiter
- TokenBudget TokenBudget
-}
-
-// Response represents the complete result of an agent execution
-type Response struct {
- FinalText string
- ToolRequests []*ai.ToolRequest
-}
-
-// StreamCallback is called for each chunk of streaming response
-type StreamCallback func(ctx context.Context, chunk *ai.ModelResponseChunk) error
-```
-
-#### `internal/agent/chat/flow.go`
-
-```go
-// Flow type alias for Genkit Streaming Flow
-type Flow = core.Flow[Input, Output, StreamChunk]
-
-const FlowName = "koopa/chat"
-
-type Input struct {
- Query string `json:"query"`
- SessionID string `json:"sessionId"`
-}
-
-type Output struct {
- Response string `json:"response"`
- SessionID string `json:"sessionId"`
-}
-
-type StreamChunk struct {
- Text string `json:"text"`
-}
-
-// Singleton management:
-// InitFlow(g, chat) - Initialize once via sync.Once
-// GetFlow() *Flow - Get initialized flow (panics if not initialized)
-// ResetFlowForTesting() - Reset for test isolation
-```
-
-#### `internal/agent/chat/retry.go`
-
-```go
-type RetryConfig struct {
- MaxRetries int // Default: 3
- InitialInterval time.Duration // Default: 500ms
- MaxInterval time.Duration // Default: 10s
-}
-// Methods: DefaultRetryConfig(), executeWithRetry(), retryableError()
-```
-
-#### `internal/agent/chat/circuit.go`
-
-```go
-type CircuitState int // CircuitClosed=0, CircuitOpen=1, CircuitHalfOpen=2
-
-type CircuitBreakerConfig struct {
- FailureThreshold int // Default: 5
- SuccessThreshold int // Default: 2
- Timeout time.Duration // Default: 30s
-}
-
-type CircuitBreaker struct { // Thread-safe (sync.RWMutex)
- mu sync.RWMutex
- state CircuitState
- failures int
- successes int
- lastFailure time.Time
- failureThreshold int
- successThreshold int
- timeout time.Duration
-}
-// Methods: NewCircuitBreaker(), Allow(), Success(), Failure(), State(), Reset()
-```
-
-#### `internal/agent/chat/tokens.go`
-
-```go
-type TokenBudget struct {
- MaxHistoryTokens int // Default: 8000
- MaxInputTokens int // Default: 2000
- ReservedTokens int // Default: 4000
-}
-// Methods: DefaultTokenBudget(), estimateTokens(), truncateHistory()
-```
-
-#### `internal/agent/errors.go`
-
-```go
-var (
- ErrInvalidSession = errors.New("invalid session")
- ErrExecutionFailed = errors.New("execution failed")
-)
-```
-
-### 2.2 Session / Message Types
-
-#### `internal/session/types.go`
-
-```go
-// History - Thread-safe conversation history (sync.RWMutex) - FULLY IMPLEMENTED
-type History struct {
- mu sync.RWMutex
- messages []*ai.Message
-}
-// Methods: NewHistory(), SetMessages(), Messages(), Add(), AddMessage(), Count(), Clear()
-
-// Session represents a conversation session
-type Session struct {
- ID uuid.UUID
- Title string
- CreatedAt time.Time
- UpdatedAt time.Time
- ModelName string
- SystemPrompt string
- MessageCount int
-}
-
-// Message represents a single conversation message
-type Message struct {
- ID uuid.UUID
- SessionID uuid.UUID
- Role string // "user" | "assistant" | "system" | "tool"
- Content []*ai.Part // Genkit Part slice (stored as JSONB)
- Status string // "streaming" | "completed" | "failed"
- SequenceNumber int
- CreatedAt time.Time
-}
-```
-
-#### `internal/session/store.go`
-
-```go
-// Store manages session persistence with PostgreSQL backend - FULLY IMPLEMENTED
-type Store struct {
- queries *sqlc.Queries
- pool *pgxpool.Pool
- logger *slog.Logger
-}
-// Methods: New(), CreateSession(), GetSession(), ListSessions(),
-// ListSessionsWithMessages(), DeleteSession(), UpdateSessionTitle(),
-// AddMessages(), GetMessages(), AppendMessages(), GetHistory(),
-// CreateMessagePair(), GetUserMessageBefore(), GetMessageByID(),
-// UpdateMessageContent(), UpdateMessageStatus()
-
-// MessagePair represents a user-assistant message pair for streaming
-type MessagePair struct {
- UserMsgID uuid.UUID
- AssistantMsgID uuid.UUID
- UserSeq int32
- AssistantSeq int32
-}
-```
-
-#### `internal/session/errors.go`
-
-```go
-const (
- StatusStreaming = "streaming"
- StatusCompleted = "completed"
- StatusFailed = "failed"
-)
-
-var (
- ErrSessionNotFound = errors.New("session not found")
- ErrMessageNotFound = errors.New("message not found")
-)
-
-const (
- DefaultHistoryLimit int32 = 100
- MaxHistoryLimit int32 = 10000
- MinHistoryLimit int32 = 10
-)
-```
-
-### 2.3 Tool / MCP Types
-
-#### `internal/tools/types.go`
-
-```go
-type Status string
-const (
- StatusSuccess Status = "success"
- StatusError Status = "error"
-)
-
-type ErrorCode string
-const (
- ErrCodeSecurity ErrorCode = "SecurityError"
- ErrCodeNotFound ErrorCode = "NotFound"
- ErrCodePermission ErrorCode = "PermissionDenied"
- ErrCodeIO ErrorCode = "IOError"
- ErrCodeExecution ErrorCode = "ExecutionError"
- ErrCodeTimeout ErrorCode = "TimeoutError"
- ErrCodeNetwork ErrorCode = "NetworkError"
- ErrCodeValidation ErrorCode = "ValidationError"
-)
-
-type Result struct {
- Status Status `json:"status"`
- Data any `json:"data,omitempty"`
- Error *Error `json:"error,omitempty"`
-}
-
-type Error struct {
- Code ErrorCode `json:"code"`
- Message string `json:"message"`
- Details any `json:"details,omitempty"`
-}
-```
-
-#### `internal/tools/metadata.go`
-
-```go
-type DangerLevel int
-const (
- DangerLevelSafe DangerLevel = iota // read_file, list_files, get_env, web_fetch
- DangerLevelWarning // write_file (reversible)
- DangerLevelDangerous // delete_file, execute_command (irreversible)
- DangerLevelCritical // Reserved for future
-)
-
-type ToolMetadata struct {
- Name string
- Description string
- RequiresConfirmation bool
- DangerLevel DangerLevel
- IsDangerousFunc func(params map[string]any) bool
- Category string
-}
-// Functions: GetToolMetadata(), GetAllToolMetadata(), IsDangerous(),
-// RequiresConfirmation(), GetDangerLevel(), ListToolsByDangerLevel()
-```
-
-#### `internal/tools/emitter.go`
-
-```go
-// ToolEventEmitter receives tool lifecycle events - INTERFACE
-type ToolEventEmitter interface {
- OnToolStart(name string)
- OnToolComplete(name string)
- OnToolError(name string)
-}
-// Functions: EmitterFromContext(ctx), ContextWithEmitter(ctx, emitter)
-```
-
-#### `internal/tools/file.go`
-
-```go
-type FileTools struct { // FULLY IMPLEMENTED
- pathVal *security.Path
- logger log.Logger
-}
-
-type ReadFileInput struct {
- Path string `json:"path" jsonschema_description:"The file path to read"`
-}
-type WriteFileInput struct {
- Path string `json:"path"`
- Content string `json:"content"`
-}
-type ListFilesInput struct {
- Path string `json:"path"`
-}
-type DeleteFileInput struct {
- Path string `json:"path"`
-}
-type GetFileInfoInput struct {
- Path string `json:"path"`
-}
-type FileEntry struct {
- Name string `json:"name"`
- Type string `json:"type"` // "file" | "directory"
-}
-```
-
-#### `internal/tools/system.go`
-
-```go
-type SystemTools struct { // FULLY IMPLEMENTED
- cmdVal *security.Command
- envVal *security.Env
- logger log.Logger
-}
-
-type ExecuteCommandInput struct {
- Command string `json:"command"`
- Args []string `json:"args,omitempty"`
-}
-type GetEnvInput struct {
- Key string `json:"key"`
-}
-type CurrentTimeInput struct{}
-```
-
-#### `internal/tools/network.go`
-
-```go
-type NetworkTools struct { // FULLY IMPLEMENTED
- searchBaseURL string
- searchClient *http.Client
- fetchParallelism int
- fetchDelay time.Duration
- fetchTimeout time.Duration
- urlValidator *security.URL
- skipSSRFCheck bool
- logger log.Logger
-}
-
-type NetworkConfig struct {
- SearchBaseURL string
- FetchParallelism int
- FetchDelay time.Duration
- FetchTimeout time.Duration
-}
-```
-
-#### `internal/tools/knowledge.go`
-
-```go
-type KnowledgeTools struct { // FULLY IMPLEMENTED
- retriever ai.Retriever
- logger log.Logger
-}
-
-type KnowledgeSearchInput struct {
- Query string `json:"query"`
- TopK int `json:"topK,omitempty"`
-}
-```
-
-#### `internal/mcp/server.go`
-
-```go
-type Server struct { // FULLY IMPLEMENTED
- mcpServer *mcp.Server
- fileTools *tools.FileTools
- systemTools *tools.SystemTools
- networkTools *tools.NetworkTools
- name string
- version string
-}
-
-type Config struct {
- Name string
- Version string
- FileTools *tools.FileTools
- SystemTools *tools.SystemTools
- NetworkTools *tools.NetworkTools
-}
-// Methods: NewServer(Config), Run(ctx, transport), registerTools()
-```
-
-### 2.4 Security Types
-
-#### `internal/security/path.go`
-
-```go
-type Path struct { // FULLY IMPLEMENTED
- allowedDirs []string
- workDir string
-}
-// Methods: NewPath(allowedDirs), Validate(path) (string, error)
-
-var (
- ErrPathOutsideAllowed = errors.New("path is outside allowed directories")
- ErrSymlinkOutsideAllowed = errors.New("symbolic link points outside allowed directories")
- ErrPathNullByte = errors.New("path contains null byte")
-)
-```
-
-#### `internal/security/command.go`
-
-```go
-type Command struct { // FULLY IMPLEMENTED - WHITELIST mode
- blacklist []string
- whitelist []string
-}
-// Methods: NewCommand(), ValidateCommand(cmd, args), QuoteCommandArgs(args)
-// 53 whitelisted commands: ls, cat, grep, git, go, npm, etc.
-```
-
-#### `internal/security/env.go`
-
-```go
-type Env struct { // FULLY IMPLEMENTED
- sensitivePatterns []string // 87 patterns blocked
-}
-// Methods: NewEnv(), Validate(key), GetAllowedEnvNames()
-// 23 allowed variables: PATH, HOME, GOPATH, etc.
-```
-
-#### `internal/security/url.go`
-
-```go
-type URL struct { // FULLY IMPLEMENTED
- allowedSchemes map[string]struct{}
- blockedHosts map[string]struct{}
-}
-// Methods: NewURL(), Validate(rawURL), SafeTransport(), ValidateRedirect()
-// Blocks: localhost, metadata endpoints, private IPs, link-local
-```
-
-#### `internal/security/prompt.go`
-
-```go
-type PromptValidator struct { // FULLY IMPLEMENTED
- patterns []*regexp.Regexp // 13 patterns
-}
-
-type PromptInjectionResult struct {
- Safe bool
- Patterns []string
-}
-// Methods: NewPromptValidator(), Validate(input), IsSafe(input)
-```
-
-### 2.5 Config Types
-
-#### `internal/config/config.go`
-
-```go
-type Config struct { // FULLY IMPLEMENTED
- // AI
- ModelName string
- Temperature float32
- MaxTokens int
- Language string
- PromptDir string
- // History
- MaxHistoryMessages int32
- MaxTurns int
- // Database
- DatabasePath string
- PostgresHost string
- PostgresPort int
- PostgresUser string
- PostgresPassword string // Masked in MarshalJSON
- PostgresDBName string
- PostgresSSLMode string
- // RAG
- RAGTopK int
- EmbedderModel string
- // MCP
- MCP MCPConfig
- MCPServers map[string]MCPServer
- // Tools
- SearXNG SearXNGConfig
- WebScraper WebScraperConfig
- // Observability
- Datadog DatadogConfig
- // Security
- HMACSecret string // Masked in MarshalJSON
-}
-// Methods: Load(), Validate(), MarshalJSON()
-```
-
-#### `internal/config/tools.go`
-
-```go
-type MCPConfig struct {
- Allowed []string
- Excluded []string
- Timeout int
-}
-
-type MCPServer struct {
- Command string
- Args []string
- Env map[string]string
- Timeout int
- IncludeTools []string
- ExcludeTools []string
-}
-
-type SearXNGConfig struct {
- BaseURL string
-}
-
-type WebScraperConfig struct {
- Parallelism int
- DelayMs int
- TimeoutMs int
-}
-```
-
-### 2.6 App / Runtime Types
-
-#### `internal/app/app.go`
-
-```go
-type App struct { // FULLY IMPLEMENTED
- Config *config.Config
- Genkit *genkit.Genkit
- Embedder ai.Embedder
- DBPool *pgxpool.Pool
- DocStore *postgresql.DocStore
- Retriever ai.Retriever
- SessionStore *session.Store
- PathValidator *security.Path
- Tools []ai.Tool
- ctx context.Context
- cancel context.CancelFunc
- eg *errgroup.Group
- egCtx context.Context
-}
-// Methods: Close(), Wait(), Go(func() error), CreateAgent(ctx)
-```
-
-#### `internal/app/runtime.go`
-
-```go
-type Runtime struct { // FULLY IMPLEMENTED
- App *App
- Flow *chat.Flow
- cleanup func()
-}
-// Factory: NewRuntime(ctx, cfg) → single init point for all entry points
-// Methods: Close() error (App.Close() → Wire cleanup)
-```
-
-### 2.7 Web Handler Types
-
-#### `internal/web/handlers/chat.go`
-
-```go
-// SSEWriter - INTERFACE
-type SSEWriter interface {
- WriteChunkRaw(msgID, htmlContent string) error
- WriteDone(ctx context.Context, msgID string, comp templ.Component) error
- WriteError(msgID, code, message string) error
- WriteSidebarRefresh(sessionID, title string) error
-}
-
-type ChatConfig struct {
- Logger *slog.Logger
- Genkit *genkit.Genkit
- Flow *chat.Flow
- Sessions *Sessions
- SSEWriterFn func(w http.ResponseWriter) (SSEWriter, error)
-}
-
-type Chat struct { // FULLY IMPLEMENTED
- logger *slog.Logger
- genkit *genkit.Genkit
- flow *chat.Flow
- sessions *Sessions
- sseWriterFn func(w http.ResponseWriter) (SSEWriter, error)
-}
-// Methods: NewChat(ChatConfig), Send(w, r), Stream(w, r)
-
-type streamState struct {
- msgID string
- sessionID string
- buffer strings.Builder
-}
-```
-
-#### `internal/web/handlers/sessions.go`
-
-```go
-type Sessions struct { // FULLY IMPLEMENTED
- store *session.Store
- hmacSecret []byte
- isDev bool
-}
-// Methods: NewSessions(), GetOrCreate(r), ID(r), NewCSRFToken(sessionID),
-// CheckCSRF(r), RegisterRoutes(mux)
-```
-
-#### `internal/web/sse/writer.go`
-
-```go
-type Writer struct { // FULLY IMPLEMENTED
- w io.Writer
- flusher http.Flusher
-}
-// Methods: NewWriter(), WriteChunk(), WriteChunkRaw(), WriteDone(),
-// WriteError(), WriteSidebarRefresh(),
-// WriteToolStart(), WriteToolComplete(), WriteToolError()
-```
-
-### 2.8 TUI Types
-
-#### `internal/tui/tui.go`
-
-```go
-type State int
-const (
- StateInput State = iota // Awaiting input
- StateThinking // Processing request
- StateStreaming // Streaming response
-)
-
-type Message struct {
- Role string // "user" | "assistant" | "system" | "error"
- Content string
-}
-
-type TUI struct { // FULLY IMPLEMENTED - BubbleTea Model
- input textarea.Model
- history []string
- historyIdx int
- state State
- lastCtrlC time.Time
- spinner spinner.Model
- output strings.Builder
- viewBuf strings.Builder
- messages []Message
- streamCancel context.CancelFunc
- streamEventCh <-chan streamEvent
- chatFlow *chat.Flow
- sessionID string
- ctx context.Context
- ctxCancel context.CancelFunc
- width int
- height int
- styles Styles
- markdown *markdownRenderer
-}
-// Implements tea.Model: Init(), Update(msg), View()
-```
-
----
-
-## 3. Genkit Integration
-
-### 3.1 Initialization
-
-**Location**: `internal/app/wire.go` lines 107-130
-
-```go
-g := genkit.Init(ctx,
- genkit.WithPlugins(&googlegenai.GoogleAI{}, postgres),
- genkit.WithPromptDir(promptDir), // Default: "prompts"
-)
-```
-
-**Plugins Used**:
-| Plugin | Purpose |
-|--------|---------|
-| `googlegenai.GoogleAI{}` | Gemini model access (chat + embeddings) |
-| `postgresql.Postgres{}` | DocStore + Retriever for RAG (pgvector) |
-
-**Initialization Order** (Wire DI):
-1. OpenTelemetry setup (tracing before Genkit)
-2. Database pool (migrations run)
-3. PostgreSQL plugin creation
-4. Genkit Init with plugins
-5. Embedder provision
-6. RAG components (DocStore + Retriever)
-7. Session store
-8. Security validators
-9. Tool registration (all 13 tools)
-10. App construction
-
-### 3.2 Flow Definitions
-
-**One flow defined**: `koopa/chat`
-
-| Flow Name | Type | Input | Output | Stream Type |
-|-----------|------|-------|--------|-------------|
-| `koopa/chat` | `genkit.DefineStreamingFlow` | `Input{Query, SessionID}` | `Output{Response, SessionID}` | `StreamChunk{Text}` |
-
-**Flow lifecycle**:
-- Singleton via `sync.Once` in `InitFlow(g, chat)`
-- Access via `GetFlow()` (panics if not initialized)
-- `ResetFlowForTesting()` for test isolation
-
-**Flow implementation** (`internal/agent/chat/flow.go` lines 87-149):
-1. Parse session UUID from `input.SessionID`
-2. Wrap `streamCb` into `StreamCallback` (adapts `StreamChunk` → `ai.ModelResponseChunk`)
-3. Call `chat.ExecuteStream(ctx, sessionID, query, callback)`
-4. Return `Output{Response, SessionID}`
-
-### 3.3 Tool Calling
-
-**13 tools registered** at app initialization via `genkit.DefineTool()`:
-
-| Category | Tool Name | Danger Level |
-|----------|-----------|-------------|
-| **File** | `read_file` | Safe |
-| | `write_file` | Warning |
-| | `list_files` | Safe |
-| | `delete_file` | Dangerous |
-| | `get_file_info` | Safe |
-| **System** | `current_time` | Safe |
-| | `execute_command` | Dangerous |
-| | `get_env` | Safe |
-| **Network** | `web_search` | Safe |
-| | `web_fetch` | Safe |
-| **Knowledge** | `search_history` | Safe |
-| | `search_documents` | Safe |
-| | `search_system_knowledge` | Safe |
-
-**Tool middleware**: `WithEvents` wrapper — emits `OnToolStart`/`OnToolComplete`/`OnToolError` lifecycle events via `ToolEventEmitter` interface (used for SSE tool status display).
-
-**No other middleware** (no rate limiting, no caching, no request/response interceptors).
-
-**Agentic loop**: `ai.WithMaxTurns(maxTurns)` — default 5 turns. LLM can call tools iteratively.
-
-### 3.4 Structured Output
-
-**Not used**. The flow uses plain string output:
-- Input: JSON object `{query, sessionId}`
-- Output: JSON object `{response, sessionId}`
-- Streaming: JSON chunks `{text}`
-
-Tool inputs use `jsonschema_description` tags for schema inference, but no output schema validation.
-
-### 3.5 Session Management
-
-**Custom** (NOT Genkit built-in). PostgreSQL-backed via `session.Store`:
-
-```
-Agent.ExecuteStream()
- → sessions.GetHistory(ctx, sessionID) // Load []*ai.Message from DB
- → generateResponse(ctx, input, messages) // LLM call
- → sessions.AppendMessages(ctx, sessionID) // Persist new messages
-```
-
-Messages stored as JSONB in `message.content` column, serialized from `[]*ai.Part`.
-
-### 3.6 Streaming
-
-**Two-layer architecture**:
-
-1. **Agent → Flow** (internal): `ai.WithStreaming(callback)` where callback receives `*ai.ModelResponseChunk`
-2. **Flow → HTTP** (SSE): Go 1.23 `range-over-func` iterator over `flow.Stream(ctx, input)`
-
-**SSE endpoint**: `GET /genui/stream?msgId=X&session_id=Y`
-- 5-minute timeout
-- HTML escaping (XSS prevention)
-- OOB swaps for HTMX (sidebar refresh, final message replacement)
-
----
-
-## 4. Data Layer
-
-### 4.1 Database
-
-| Field | Value |
-|-------|-------|
-| **DBMS** | PostgreSQL 17 with pgvector extension |
-| **Driver** | `pgx/v5` (connection pooling) |
-| **Pool Config** | MaxConns=10, MinConns=2, MaxLifetime=30m, MaxIdle=5m, HealthCheck=1m |
-| **Vector Extension** | pgvector (vector(768), HNSW index, cosine distance) |
-| **Query Generation** | sqlc (type-safe Go from SQL) |
-| **Migration Tool** | golang-migrate v4 |
-| **Docker Image** | `pgvector/pgvector:pg17` |
-
-### 4.2 SQLC Configuration
-
-**Config**: `build/sql/sqlc.yaml`
-- Engine: PostgreSQL
-- Driver: `pgx/v5`
-- Output: `internal/sqlc/`
-- Features: JSON tags, empty slices (not nil), UUID type override
-
-**Query Files**:
-
-| File | Queries | Purpose |
-|------|---------|---------|
-| `db/queries/documents.sql` | 8 | Document CRUD + vector search |
-| `db/queries/sessions.sql` | 28 | Session/message CRUD + streaming lifecycle |
-
-### 4.3 Migration Files
-
-Single migration: `db/migrations/000001_init_schema.up.sql`
-
-Migrations are **embedded** via `//go:embed` and run automatically at startup in `provideDBPool()`.
-
-### 4.4 Complete Schema (CREATE TABLE)
-
-```sql
--- Extension
-CREATE EXTENSION IF NOT EXISTS vector;
-
--- Helper function
-CREATE OR REPLACE FUNCTION update_updated_at_column()
-RETURNS TRIGGER AS $$
-BEGIN
- NEW.updated_at = NOW();
- RETURN NEW;
-END;
-$$ LANGUAGE plpgsql;
-
--- Table: documents (RAG/Knowledge Store)
-CREATE TABLE IF NOT EXISTS documents (
- id TEXT PRIMARY KEY,
- content TEXT NOT NULL,
- embedding vector(768) NOT NULL,
- source_type TEXT,
- metadata JSONB
-);
-CREATE INDEX IF NOT EXISTS idx_documents_embedding
- ON documents USING hnsw (embedding vector_cosine_ops)
- WITH (m = 16, ef_construction = 64);
-CREATE INDEX IF NOT EXISTS idx_documents_source_type ON documents(source_type);
-CREATE INDEX IF NOT EXISTS idx_documents_metadata_gin
- ON documents USING GIN (metadata jsonb_path_ops);
-
--- Table: sessions
-CREATE TABLE IF NOT EXISTS sessions (
- id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
- title TEXT,
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- model_name TEXT,
- system_prompt TEXT,
- message_count INTEGER DEFAULT 0
-);
-CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at DESC);
-CREATE TRIGGER update_sessions_updated_at
- BEFORE UPDATE ON sessions FOR EACH ROW
- EXECUTE FUNCTION update_updated_at_column();
-
--- Table: message
-CREATE TABLE IF NOT EXISTS message (
- id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
- session_id UUID NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
- role TEXT NOT NULL,
- content JSONB NOT NULL,
- sequence_number INTEGER NOT NULL,
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- status TEXT NOT NULL DEFAULT 'completed'
- CHECK (status IN ('streaming', 'completed', 'failed')),
- updated_at TIMESTAMPTZ DEFAULT NOW(),
- CONSTRAINT unique_message_sequence UNIQUE (session_id, sequence_number),
- CONSTRAINT message_role_check CHECK (role IN ('user', 'assistant', 'system', 'tool'))
-);
-CREATE INDEX IF NOT EXISTS idx_message_session_id ON message(session_id);
-CREATE INDEX IF NOT EXISTS idx_message_session_seq ON message(session_id, sequence_number);
-CREATE INDEX IF NOT EXISTS idx_incomplete_messages ON message(session_id, updated_at)
- WHERE status IN ('streaming', 'failed');
-CREATE INDEX IF NOT EXISTS idx_message_status ON message(session_id, status)
- WHERE status != 'completed';
-CREATE INDEX IF NOT EXISTS idx_message_content_gin
- ON message USING GIN (content jsonb_path_ops);
-CREATE TRIGGER update_message_updated_at
- BEFORE UPDATE ON message FOR EACH ROW
- EXECUTE FUNCTION update_updated_at_column();
-```
-
-### 4.5 pgvector Usage
-
-| Aspect | Detail |
-|--------|--------|
-| **Dimension** | 768 (text-embedding-004 model) |
-| **Index Type** | HNSW (m=16, ef_construction=64) |
-| **Distance Metric** | Cosine (`<=>` operator) |
-| **Similarity Formula** | `(1 - (embedding <=> query_embedding))::float8` |
-| **Table** | `documents` |
-| **Source Types** | `conversation`, `file`, `system` |
-| **System Knowledge** | 6 pre-indexed documents at startup |
-
----
-
-## 5. Channel Adapter Implementation
-
-### Current State: No External Channel Adapters
-
-The project has **three access channels**, all built-in:
-
-| Channel | Implementation | Transport |
-|---------|---------------|-----------|
-| **CLI/TUI** | `internal/tui/` (BubbleTea) | Terminal stdin/stdout |
-| **Web Chat** | `internal/web/` (Templ + HTMX) | HTTP + SSE |
-| **MCP** | `internal/mcp/` (MCP SDK) | stdio |
-
-**No Telegram, LINE, Discord, Slack, or WhatsApp adapters exist.**
-
-### Web Chat Details
-
-| Aspect | Detail |
-|--------|--------|
-| **Main Files** | `internal/web/server.go`, `handlers/chat.go`, `handlers/pages.go`, `handlers/sessions.go`, `sse/writer.go` |
-| **Template Engine** | Templ (Go SSR) |
-| **CSS Framework** | Tailwind CSS (compiled, embedded) |
-| **Interactivity** | HTMX + Tailwind Plus Elements |
-| **Streaming** | Server-Sent Events (SSE) |
-| **Session** | HTTP cookie (30-day expiry) + CSRF tokens |
-| **Routes** | `GET /genui`, `POST /genui/send`, `GET /genui/stream`, `POST /genui/sessions/*` |
-
-### CLI/TUI Details
-
-| Aspect | Detail |
-|--------|--------|
-| **Framework** | BubbleTea v2 |
-| **State Machine** | 3 states: Input → Thinking → Streaming |
-| **Streaming** | Go 1.23 range-over-func iterator, discriminated union channel |
-| **Key Bindings** | Enter=submit, Shift+Enter=newline, Esc=cancel, Ctrl+C×2=quit |
-| **Markdown** | Glamour rendering (graceful degradation to plain text) |
-| **History** | In-memory (max 100 entries), Up/Down navigation |
-
----
-
-## 6. MCP Integration
-
-### 6.1 SDK
-
-| Field | Value |
-|-------|-------|
-| **SDK** | `github.com/modelcontextprotocol/go-sdk` v1.1.0 (official) |
-| **Role** | **MCP Server** only (not client) |
-| **Transport** | Stdio |
-| **Entry Point** | `cmd/mcp.go` → `mcp.NewServer(config).Run(ctx, &StdioTransport{})` |
-
-### 6.2 MCP Server Tools (10 tools)
-
-Same as Genkit tools minus the 3 knowledge tools (knowledge tools require Genkit retriever, MCP server doesn't initialize full Genkit):
-
-| Tool | Description |
-|------|-------------|
-| `read_file` | Read file content (max 10MB) |
-| `write_file` | Create/overwrite files |
-| `list_files` | List directory contents |
-| `delete_file` | Delete files |
-| `get_file_info` | File metadata |
-| `current_time` | System time |
-| `execute_command` | Whitelisted commands |
-| `get_env` | Non-sensitive env vars |
-| `web_search` | SearXNG search |
-| `web_fetch` | URL fetch with SSRF protection |
-
-### 6.3 Tool Registry Management
-
-- Tools registered via `mcp.AddTool()` in `Server.registerTools()`
-- Input schemas auto-generated using `jsonschema.For[T](nil)`
-- Results converted via `resultToMCP()` with error detail sanitization (blocks stack traces, file paths, API keys)
-
-### 6.4 Connected MCP Servers
-
-The `.mcp.json` configures one MCP server for **development** use with Claude Code:
-
-```json
-{
- "mcpServers": {
- "genkit": {
- "command": "genkit",
- "args": ["mcp", "--no-update-notification"]
- }
- }
-}
-```
-
-`config.example.yaml` shows additional configurable MCP servers (fetch, filesystem, github) but these are **configuration templates**, not actively connected.
-
----
-
-## 7. Permission / Security
-
-### 7.1 Security Validators (5 modules)
-
-| Validator | File | Prevents | Mechanism |
-|-----------|------|----------|-----------|
-| **Path** | `security/path.go` | Directory traversal (CWE-22) | Whitelist allowed dirs, symlink resolution, null byte rejection |
-| **Command** | `security/command.go` | Command injection (CWE-78) | Whitelist 53 safe commands, argument validation |
-| **Env** | `security/env.go` | Info leakage | Block 87 sensitive patterns, allow 23 safe vars |
-| **URL** | `security/url.go` | SSRF (CWE-918) | Block private IPs, metadata endpoints, DNS rebinding protection |
-| **Prompt** | `security/prompt.go` | Prompt injection | 13 regex patterns, Unicode normalization |
-
-### 7.2 Tool Call Permission Check
-
-Security validation is **inline** — each tool validates its inputs before execution:
-- `FileTools.ReadFile()` → `pathVal.Validate(input.Path)`
-- `SystemTools.ExecuteCommand()` → `cmdVal.ValidateCommand(cmd, args)`
-- `SystemTools.GetEnv()` → `envVal.Validate(key)`
-- `NetworkTools.WebFetch()` → `urlValidator.Validate(url)` + `SafeTransport()`
-
-Security failures return `Result{Status: StatusError, Error: &Error{Code: ErrCodeSecurity}}` — business errors, not Go errors. This allows the LLM to handle rejections gracefully.
-
-### 7.3 Approval Flow
-
-**Not implemented**. `ToolMetadata.RequiresConfirmation` and `DangerLevel` are defined but there is no runtime approval mechanism. The metadata system is in place as infrastructure for a future approval flow.
-
-### 7.4 Audit Logging
-
-**slog-based** security event logging with structured fields:
-
-```go
-logger.Warn("security_event",
- "type", "path_traversal_attempt",
- "path", unsafePath,
- "allowed_dirs", allowedDirs,
-)
-```
-
-Events logged for: path traversal, command injection, sensitive env access, SSRF blocks, symlink traversal, prompt injection.
-
-**No database-backed audit log**. All events go to application logs only.
-
----
-
-## 8. Frontend
-
-### 8.1 Technology Stack
-
-| Aspect | Technology |
-|--------|-----------|
-| **Template Engine** | Templ v0.3.960 (Go SSR, type-safe) |
-| **CSS Framework** | Tailwind CSS (compiled, embedded in binary) |
-| **Interactivity** | HTMX + HTMX SSE Extension |
-| **Client-side** | Tailwind Plus Elements (dropdowns, modals) |
-| **Code Highlighting** | Prism.js |
-| **No SPA framework** | No Angular, React, or Vue |
-
-### 8.2 Pages and Components
-
-| Type | File | Purpose |
-|------|------|---------|
-| **Layout** | `layout/app.templ` | Base HTML (head, body, scripts) |
-| **Page** | `page/chat.templ` | Two-column chat layout (sidebar + feed) |
-| **Component** | `component/message_bubble.templ` | User/assistant message rendering |
-| | `component/sidebar.templ` | Session list sidebar |
-| | `component/chat_input.templ` | Message input form |
-| | `component/empty_state.templ` | Empty state view |
-| | `component/session_placeholders.templ` | Loading skeletons |
-
-### 8.3 Real-time Updates
-
-- **SSE** (Server-Sent Events) for streaming AI responses
-- **HTMX OOB swaps** for sidebar refresh, session field updates
-- **No WebSocket** implementation
-
-### 8.4 Dashboard
-
-**Not implemented**. Only a chat interface exists.
-
-### 8.5 Static Asset Management
-
-- **Production**: Assets embedded via `//go:embed` (single binary deployment)
-- **Development**: `assets_dev.go` loads from filesystem (build tag: `dev`)
-- **JS Libraries**: HTMX 2.x, htmx-sse extension, Tailwind Plus Elements, Prism.js — all vendored locally (no CDN)
-
----
-
-## 9. Deployment & Configuration
-
-### 9.1 Docker Compose
-
-```yaml
-services:
- postgres:
- image: pgvector/pgvector:pg17
- ports: ["5432:5432"]
- environment:
- POSTGRES_USER: koopa
- POSTGRES_PASSWORD: koopa_dev_password
- POSTGRES_DB: koopa
- volumes:
- - postgres_data:/var/lib/postgresql/data
- - ./db/migrations:/docker-entrypoint-initdb.d # Auto-init
- healthcheck: pg_isready -U koopa
-
- searxng:
- image: searxng/searxng:latest
- ports: ["8888:8080"]
- depends_on: [redis]
- healthcheck: wget --no-verbose --tries=1 http://localhost:8080/healthz
-
- redis:
- image: valkey/valkey:8-alpine
- healthcheck: valkey-cli ping
-```
-
-**No Dockerfile for Koopa itself** — built as a Go binary, not containerized.
-
-### 9.2 Environment Variables
-
-| Variable | Required | Default | Purpose |
-|----------|----------|---------|---------|
-| `GEMINI_API_KEY` | Yes | — | Google Gemini API key |
-| `DATABASE_URL` | Yes | — | PostgreSQL connection string |
-| `HMAC_SECRET` | Web only | — | CSRF token signing (min 32 chars) |
-| `KOOPA_MODEL_NAME` | No | `gemini-2.5-flash` | LLM model name |
-| `KOOPA_TEMPERATURE` | No | `0.7` | LLM temperature |
-| `KOOPA_MAX_TOKENS` | No | `2048` | Max response tokens |
-| `KOOPA_MAX_HISTORY_MESSAGES` | No | `100` | Conversation history limit |
-| `KOOPA_RAG_TOP_K` | No | `3` | RAG documents to retrieve |
-| `DEBUG` | No | `false` | Debug logging |
-| `DD_API_KEY` | No | — | Datadog APM |
-
-### 9.3 Configuration File
-
-**Location**: `~/.koopa/config.yaml`
-
-Used for MCP server definitions and advanced tool configuration. See `config.example.yaml` for template.
-
-### 9.4 Startup Flow
-
-```
-main.go
- → cmd.Execute()
- → Parse flags (--version, --help)
- → Dispatch to: runCLI() | RunServe() | RunMCP()
-
-runCLI():
- → config.Load()
- → signal.NotifyContext(SIGINT, SIGTERM)
- → app.NewRuntime(ctx, cfg)
- → Wire: InitializeApp() (10-step provider chain)
- → chat.New(Config{...}) # Create agent
- → chat.InitFlow(g, agent) # Register Genkit Flow (sync.Once)
- → Load/create session ID (~/.koopa/current_session)
- → tui.New(ctx, flow, sessionID)
- → tea.NewProgram(model).Run()
-
-RunServe():
- → Validate HMAC_SECRET
- → app.NewRuntime(ctx, cfg)
- → web.NewServer(ServerConfig{...})
- → http.Server{...} with timeouts
- → Graceful shutdown on SIGINT/SIGTERM
-
-RunMCP():
- → app.InitializeApp(ctx, cfg) # Wire directly
- → Create FileTools, SystemTools, NetworkTools
- → mcp.NewServer(Config{...})
- → server.Run(ctx, &StdioTransport{})
-```
-
----
-
-## 10. Testing
-
-### 10.1 Test Types
-
-| Type | Count | Build Tag | Runner |
-|------|-------|-----------|--------|
-| **Unit Tests** | ~70 files | none | `go test ./...` |
-| **Integration Tests** | ~10 files | `integration` | `go test -tags integration ./...` |
-| **Fuzz Tests** | 4 files | none | `go test -fuzz=FuzzXxx ./...` |
-| **E2E Tests** | ~5 files | `e2e` | Playwright browser automation |
-| **Race Tests** | 1 file | none | `go test -race ./...` |
-
-### 10.2 Testing Patterns
-
-**Unit Test Example** (handler testing):
-```go
-func TestChat_Send(t *testing.T) {
- fw := SetupTest(t) // Test framework with mock deps
- defer fw.Cleanup()
- req := httptest.NewRequest(...)
- w := httptest.NewRecorder()
- fw.Chat.Send(w, req)
- assert.Equal(t, 200, w.Code)
-}
-```
-
-**Integration Test Example** (real database):
-```go
-//go:build integration
-
-func TestPages_Chat_LoadsHistoryFromDatabase(t *testing.T) {
- // Uses testcontainers PostgreSQL
- db := testutil.SetupTestDB(t)
- store := session.New(db, logger)
- // ... test with real DB operations
-}
-```
-
-**Fuzz Test Example** (security):
-```go
-func FuzzMessageContent(f *testing.F) {
- f.Add("normal text")
- f.Add("")
- f.Add("'; DROP TABLE messages; --")
- f.Fuzz(func(t *testing.T, content string) {
- // Verify no panic, proper escaping
- })
-}
-```
-
-### 10.3 Test Infrastructure
-
-| Component | Location | Purpose |
-|-----------|----------|---------|
-| `testutil.SetupTestDB()` | `internal/testutil/db.go` | PostgreSQL testcontainer |
-| `testutil.DiscardLogger()` | `internal/testutil/logger.go` | No-op logger |
-| Deterministic embedder | `internal/testutil/embedder.go` | Predictable embeddings for tests |
-| Mock SSE Writer | `internal/web/handlers/chat_test.go` | Records SSE events |
-| Browser fixture | `internal/web/fixture_test.go` | Playwright browser context |
-| SSE event parser | `internal/testutil/` | Parse SSE format |
-
-### 10.4 Coverage
-
-Not explicitly measured (no CI coverage badge), but `coverage.out` exists at root. Test infrastructure suggests moderate-to-good coverage for core packages (agent, tools, security, handlers).
-
----
-
-## 11. Architecture Flow Diagram (Working)
-
-### Message Processing Flow (Web Chat)
-
-```
-User types message in browser
- │
- ▼
-POST /genui/send (HTMX form submission)
- │
- ├── Parse form: content, session_id, csrf_token
- ├── Validate CSRF token (HMAC-SHA256)
- ├── Lazy session creation (if first message)
- │
- ├── Render HTML response:
- │ ├── User message bubble (OOB swap into #message-feed)
- │ ├── Assistant skeleton with SSE connection:
- │ │
- │ └── OOB swaps: session_id field, csrf token refresh
- │
- ▼
-GET /genui/stream (SSE endpoint)
- │
- ├── Parse query params: msgId, session_id, query
- ├── Create 5-minute timeout context
- │
- ├── flow.Stream(ctx, Input{Query, SessionID})
- │ │
- │ ▼
- │ Chat.ExecuteStream(ctx, sessionID, query, callback)
- │ │
- │ ├── sessions.GetHistory(ctx, sessionID) # Load from PostgreSQL
- │ ├── deepCopyMessages(history) # Prevent Genkit data race
- │ ├── truncateHistory(messages, tokenBudget) # Context window management
- │ ├── retrieveRAGContext(ctx, query) # pgvector search (5s timeout)
- │ │
- │ ├── circuitBreaker.Allow() # Check circuit state
- │ ├── executeWithRetry(ctx, opts) # Exponential backoff
- │ │ │
- │ │ ├── rateLimiter.Wait(ctx) # Rate limiting
- │ │ └── prompt.Execute(ctx, opts...) # Genkit → Gemini API
- │ │ │
- │ │ ├── ai.WithTools(toolRefs...) # 13 tools available
- │ │ ├── ai.WithMaxTurns(5) # Agentic loop
- │ │ ├── ai.WithStreaming(callback) # Stream chunks
- │ │ └── ai.WithDocs(ragDocs...) # RAG context
- │ │ │
- │ │ ▼
- │ │ [Gemini processes, may call tools]
- │ │ │
- │ │ ├── Tool call → security validation → execute → Result
- │ │ ├── Tool result → back to Gemini → next turn
- │ │ └── Final text response
- │ │
- │ ├── circuitBreaker.Success()
- │ ├── sessions.AppendMessages(ctx, sessionID, [user, assistant])
- │ └── Return Response{FinalText, ToolRequests}
- │
- ├── For each StreamChunk:
- │ └── sseWriter.WriteChunkRaw(msgID, escapedHTML) → SSE "chunk" event
- │
- ├── maybeGenerateTitle(sessionID, content) # AI title gen
- ├── sseWriter.WriteSidebarRefresh(sessionID, title) # HX-Trigger
- └── sseWriter.WriteDone(ctx, msgID, finalComponent) # SSE "done" event
- │
- ▼
-Browser: HTMX processes SSE events
- ├── "chunk" events → swap into assistant bubble (innerHTML)
- ├── Sidebar refresh → re-render session list
- └── "done" event → close SSE connection
-```
-
-### Message Processing Flow (CLI/TUI)
-
-```
-User types in terminal textarea
- │
- ├── Enter key → submit
- │
- ▼
-TUI.handleKey() → StateInput → StateThinking
- │
- ├── startStream(query) → tea.Cmd
- │ │
- │ ├── Create buffered channel (100 items)
- │ ├── Spawn goroutine with 5-min timeout
- │ └── flow.Stream(ctx, Input{Query, SessionID})
- │ │
- │ └── [Same agent flow as Web Chat above]
- │
- ├── streamStartedMsg → StateStreaming
- │ ├── streamTextMsg → append to output.Builder
- │ ├── streamTextMsg → append to output.Builder
- │ └── ...
- │
- └── streamDoneMsg → StateInput
- ├── Append to messages slice (max 100)
- └── Clear output buffer
-```
-
----
-
-## 12. Gap Analysis
-
-| Component | Status | Details |
-|-----------|--------|---------|
-| **Agent Runtime (Genkit flow-based)** | ✅ 完成 | `genkit.DefineStreamingFlow`, agentic loop with 5 max turns, retry + circuit breaker + rate limiting + token budget |
-| **Channel Adapter: Telegram** | ❌ 尚未開始 | 無任何 Telegram 相關程式碼 |
-| **Channel Adapter: LINE** | ❌ 尚未開始 | 無任何 LINE 相關程式碼 |
-| **Channel Adapter: Web Chat** | ✅ 完成 | Templ SSR + HTMX + SSE streaming, session management, CSRF protection |
-| **Permission Engine (Always/RequireApproval/RoleOnly)** | 🟡 部分完成 | **已完成**: `DangerLevel` enum (Safe/Warning/Dangerous/Critical), `ToolMetadata.RequiresConfirmation` flag, inline security validators (5 modules). **缺少**: Runtime approval flow (approve/reject UI), role-based access control, permission policy engine |
-| **MCP Bridge (client + permission layer)** | 🟡 部分完成 | **已完成**: MCP Server (10 tools, stdio transport, official SDK v1.1.0). **缺少**: MCP Client (connecting to external MCP servers at runtime), permission layer for MCP tool calls |
-| **Event Bus (structured events for observability)** | 🟡 部分完成 | **已完成**: `ToolEventEmitter` interface with `OnToolStart/Complete/Error`, context-based propagation, SSE tool status display. **缺少**: General-purpose event bus, structured event types beyond tools, event persistence, event subscribers/listeners pattern |
-| **Memory: Short-term (conversation context)** | ✅ 完成 | `session.History` (thread-safe `[]*ai.Message`), token budget truncation (8K history limit), loaded from PostgreSQL per request |
-| **Memory: Long-term (PostgreSQL cross-conversation)** | ✅ 完成 | `sessions` + `message` tables, JSONB content storage, sequence numbering, message status lifecycle (streaming/completed/failed), session listing with pagination |
-| **Memory: Knowledge Base (RAG with pgvector)** | ✅ 完成 | `documents` table with vector(768), HNSW index, cosine similarity, 3 source types (conversation/file/system), 6 system knowledge docs, 3 knowledge tools for agent |
-| **Session Management** | ✅ 完成 | PostgreSQL-backed `session.Store`, HTTP cookie sessions (30-day), CSRF tokens (HMAC-SHA256), lazy session creation, CLI session persistence (~/.koopa/current_session) |
-| **Audit Logging (PostgreSQL)** | 🟡 部分完成 | **已完成**: slog-based security event logging (path traversal, command injection, SSRF, prompt injection). **缺少**: PostgreSQL audit table, queryable audit trail, audit log retention policy |
-| **Approval Flow (pending → approved/rejected)** | ❌ 尚未開始 | `ToolMetadata.RequiresConfirmation` exists as schema but no runtime implementation. No pending state, no approval UI, no notification system |
-| **Dashboard Frontend (Angular)** | ❌ 尚未開始 | 無 Angular 程式碼。目前只有 Templ SSR chat interface |
-| **REST API for Dashboard** | 🟡 部分完成 | **已完成**: `GET /genui` (chat page), `POST /genui/send`, `GET /genui/stream` (SSE), `POST /genui/sessions/*`, `GET /health`, `GET /ready`. **缺少**: RESTful CRUD API for sessions/messages (JSON), admin/dashboard-specific endpoints, API versioning |
-| **WebSocket for real-time events** | ❌ 尚未開始 | 目前使用 SSE (Server-Sent Events) 進行即時串流,無 WebSocket 實作 |
-| **Docker Compose deployment** | 🟡 部分完成 | **已完成**: PostgreSQL + pgvector, SearXNG + Redis, health checks, volume mounts. **缺少**: Koopa 應用本身的 Dockerfile, production docker-compose profile, nginx/reverse proxy |
-| **CLI / Configuration** | ✅ 完成 | 三種模式 (cli/serve/mcp), Viper config (.env + YAML), signal handling + graceful shutdown, Taskfile build system |
-
-### Summary Table
-
-| Status | Count | Items |
-|--------|-------|-------|
-| ✅ 完成 | 7 | Agent Runtime, Web Chat, Short-term Memory, Long-term Memory, Knowledge Base (RAG), Session Management, CLI/Config |
-| 🟡 部分完成 | 6 | Permission Engine, MCP Bridge, Event Bus, Audit Logging, REST API, Docker Compose |
-| ❌ 尚未開始 | 5 | Telegram Adapter, LINE Adapter, Approval Flow, Dashboard (Angular), WebSocket |
-
----
-
-## Appendix: TODO/FIXME/HACK Comments
-
-| File | Comment |
-|------|---------|
-| `internal/web/server.go:86-96` | TODO: Implement Settings and Search handlers |
-| `internal/web/server.go:112-114` | TODO: Settings and Search routes |
-| `internal/agent/chat/chat.go:412` | TODO: File Genkit GitHub issue (data race workaround) |
-| `internal/app/app_test.go:108` | TODO: Re-enable when Toolset migration complete |
-| `cmd/e2e_test.go:198-203` | FIXME: MCP server exits immediately (requires test harness) |
-| `cmd/e2e_test.go:255-260` | FIXME: MCP communication fails with EOF (needs proper test client) |
-
----
-
-## Appendix: Environment Configuration Template
-
-**`.env.example`**:
-```bash
-# Required (All Modes)
-GEMINI_API_KEY=your-api-key-here
-DATABASE_URL=postgres://koopa:koopa_dev_password@localhost:5432/koopa?sslmode=disable
-
-# Required (Web Mode Only)
-HMAC_SECRET= # openssl rand -base64 32
-
-# Optional Model Settings
-KOOPA_MODEL_NAME=gemini-2.5-flash
-KOOPA_TEMPERATURE=0.7
-KOOPA_MAX_TOKENS=2048
-KOOPA_MAX_HISTORY_MESSAGES=100
-KOOPA_RAG_TOP_K=3
-
-# Optional
-DEBUG=false
-DD_API_KEY= # Datadog APM
-```
-
----
-
-## Appendix: Build Commands
-
-```bash
-# Development
-task build:dev # Build with filesystem assets
-task css:watch # Watch CSS changes
-task generate # Generate templ files
-task sqlc # Generate SQL code
-
-# Testing
-task test # All tests
-task test:race # With race detector
-task test:unit # Unit only
-task test:integration # With testcontainers
-task test:fuzz # Fuzz tests (30s each)
-task test:e2e # Playwright browser tests
-
-# Quality
-task lint # golangci-lint
-task check # lint + test + build
-
-# Production
-task build # Embedded assets binary
-go build -o koopa ./... # Manual build
-```
diff --git a/docs/proposals/001-phase-a-remove-templ-htmx-json-api-skeleton.md b/docs/proposals/001-phase-a-remove-templ-htmx-json-api-skeleton.md
deleted file mode 100644
index 3e4f91e..0000000
--- a/docs/proposals/001-phase-a-remove-templ-htmx-json-api-skeleton.md
+++ /dev/null
@@ -1,306 +0,0 @@
-# Proposal 001: Phase A — Remove Templ/HTMX, Build JSON API Skeleton
-
-> Status: PENDING REVIEW
-> Author: Claude Code
-> Date: 2026-02-04
-
----
-
-## Summary
-
-Remove all Templ/HTMX/SSE frontend code from `internal/web/`, migrate business logic to a new `internal/api/` package, and establish a JSON REST API skeleton. This is an atomic operation — `koopa serve` must work before and after.
-
----
-
-## 2.1 Proposed Directory Structure
-
-```
-internal/api/
-├── server.go # HTTP server, route registration, CORS, security headers
-├── response.go # JSON response helpers: WriteJSON, WriteError, standard envelope
-├── middleware/
-│ ├── recovery.go # Panic recovery (from web/middleware.go)
-│ ├── logging.go # Request logging (from web/middleware.go)
-│ ├── cors.go # CORS for Angular dev server (NEW)
-│ └── auth.go # Session cookie + CSRF validation (from web/handlers/sessions.go + web/middleware.go)
-└── v1/
- ├── chat.go # POST /api/v1/chat/send, GET /api/v1/chat/stream (JSON SSE)
- ├── chat_test.go # Chat handler tests (JSON response validation)
- ├── sessions.go # GET/POST/DELETE /api/v1/sessions, session auth logic
- ├── sessions_test.go # Session handler tests
- ├── health.go # GET /api/v1/health, GET /api/v1/ready
- └── health_test.go # Health handler tests
-```
-
-### Design Rationale
-
-**Why `internal/api/` not `internal/web/` rename?**
-- Clean break. No leftover references to templ/htmx imports.
-- `internal/web/` will be re-used later for Angular `dist/` embedding (Phase B+).
-
-**Why `v1/` versioning?**
-- API versioning from day one. Angular client will target `/api/v1/`.
-- Future breaking changes go to `v2/` without disrupting existing clients.
-
-**Why `middleware/` as a sub-package?**
-- Follows existing pattern separation (middleware was already logically separate in `web/middleware.go`).
-- Allows middleware to be tested independently.
-
-**Why `response.go` at package root?**
-- Shared by all `v1/` handlers. Avoids circular imports.
-- Single source of truth for JSON envelope format.
-
----
-
-## Response Envelope
-
-All JSON responses follow a consistent envelope:
-
-```go
-// internal/api/response.go
-
-// Envelope is the standard JSON response wrapper.
-type Envelope struct {
- Data any `json:"data,omitempty"`
- Error *Error `json:"error,omitempty"`
-}
-
-type Error struct {
- Code string `json:"code"`
- Message string `json:"message"`
-}
-
-func WriteJSON(w http.ResponseWriter, status int, data any) { ... }
-func WriteError(w http.ResponseWriter, status int, code, message string) { ... }
-```
-
----
-
-## Route Mapping (Old → New)
-
-| Old Route | New Route | Method | Handler | Notes |
-|-----------|-----------|--------|---------|-------|
-| `GET /genui` | _(removed)_ | — | — | Page rendering, replaced by Angular |
-| `POST /genui/send` | `POST /api/v1/chat/send` | POST | `v1.Chat.Send` | Returns JSON instead of HTML |
-| `GET /genui/stream` | `GET /api/v1/chat/stream` | GET | `v1.Chat.Stream` | JSON SSE (not HTML SSE) |
-| `GET /genui/sessions` | `GET /api/v1/sessions` | GET | `v1.Sessions.List` | JSON array |
-| `POST /genui/sessions` | `POST /api/v1/sessions` | POST | `v1.Sessions.Create` | JSON response |
-| `DELETE /genui/sessions/{id}` | `DELETE /api/v1/sessions/{id}` | DELETE | `v1.Sessions.Delete` | JSON response |
-| `GET /health` | `GET /api/v1/health` | GET | `v1.Health.Health` | Same logic |
-| `GET /ready` | `GET /api/v1/ready` | GET | `v1.Health.Ready` | Same logic |
-| — | `OPTIONS /api/v1/*` | OPTIONS | CORS middleware | NEW: Preflight |
-
----
-
-## Streaming Strategy
-
-The SSE **transport protocol** is kept (it's standard HTTP, not HTMX-specific). What changes is the **content format**:
-
-**Before (HTML SSE)**:
-```
-event: chunk
-data:
Hello
-
-event: done
-data:
...
-```
-
-**After (JSON SSE)**:
-```
-event: chunk
-data: {"msgId":"abc-123","text":"Hello"}
-
-event: tool_start
-data: {"msgId":"abc-123","tool":"web_search","message":"Searching..."}
-
-event: tool_complete
-data: {"msgId":"abc-123","tool":"web_search","message":"Done"}
-
-event: done
-data: {"msgId":"abc-123","sessionId":"xyz","title":"Chat Title"}
-
-event: error
-data: {"msgId":"abc-123","code":"timeout","message":"Request timed out"}
-```
-
-This preserves the existing streaming architecture (Genkit Flow → SSE) while making the output consumable by Angular's `EventSource` API. The SSE writer is rewritten as a simple JSON event writer in `internal/api/v1/chat.go` — no separate `sse/` package needed.
-
----
-
-## Middleware Stack
-
-```
-Request
- → CORS (new)
- → Recovery (from web/middleware.go)
- → Logging (from web/middleware.go)
- → Auth (session cookie + CSRF, from web/middleware.go + handlers/sessions.go)
- → Route handler
-```
-
-**CORS middleware** (`middleware/cors.go`):
-- Reads allowed origins from config (`KOOPA_CORS_ORIGINS`, default: `http://localhost:4200`)
-- Handles preflight `OPTIONS` requests (204 No Content)
-- Sets `Access-Control-Allow-Origin`, `Allow-Methods`, `Allow-Headers`, `Allow-Credentials`
-
-**Auth middleware** (`middleware/auth.go`):
-- Consolidates session cookie reading + CSRF validation from the old `RequireSession` + `RequireCSRF` + `Sessions` struct.
-- HMAC-SHA256 token logic preserved exactly as-is.
-- Pre-session CSRF pattern preserved for lazy session creation.
-- CSRF token read from `X-CSRF-Token` header (instead of form field) for JSON API compatibility.
-
-**MethodOverride middleware**: Removed. JSON API uses proper HTTP methods directly.
-
----
-
-## Business Logic Migration Map
-
-| Source | Destination | What Migrates |
-|--------|-------------|---------------|
-| `web/middleware.go` RecoveryMiddleware | `api/middleware/recovery.go` | Panic catch, error response (→ JSON) |
-| `web/middleware.go` LoggingMiddleware | `api/middleware/logging.go` | Request logging (unchanged) |
-| `web/middleware.go` RequireSession | `api/middleware/auth.go` | Lazy session creation logic |
-| `web/middleware.go` RequireCSRF | `api/middleware/auth.go` | CSRF validation logic |
-| `web/handlers/sessions.go` HMAC logic | `api/middleware/auth.go` | Token gen/validation (unchanged) |
-| `web/handlers/sessions.go` cookie logic | `api/middleware/auth.go` | Cookie set/read (unchanged) |
-| `web/handlers/sessions.go` List/Create/Delete | `api/v1/sessions.go` | CRUD → JSON responses |
-| `web/handlers/chat.go` Send | `api/v1/chat.go` | Content validation, session handling → JSON |
-| `web/handlers/chat.go` Stream | `api/v1/chat.go` | Flow streaming → JSON SSE |
-| `web/handlers/chat.go` title gen | `api/v1/chat.go` | AI title generation (unchanged) |
-| `web/handlers/chat.go` error classify | `api/v1/chat.go` | Error classification (unchanged) |
-| `web/handlers/health.go` | `api/v1/health.go` | Health probes (unchanged) |
-| `web/handlers/tool_display.go` | `api/v1/chat.go` (inline) | Tool display messages (simplified) |
-| `web/handlers/convert.go` extractTextContent | `api/v1/sessions.go` | Message text extraction (unchanged) |
-| `web/sse/writer.go` SSE format | `api/v1/chat.go` (inline) | SSE protocol (`event:`, `data:`, flush) |
-
----
-
-## What Gets Deleted (No Migration)
-
-| File/Dir | Reason |
-|----------|--------|
-| `internal/web/page/` | Templ page templates |
-| `internal/web/layout/` | Templ layout templates |
-| `internal/web/component/` | Templ components + 222 reference blocks |
-| `internal/web/static/` | CSS, JS, embedded assets |
-| `internal/web/sse/` | HTML SSE writer (replaced by inline JSON SSE) |
-| `internal/web/e2e/` | Playwright test assets |
-| `internal/web/handlers/htmx.go` | HTMX detection helper |
-| `internal/web/handlers/pages.go` | HTML page rendering |
-| `internal/web/handlers/tool_emitter.go` | SSE-specific tool emitter |
-| `internal/web/fixture_test.go` | Playwright fixtures |
-| `internal/web/e2e_*.go` | All E2E browser tests |
-| `internal/web/server_test.go` | Tests for HTML server |
-| `internal/web/static/assets_test.go` | Tests for embedded assets |
-| `build/frontend/` | Tailwind CSS build config |
-
----
-
-## Wire DI Impact
-
-**No changes to `wire.go` or `wire_gen.go`.**
-
-`web.NewServer` was never part of Wire — it's created manually in `cmd/serve.go`. The new `api.NewServer` will also be created manually in `cmd/serve.go`, using the same Runtime components.
-
----
-
-## cmd/serve.go Changes
-
-```go
-// Before
-webServer, err := web.NewServer(web.ServerConfig{
- Logger: logger,
- Genkit: runtime.App.Genkit,
- ChatFlow: runtime.Flow,
- SessionStore: runtime.App.SessionStore,
- CSRFSecret: []byte(cfg.HMACSecret),
- Config: cfg,
-})
-
-// After
-apiServer := api.NewServer(api.ServerConfig{
- Logger: logger,
- Genkit: runtime.App.Genkit,
- ChatFlow: runtime.Flow,
- SessionStore: runtime.App.SessionStore,
- CSRFSecret: []byte(cfg.HMACSecret),
- CORSOrigins: cfg.CORSOrigins,
- IsDev: cfg.Debug,
-})
-```
-
----
-
-## Config Changes
-
-**New fields in `config.Config`**:
-```go
-CORSOrigins []string // from KOOPA_CORS_ORIGINS, default: ["http://localhost:4200"]
-```
-
-**New env vars in `.env.example`**:
-```bash
-KOOPA_CORS_ORIGINS=http://localhost:4200 # Comma-separated allowed origins
-```
-
-**Removed env vars**: None. `HMAC_SECRET` stays (still used for CSRF).
-
----
-
-## Dependency Removal
-
-| Dependency | Action | Reason |
-|------------|--------|--------|
-| `github.com/a-h/templ` | Remove | No more templ templates |
-| `github.com/playwright-community/playwright-go` | Remove | Only used in `internal/web/` E2E tests |
-
-Verified: `cmd/e2e_test.go` does NOT use playwright (uses `os/exec` only).
-
----
-
-## Taskfile.yml Changes
-
-| Task | Action |
-|------|--------|
-| `css` | Remove |
-| `css:watch` | Remove |
-| `generate` | Change to `sqlc generate` only (remove `templ generate`) |
-| `install:templ` | Remove |
-| `install:npm` | Remove (no more build/frontend) |
-| `build` | Remove templ/css deps |
-| `build:dev` | Remove templ/css deps |
-| `test:e2e` | Remove |
-| `fmt` | Remove `templ fmt` |
-| `dev` | Remove templ/css generation |
-
----
-
-## Acceptance Criteria
-
-All criteria from the user's spec apply. Additionally:
-
-1. `go build ./...` — zero errors
-2. `golangci-lint run ./...` — no errors
-3. `go test ./...` — all pass (deleted web tests excluded)
-4. `go vet ./...` — clean
-5. `koopa serve` — starts, responds to `/api/v1/health`
-6. `koopa` (CLI) — unaffected
-7. `koopa mcp` — unaffected
-8. CORS preflight returns 204 with correct headers
-9. No `.templ` files remain
-10. No `htmx` references in Go code
-11. No `github.com/a-h/templ` in go.mod
-12. `internal/web/` directory does not exist
-
----
-
-## Risk Assessment
-
-| Risk | Mitigation |
-|------|-----------|
-| Wire DI breaks | Wire is not involved — server created in cmd/serve.go |
-| Import cycle | `api/` depends on `session/`, `agent/chat/`, `config/` — same as `web/` did |
-| Missing business logic | Comprehensive migration map above; each handler verified |
-| Streaming breaks | SSE transport preserved, only content format changes (HTML → JSON) |
-| Tests fail | Migrated tests adapted to JSON assertions |
-| go.mod stale deps | `go mod tidy` cleans up automatically |
diff --git a/docs/proposals/001-v010-release-plan.md b/docs/proposals/001-v010-release-plan.md
deleted file mode 100644
index cc4484a..0000000
--- a/docs/proposals/001-v010-release-plan.md
+++ /dev/null
@@ -1,315 +0,0 @@
-# Proposal 001: v0.1.0 Release Plan
-
-## Summary
-
-Prepare Koopa for its first public release by:
-1. Cleaning up code quality (testify removal, skipped tests)
-2. Adding multi-model support (Gemini + Ollama + OpenAI)
-3. Exposing RAG tools via MCP
-4. Writing README and release automation
-
-## Scope
-
-**In scope (v0.1.0):**
-- testify → stdlib + cmp.Diff migration (16 files)
-- Multi-model: Gemini (current) + Ollama + OpenAI via Genkit plugins
-- Config-driven model/provider selection
-- MCP server exposes Knowledge/RAG tools
-- README with install instructions and quick start
-- goreleaser for macOS (arm64/amd64) + Linux (arm64/amd64)
-- Complete skipped tests where feasible
-
-**Out of scope (deferred to v0.2.0):**
-- Claude/Anthropic support (Genkit plugin disables tool use)
-- MCP Client (connecting to external MCP servers)
-- Hybrid RAG (vector + keyword search)
-- Extensible user-defined tool system
-- Agent loop enhancement (multi-step task decomposition)
-
-## Detailed Design
-
-### 1. testify Migration (16 files)
-
-**What changes:**
-Replace all `assert.*` / `require.*` calls with stdlib patterns:
-
-```go
-// Before
-assert.Equal(t, want, got)
-require.NoError(t, err)
-
-// After
-if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("FuncName() mismatch (-want +got):\n%s", diff)
-}
-if err != nil {
- t.Fatalf("FuncName() unexpected error: %v", err)
-}
-```
-
-**Files affected (16):**
-- `internal/agent/chat/integration_rag_test.go`
-- `internal/agent/chat/integration_test.go`
-- `internal/agent/chat/integration_streaming_test.go`
-- `internal/agent/chat/flow_test.go`
-- `internal/agent/chat/chat_test.go`
-- `internal/tools/system_integration_test.go`
-- `internal/tools/register_test.go`
-- `internal/tools/file_integration_test.go`
-- `internal/tools/network_integration_test.go`
-- `internal/session/integration_test.go`
-- `internal/rag/system_test.go`
-- `internal/mcp/integration_test.go`
-- `internal/observability/datadog_test.go`
-- `internal/testutil/postgres.go`
-- `internal/testutil/sse.go`
-- `cmd/e2e_test.go`
-
-After migration, remove `github.com/stretchr/testify` from `go.mod`.
-
-**Risk:** Low. Mechanical replacement, no logic changes.
-
-### 2. Multi-Model Support
-
-**Current state:**
-- `wire.go:119` hardcodes `&googlegenai.GoogleAI{}` as the only AI plugin
-- `wire.go:134` hardcodes `googlegenai.GoogleAIEmbedder()` as embedder
-- `prompts/koopa.prompt:2` hardcodes `model: googleai/gemini-2.5-flash`
-
-**Design:**
-
-#### 2a. Config Changes (`internal/config/config.go`)
-
-Add provider field:
-
-```go
-type Config struct {
- // AI configuration
- Provider string `mapstructure:"provider" json:"provider"` // "gemini" (default), "ollama", "openai"
- ModelName string `mapstructure:"model_name" json:"model_name"` // e.g. "gemini-2.5-flash", "llama3.3", "gpt-4o"
- // ... existing fields ...
-
- // Ollama configuration
- OllamaHost string `mapstructure:"ollama_host" json:"ollama_host"` // default: "http://localhost:11434"
-
- // OpenAI configuration (env: OPENAI_API_KEY)
- // No config field needed - key read by Genkit plugin from env
-}
-```
-
-Defaults:
-```yaml
-provider: gemini
-model_name: gemini-2.5-flash
-ollama_host: http://localhost:11434
-```
-
-Env overrides:
-```
-KOOPA_PROVIDER=ollama
-KOOPA_MODEL_NAME=llama3.3
-KOOPA_OLLAMA_HOST=http://localhost:11434
-OPENAI_API_KEY=sk-... # for openai provider
-```
-
-#### 2b. Dynamic Plugin Loading (`internal/app/wire.go`)
-
-Replace hardcoded GoogleAI with provider switch:
-
-```go
-func provideGenkit(ctx context.Context, cfg *config.Config, _ OtelShutdown, postgres *postgresql.Postgres) (*genkit.Genkit, error) {
- plugins := []genkit.Plugin{postgres} // PostgreSQL always needed for RAG
-
- switch cfg.Provider {
- case "gemini", "":
- plugins = append(plugins, &googlegenai.GoogleAI{})
- case "ollama":
- plugins = append(plugins, &ollama.Ollama{ServerAddress: cfg.OllamaHost})
- case "openai":
- plugins = append(plugins, &openai.OpenAI{})
- default:
- return nil, fmt.Errorf("unsupported provider: %q", cfg.Provider)
- }
-
- g := genkit.Init(ctx,
- genkit.WithPlugins(plugins...),
- genkit.WithPromptDir(promptDir),
- )
- return g, nil
-}
-```
-
-#### 2c. Dynamic Model in Dotprompt
-
-The dotprompt `model:` field determines which model Genkit uses. Two options:
-
-**Option A: Override at runtime (preferred)**
-Keep the dotprompt as-is with a default model. Override the model when executing the prompt via `ai.WithModel()` option. This requires checking if Genkit's prompt execution API supports model override.
-
-**Option B: Generate dotprompt at startup**
-Write the dotprompt file dynamically based on config. This is fragile and not recommended.
-
-**Option C: Multiple dotprompt files**
-Create `prompts/koopa-gemini.prompt`, `prompts/koopa-ollama.prompt`, etc. Select based on config. Duplicates prompt content.
-
-**Recommendation: Option A** — override at execution time. The `ai.WithModel()` option in `ai.PromptExecuteOption` allows this. The dotprompt file remains the default/fallback.
-
-Implementation in `chat.go`:
-```go
-// In generateResponse(), add model override to opts
-if modelOverride != "" {
- model := genkit.LookupModel(c.g, modelOverride)
- if model != nil {
- opts = append(opts, ai.WithModel(model))
- }
-}
-```
-
-The model name format follows Genkit convention:
-- Gemini: `googleai/gemini-2.5-flash`
-- Ollama: `ollama/llama3.3`
-- OpenAI: `openai/gpt-4o`
-
-Config `model_name` stores the short name (`gemini-2.5-flash`), and the provider prefix is derived from `cfg.Provider`.
-
-#### 2d. Embedder Selection
-
-Embedder is needed for RAG (pgvector). Options per provider:
-
-| Provider | Embedder | Model |
-|----------|----------|-------|
-| gemini | `googlegenai.GoogleAIEmbedder` | text-embedding-004 |
-| ollama | `ollama.Embedder` | nomic-embed-text |
-| openai | `openai.Embedder` | text-embedding-3-small |
-
-```go
-func provideEmbedder(g *genkit.Genkit, cfg *config.Config) ai.Embedder {
- switch cfg.Provider {
- case "ollama":
- return ollama.Embedder(g, cfg.EmbedderModel)
- case "openai":
- return openai.Embedder(g, cfg.EmbedderModel)
- default:
- return googlegenai.GoogleAIEmbedder(g, cfg.EmbedderModel)
- }
-}
-```
-
-**Note:** Switching embedder provider changes vector dimensions. Existing pgvector data from one embedder is incompatible with another. This is acceptable for v0.1.0 (users start fresh). Document this limitation.
-
-#### 2e. Validation
-
-- `gemini`: Require `GEMINI_API_KEY` env var
-- `ollama`: Require `ollama_host` reachable (health check at startup)
-- `openai`: Require `OPENAI_API_KEY` env var
-
-### 3. MCP RAG Tools
-
-**Current state:** MCP server exposes File, System, Network tools but NOT Knowledge tools.
-
-**Change:** Register knowledge tools in MCP server alongside existing tools.
-
-**File:** `internal/mcp/server.go` (or wherever MCP tools are registered)
-
-Add the 3 knowledge tools:
-- `search_history` — search past conversations
-- `search_documents` — search indexed documents
-- `search_system_knowledge` — search system knowledge base
-
-**Risk:** Low. Tools already exist and are tested. Just wire them into MCP registration.
-
-### 4. README
-
-Structure:
-```
-# Koopa
-One-line description.
-
-## Features
-- Terminal AI assistant (TUI)
-- JSON REST API with SSE streaming
-- MCP server for IDE integration
-- 13 built-in tools (file, system, network, knowledge)
-- RAG with pgvector
-- Multi-model (Gemini, Ollama, OpenAI)
-
-## Quick Start
-### Prerequisites
-- Go 1.23+
-- PostgreSQL 17 with pgvector
-- Docker (for SearXNG + Redis)
-
-### Install
-go install github.com/koopa0/koopa@latest
-# or download from releases
-
-### Setup
-docker compose up -d
-export GEMINI_API_KEY="your-key"
-koopa
-
-### Using Ollama (local models)
-ollama pull llama3.3
-export KOOPA_PROVIDER=ollama
-export KOOPA_MODEL_NAME=llama3.3
-koopa
-
-## Configuration
-## Architecture
-## License
-```
-
-### 5. Release Automation (goreleaser)
-
-Create `.goreleaser.yml`:
-- Platforms: macOS (arm64, amd64), Linux (arm64, amd64)
-- Binary name: `koopa`
-- GitHub release with checksums
-- Homebrew tap (optional, can add later)
-
-### 6. Skipped Tests
-
-7 skipped tests identified:
-- 5 in `internal/api/session_test.go` — need PostgreSQL integration
-- 2 in `cmd/e2e_test.go` — need MCP test harness
-
-**Action:**
-- `session_test.go`: Convert to testcontainers-based integration tests
-- `e2e_test.go`: Assess feasibility. If MCP test harness is complex, document as known limitation.
-
-## Implementation Order
-
-| Step | Task | Depends On | Estimated Files |
-|------|------|------------|-----------------|
-| 1 | testify migration | none | 16 test files |
-| 2 | Remove testify from go.mod | step 1 | go.mod, go.sum |
-| 3 | Config: add Provider, OllamaHost fields | none | config.go, validation.go, config_test.go |
-| 4 | wire.go: dynamic plugin loading | step 3 | wire.go, wire_gen.go |
-| 5 | chat.go: model override at execution | step 4 | chat.go, chat_test.go |
-| 6 | Embedder: provider-based selection | step 4 | wire.go |
-| 7 | MCP: register knowledge tools | none | mcp/server.go or tools registration |
-| 8 | go.mod: add ollama, openai plugins | step 4 | go.mod, go.sum |
-| 9 | Skipped tests | none | session_test.go, e2e_test.go |
-| 10 | README | after all features | README.md |
-| 11 | goreleaser config | none | .goreleaser.yml |
-| 12 | Final verification | all steps | - |
-
-## Risks
-
-| Risk | Impact | Mitigation |
-|------|--------|------------|
-| Genkit Ollama plugin bugs | Tool calling fails | Test with popular models (llama3.3, qwen2.5) early |
-| Embedder dimension mismatch | RAG breaks when switching provider | Document: switching provider requires fresh DB |
-| Wire regeneration | Build breaks | Run `wire ./internal/app/` after wire.go changes |
-| Dotprompt model override | May not work as expected | Verify `ai.WithModel()` in prompt execution; fallback to Option C if needed |
-
-## Open Questions
-
-1. Should `config.yaml` support per-command provider override? (e.g., TUI uses Ollama but API uses Gemini)
- - **Recommendation:** No. Single provider per instance for v0.1.0.
-
-2. Should we add a `koopa config` CLI command to set provider/model interactively?
- - **Recommendation:** Defer. Environment variables and config file are sufficient.
-
-3. Default embedder model names for Ollama/OpenAI — need to verify exact model IDs.
- - **Action:** Check during implementation.
diff --git a/docs/proposals/002-v020-comprehensive-refactoring.md b/docs/proposals/002-v020-comprehensive-refactoring.md
deleted file mode 100644
index 3b078ad..0000000
--- a/docs/proposals/002-v020-comprehensive-refactoring.md
+++ /dev/null
@@ -1,409 +0,0 @@
-# Proposal 002: v0.2.0 Comprehensive Refactoring
-
-## Overview
-
-Based on a full-codebase comprehension (5 parallel agents analyzed every layer), this proposal defines a phased refactoring plan to bring Koopa to production quality.
-
-**Scope**: From `main.go` to every internal package. All 3 modes (CLI, API, MCP) preserved.
-**Priority**: CLI + HTTP API first, MCP in parallel where non-conflicting.
-**Framework**: Genkit as core AI orchestration, Bubble Tea v2 for TUI.
-**Storage**: PostgreSQL + pgvector retained as required dependency.
-
----
-
-## Phase 1: Critical Bug Fixes (Low Risk, Immediate)
-
-No architecture changes. Pure bug fixes with existing patterns.
-
-### 1.1 Fix `listenForStream()` Stack Overflow
-
-**File**: `internal/tui/commands.go`
-**Problem**: Recursive call on empty events can overflow stack.
-**Fix**: Replace recursion with `for` loop.
-
-```go
-// Before (recursive):
-default:
- return listenForStream(eventCh)()
-
-// After (iterative):
-for {
- event, ok := <-eventCh
- if !ok { return streamErrorMsg{err: ...} }
- if event.err != nil { return streamErrorMsg{...} }
- if event.done { return streamDoneMsg{...} }
- if event.text != "" { return streamTextMsg{...} }
- // empty event: loop instead of recurse
-}
-```
-
-### 1.2 Fix `truncateHistory` Break Logic
-
-**File**: `internal/agent/chat/tokens.go`
-**Problem**: Uses `break` instead of `continue`, drops old small messages after first large one.
-**Fix**: Change `break` to `continue`.
-
-### 1.3 Remove Ghost `knowledge_store` Reference
-
-**File**: `internal/api/chat.go:59`
-**Problem**: UI message references a tool that doesn't exist.
-**Fix**: Remove the `knowledge_store` entry from toolDisplayMessages map.
-
-### 1.4 Remove Dead `NormalizeMaxHistoryMessages`
-
-**File**: `internal/config/validation.go`
-**Problem**: Function defined but never called.
-**Fix**: Delete the function (validation already handles range checks in `Validate()`).
-
-### 1.5 Fix RAG Non-Atomic UPSERT
-
-**File**: `internal/rag/system.go`
-**Problem**: Delete + Insert without transaction; delete failure silently swallowed.
-**Fix**: Wrap in PostgreSQL transaction.
-
----
-
-## Phase 2: Remove Wire, Simplify DI (Medium Risk)
-
-Wire adds indirection and `wire_gen.go` is hand-maintained despite claiming to be generated.
-
-### 2.1 Replace Wire with Manual DI
-
-**Delete**: `internal/app/wire.go`, `internal/app/wire_gen.go`
-**Modify**: `internal/app/app.go`
-
-Replace Wire-generated `InitializeApp()` with explicit construction:
-
-```go
-func InitializeApp(ctx context.Context, cfg *config.Config) (*App, func(), error) {
- // 1. OTel setup
- otelShutdown, err := observability.Setup(ctx, cfg.Datadog)
-
- // 2. DB pool + migrations
- pool, err := connectDB(ctx, cfg)
-
- // 3. Genkit init
- g := initGenkit(ctx, cfg)
-
- // 4. Embedder
- embedder := lookupEmbedder(g, cfg)
-
- // 5. RAG components
- docStore, retriever := initRAG(ctx, g, pool, embedder, cfg)
-
- // 6. Session store
- sessionStore := session.NewStore(pool)
-
- // 7. Security validators
- pathValidator := security.NewPath(...)
-
- // 8. Tools
- tools := registerTools(g, pathValidator, ...)
-
- // 9. Construct App
- app := &App{...}
-
- cleanup := func() {
- pool.Close()
- otelShutdown(ctx)
- }
-
- return app, cleanup, nil
-}
-```
-
-### 2.2 Unify Tool Registration
-
-**Problem**: Tools are created twice — once for Genkit (wire.go), once for MCP (cmd/mcp.go).
-**Fix**: Create tools once, register to both Genkit and MCP from same instances.
-
-```go
-// internal/tools/registry.go (new file)
-type Registry struct {
- File *FileTools
- System *SystemTools
- Network *NetworkTools
- Knowledge *KnowledgeTools
-}
-
-func NewRegistry(pathValidator *security.Path, ...) *Registry { ... }
-
-// Register to Genkit
-func (r *Registry) RegisterGenkit(g *genkit.Genkit) ([]ai.Tool, error) { ... }
-```
-
-MCP server accepts `*tools.Registry` instead of individual tool structs.
-
-### 2.3 Refactor App Struct
-
-Split the God Object into focused structs:
-
-```go
-type App struct {
- Config *config.Config
- AI *AIComponents // Genkit, Embedder
- Storage *StorageComponents // DBPool, DocStore, Retriever, SessionStore
- Tools *tools.Registry
-
- // Lifecycle
- ctx context.Context
- cancel context.CancelFunc
- eg *errgroup.Group
-}
-
-type AIComponents struct {
- Genkit *genkit.Genkit
- Embedder ai.Embedder
-}
-
-type StorageComponents struct {
- Pool *pgxpool.Pool
- DocStore *postgresql.DocStore
- Retriever ai.Retriever
- SessionStore *session.Store
-}
-```
-
----
-
-## Phase 3: TUI Upgrade to Bubble Tea v2 (High Value)
-
-### 3.1 Dependency Upgrade
-
-```
-charm.land/bubbletea/v2
-charm.land/bubbles/v2
-charm.land/lipgloss/v2
-github.com/charmbracelet/glamour (for markdown rendering)
-```
-
-### 3.2 Rewrite TUI with v2 Patterns
-
-**Key changes**:
-- `View()` returns `tea.View` struct (declarative altscreen, cursor)
-- `tea.KeyPressMsg` replaces `tea.KeyMsg`
-- `tea.PasteMsg` for multi-line paste
-- Synchronized output (Mode 2026) reduces flicker during streaming
-
-**Streaming pattern** (replaces recursive listenForStream):
-```go
-func waitForStream(ch <-chan streamEvent) tea.Cmd {
- return func() tea.Msg {
- for {
- event, ok := <-ch
- if !ok { return streamDoneMsg{} }
- if event.err != nil { return streamErrorMsg{event.err} }
- if event.done { return streamDoneMsg{event.output} }
- if event.text != "" { return streamChunkMsg{event.text} }
- }
- }
-}
-```
-
-### 3.3 Add Glamour Markdown Rendering
-
-LLM responses contain markdown. Currently rendered as plain text.
-- Use `glamour` for completed messages
-- Raw text during active streaming (partial markdown may not parse)
-- Re-render with glamour when stream completes
-
-### 3.4 Improve TUI Components
-
-- `bubbles/viewport` for scrollable chat history
-- `bubbles/textarea` for multi-line input (shift+enter for newlines in v2)
-- `bubbles/spinner` for "thinking" indicator
-- `bubbles/help` for keyboard shortcuts
-- `lipgloss` for message bubble styling (user vs assistant)
-
----
-
-## Phase 4: API Improvements + Angular UI Spec (Medium Risk)
-
-### 4.1 Fix Session Ownership
-
-Add session-cookie based authorization:
-```go
-func (sm *sessionManager) authorizeSession(r *http.Request, sessionID uuid.UUID) error {
- cookieSessionID := sm.getSessionID(r)
- if cookieSessionID != sessionID {
- return ErrForbidden
- }
- return nil
-}
-```
-
-### 4.2 Security Headers
-
-Add missing HSTS header:
-```go
-w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
-```
-
-### 4.3 Add API Versioning
-
-Prefix all routes with `/api/v1/`:
-```
-/api/v1/csrf-token
-/api/v1/sessions
-/api/v1/sessions/{id}
-/api/v1/sessions/{id}/messages
-/api/v1/chat
-/api/v1/chat/stream
-```
-
-### 4.4 Standardize Response Envelope
-
-```json
-{
- "data": { ... },
- "error": null
-}
-```
-
-or on error:
-```json
-{
- "data": null,
- "error": { "code": "not_found", "message": "session not found" }
-}
-```
-
-### 4.5 Angular UI Spec Document
-
-Create `docs/api-spec.md` with:
-- Complete endpoint documentation (request/response schemas)
-- SSE event types and payloads
-- Authentication flow (CSRF tokens, session cookies)
-- WebSocket upgrade path (future consideration)
-- Error code reference
-
----
-
-## Phase 5: RAG & Knowledge Improvements
-
-### 5.1 Update System Knowledge Content
-
-Current 6 documents are outdated. Update to match actual registered tools.
-
-### 5.2 Implement Knowledge Store Tool
-
-The UI references `knowledge_store` but it doesn't exist. Either:
-- Option A: Implement it (user can save conversations/documents to knowledge base)
-- Option B: Remove UI reference (done in Phase 1.3)
-
-Decision deferred to user.
-
-### 5.3 Improve Token Estimation
-
-Replace `rune/2` with proper estimation:
-```go
-func estimateTokens(text string) int {
- // Use tiktoken-go or similar for accurate counting
- // Fallback to heuristic per language detection
-}
-```
-
-Or use Genkit's built-in token counter if available.
-
----
-
-## Phase 6: Security Hardening
-
-### 6.1 Fix Prompt Injection Homoglyph Bypass
-
-Add Unicode NFKD normalization before pattern matching:
-```go
-import "golang.org/x/text/unicode/normalize"
-
-normalized := normalize.NFKD.String(input)
-// Then apply regex patterns
-```
-
-### 6.2 Enforce SafeTransport Usage
-
-`SafeTransport()` exists but is never used in production. Wire it into network tools.
-
-### 6.3 Command Whitelist Gaps
-
-Add missing blocked patterns:
-- `npm install` (postinstall hooks execute code)
-- `go test -run` (executes test code)
-
----
-
-## Phase 7: Testing Overhaul
-
-### 7.1 Remove Pointless Tests
-
-Tests that only call `t.Skip()` with no body add no value.
-(Most already removed in Phase A cleanup.)
-
-### 7.2 Add Missing Critical Tests
-
-| Test | Module | Why |
-|------|--------|-----|
-| Stream empty event handling | tui | Validates Phase 1.1 fix |
-| Concurrent stream start/cancel | tui | Race condition coverage |
-| deepCopyMessages mutation | chat | Direct unit test (currently indirect) |
-| MCP JSON-RPC protocol | mcp | No protocol-level tests exist |
-| Session ownership check | api | Validates Phase 4.1 fix |
-
-### 7.3 TUI Testing with teatest
-
-Use `charmbracelet/x/exp/teatest` for golden file testing of TUI output.
-
----
-
-## Execution Order
-
-Phases can be partially parallelized:
-
-```
-Phase 1 (bug fixes) ─────────────────────────→ immediate
-Phase 2 (wire removal, DI) ──────────────────→ after Phase 1
-Phase 3 (TUI v2) ──────┐ after Phase 2
-Phase 4 (API + spec) ──┤ parallel after Phase 2
-Phase 5 (RAG) ─────────┘ after Phase 2
-Phase 6 (security) ──────────────────────────→ after Phase 4
-Phase 7 (testing) ───────────────────────────→ continuous
-```
-
-**CLI and API work don't conflict** — Phase 3 (TUI) and Phase 4 (API) can run in parallel after Phase 2 establishes the new DI structure.
-
----
-
-## Out of Scope
-
-- Angular Web UI implementation (separate project)
-- New AI providers beyond gemini/ollama/openai
-- Multi-user authentication system
-- WebSocket support (future consideration)
-- Kubernetes deployment manifests
-
----
-
-## Risk Assessment
-
-| Phase | Risk | Mitigation |
-|-------|------|------------|
-| 1 (bugs) | Low | Pure fixes, existing tests validate |
-| 2 (wire) | Medium | Manual DI is well-understood; test suite validates |
-| 3 (TUI v2) | High | v2 is RC.2, API may change; pin specific version |
-| 4 (API) | Medium | Existing tests cover endpoints; add ownership tests |
-| 5 (RAG) | Low | Isolated module, minimal cross-dependencies |
-| 6 (security) | Medium | Security changes need careful review |
-| 7 (testing) | Low | Only adds tests, no production code risk |
-
----
-
-## Success Criteria
-
-- [ ] `go build ./...` passes
-- [ ] `go vet ./...` passes
-- [ ] `golangci-lint run ./...` — 0 issues
-- [ ] `go test -race ./...` — all pass
-- [ ] All 4 red issues resolved
-- [ ] Wire removed, manual DI working
-- [ ] TUI running on Bubble Tea v2
-- [ ] API spec document complete
-- [ ] No regressions in existing functionality
diff --git a/docs/proposals/003-naming-conventions-and-dead-code.md b/docs/proposals/003-naming-conventions-and-dead-code.md
deleted file mode 100644
index 455415a..0000000
--- a/docs/proposals/003-naming-conventions-and-dead-code.md
+++ /dev/null
@@ -1,324 +0,0 @@
-# Proposal 003: Naming Conventions & Dead Code Cleanup
-
-## Overview
-
-Based on full-codebase go-reviewer and security-reviewer findings (post-Proposal 002), this proposal addresses Go naming convention violations and dead code in the session layer.
-
-**Scope**: `session.Store` method renames, dead code removal, `clientIP()` security fix.
-**Risk**: Medium — public API rename across multiple consumers, but zero interface breakage.
-**Tier**: 2 (existing feature modification, no new packages or types).
-
-### Reviewer Status
-
-- go-reviewer: 0 BLOCKING, 3 WARNING, 3 SUGGESTION — all addressed below
-- security-reviewer: 0 CRITICAL, 1 HIGH, 3 MEDIUM — all addressed below
-
----
-
-## Phase 1: Remove Dead Code (Low Risk)
-
-Remove 2 exported methods on `session.Store` that have **zero callers and zero tests**, plus their associated sentinel error.
-
-### 1.1 Delete `GetUserMessageBefore`
-
-**File**: `internal/session/store.go:611-658`
-**Evidence**: Comprehend agent confirmed 0 production callers, 0 test callers.
-**SQL**: `db/queries/sessions.sql:104-113` — `-- name: GetUserMessageBefore :one`
-
-Delete the Store method and remove the SQL query.
-
-### 1.2 Delete `GetMessageByID`
-
-**File**: `internal/session/store.go:660-676`
-**Evidence**: Comprehend agent confirmed 0 production callers, 0 test callers.
-**SQL**: `db/queries/sessions.sql:115-119` — `-- name: GetMessageByID :one`
-
-Same approach — delete Store method and SQL query.
-
-### 1.3 Delete `ErrMessageNotFound`
-
-**File**: `internal/session/errors.go:42-43`
-
-After removing the two methods above, `ErrMessageNotFound` has zero remaining references. Remove the sentinel error to avoid dead code.
-(go-reviewer SUGGESTION #1 — confirmed via grep: all 5 usages are within the two deleted methods)
-
-### 1.4 Update doc.go
-
-Remove any references to deleted methods in `internal/session/doc.go`.
-
-### 1.5 Regenerate sqlc
-
-```bash
-cd build/sql && sqlc generate
-```
-
-This removes the generated code for the two deleted SQL queries.
-
-### Verification
-
-```bash
-go build ./... # Confirms no broken references
-go vet ./... # Static analysis
-go test ./... # No test regression (methods had 0 tests)
-```
-
----
-
-## Phase 2: Rename Store Methods (Medium Risk)
-
-Remove `Get` prefix from 3 `session.Store` getter methods per `naming.md`:
-> "No `Get` prefix for getters"
-
-### 2.1 Rename SQL Queries
-
-**File**: `db/queries/sessions.sql`
-
-| Before | After | Line |
-|--------|-------|------|
-| `-- name: GetSession :one` | `-- name: Session :one` | 9 |
-| `-- name: GetMessages :many` | `-- name: Messages :many` | 53 |
-
-**Note**: `GetHistory` has no SQL query (it's composed in Go from `Session` + `Messages`).
-
-### 2.2 Regenerate sqlc
-
-```bash
-cd build/sql && sqlc generate
-```
-
-This produces cascading changes in `internal/sqlc/sessions.sql.go`:
-- `func (q *Queries) GetSession(...)` -> `func (q *Queries) Session(...)`
-- `func (q *Queries) GetMessages(...)` -> `func (q *Queries) Messages(...)`
-- `type GetMessagesParams` -> `type MessagesParams`
-
-### 2.3 Rename `session.Store` Methods
-
-**File**: `internal/session/store.go`
-
-| Before | After | Line |
-|--------|-------|------|
-| `func (s *Store) GetSession(...)` | `func (s *Store) Session(...)` | 90 |
-| `func (s *Store) GetMessages(...)` | `func (s *Store) Messages(...)` | 323 |
-| `func (s *Store) GetHistory(...)` | `func (s *Store) History(...)` | 404 |
-
-Internal self-calls within `store.go`:
-- Line 406: `s.GetSession(ctx, sessionID)` -> `s.Session(ctx, sessionID)`
-- Line 417: `s.GetMessages(ctx, sessionID, limit, 0)` -> `s.Messages(ctx, sessionID, limit, 0)`
-
-Internal sqlc calls:
-- Line 91: `s.queries.GetSession(ctx, sessionID)` -> `s.queries.Session(ctx, sessionID)`
-- Line 324: `s.queries.GetMessages(ctx, sqlc.GetMessagesParams{...})` -> `s.queries.Messages(ctx, sqlc.MessagesParams{...})`
-
-### 2.4 Update Production Callers
-
-| File | Line | Before | After |
-|------|------|--------|-------|
-| `cmd/cli.go` | 64 | `store.GetSession(ctx, *currentID)` | `store.Session(ctx, *currentID)` |
-| `internal/api/session.go` | 57 | `sm.store.GetSession(r.Context(), sessionID)` | `sm.store.Session(r.Context(), sessionID)` |
-| `internal/api/session.go` | 268 | `sm.store.GetSession(r.Context(), sessionID)` | `sm.store.Session(r.Context(), sessionID)` |
-| `internal/api/session.go` | 313 | `sm.store.GetSession(r.Context(), id)` | `sm.store.Session(r.Context(), id)` |
-| `internal/api/chat.go` | 321 | `h.sessions.store.GetSession(ctx, sessionUUID)` | `h.sessions.store.Session(ctx, sessionUUID)` |
-| `internal/api/session.go` | 340 | `sm.store.GetMessages(r.Context(), id, 100, 0)` | `sm.store.Messages(r.Context(), id, 100, 0)` |
-| `internal/agent/chat/chat.go` | 248 | `c.sessions.GetHistory(ctx, sessionID)` | `c.sessions.History(ctx, sessionID)` |
-
-**7 production call sites total.**
-
-### 2.5 Update Test Files
-
-Approximately ~40 test call sites across:
-- `internal/session/integration_test.go` (~30 sites)
-- `internal/session/benchmark_test.go` (~10 sites)
-
-All are direct `store.GetSession(...)`, `store.GetMessages(...)`, `store.GetHistory(...)` calls -> rename to `store.Session(...)`, `store.Messages(...)`, `store.History(...)`.
-
-**Note on `internal/api/session_test.go`**: 4 test functions contain `GetSession`/`GetSessionMessages` in their names (`TestGetSession_InvalidUUID`, `TestGetSession_MissingID`, `TestGetSessionMessages_InvalidUUID`, `TestGetSessionMessages_OwnershipDenied`). These test the HTTP handler methods (`sm.getSession`, `sm.getSessionMessages`) which are *unexported* and *not* being renamed. The test function names are **intentionally left unchanged** as they reference the handler, not the Store method.
-(go-reviewer SUGGESTION #2 — addressed)
-
-### 2.6 Update Documentation
-
-- `internal/session/doc.go` — update method name references
-- `internal/session/errors.go:34` — update doc comment example `store.GetSession(ctx, id)` -> `store.Session(ctx, id)`
-- `internal/agent/chat/doc.go` — update `GetHistory` reference
-
-(go-reviewer WARNING #1 — `errors.go` doc comment added to list)
-
-### Verification
-
-```bash
-cd build/sql && sqlc generate # Regenerate
-go build ./... # Compile check
-go vet ./... # Static analysis
-golangci-lint run ./... # 0 issues
-go test -race ./... # Full test suite
-```
-
----
-
-## Phase 3: Fix clientIP() Proxy Header Trust (Low-Medium Risk)
-
-### Problem
-
-`internal/api/middleware.go:363-383` — `clientIP()` unconditionally trusts `X-Forwarded-For` header. Any client can spoof this header to bypass rate limiting.
-(security-reviewer H2 — current vulnerability being fixed)
-
-### Fix
-
-Add a `TrustProxy` configuration option. When `false` (default), ignore proxy headers entirely. When `true`, prefer `X-Real-IP` over `X-Forwarded-For` to prevent spoofing.
-(security-reviewer H1 — switched to X-Real-IP priority to avoid leftmost-XFF spoofing)
-
-#### 3.1 Add config chain
-
-**File**: `internal/config/config.go`
-- Add `TrustProxy bool` field to config struct
-- Add `viper.SetDefault("trust_proxy", false)` in `setDefaults`
-- Add `mustBind("trust_proxy", "KOOPA_TRUST_PROXY")` in `bindEnvVariables`
-
-(security-reviewer M2 — full config wiring)
-
-**File**: `config.example.yaml`
-- Add `trust_proxy: false` with comment explaining when to set `true`
-
-(security-reviewer M1 — migration note for existing proxy deployments)
-
-#### 3.2 Wire through ServerConfig
-
-**File**: `internal/api/server.go`
-
-```go
-type ServerConfig struct {
- // ... existing fields ...
- TrustProxy bool // Trust X-Real-IP/X-Forwarded-For headers (set true behind reverse proxy)
-}
-```
-
-**File**: `cmd/serve.go` — pass `cfg.TrustProxy` into `api.ServerConfig`
-
-#### 3.3 Update clientIP function
-
-**File**: `internal/api/middleware.go`
-
-```go
-// clientIP extracts the client IP from the request.
-// If trustProxy is true, checks X-Real-IP first (single-valued, set by proxy),
-// then X-Forwarded-For. Validates extracted IP with net.ParseIP.
-// Otherwise, uses RemoteAddr only (safe default for direct exposure).
-func clientIP(r *http.Request, trustProxy bool) string {
- if trustProxy {
- // Prefer X-Real-IP: single value set by the reverse proxy, not spoofable
- if xri := r.Header.Get("X-Real-IP"); xri != "" {
- ip := strings.TrimSpace(xri)
- if net.ParseIP(ip) != nil {
- return ip
- }
- }
- // Fallback: X-Forwarded-For (first entry, client-provided — less trustworthy)
- if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
- raw, _, _ := strings.Cut(xff, ",")
- ip := strings.TrimSpace(raw)
- if net.ParseIP(ip) != nil {
- return ip
- }
- }
- }
-
- ip, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
- return r.RemoteAddr
- }
- return ip
-}
-```
-
-Key changes from original proposal (per security review):
-1. **X-Real-IP checked first** — single value set by trusted proxy, not appendable by client (H1 fix)
-2. **`net.ParseIP` validation** — rejects garbage values, falls back to RemoteAddr (M3 fix)
-3. **`trustProxy` parameter** — gated by config, default `false` (existing fix)
-
-(security-reviewer H1 + M3 — addressed)
-
-#### 3.4 Update rateLimitMiddleware
-
-**File**: `internal/api/middleware.go`
-
-Pass `trustProxy` through closure to `clientIP`:
-
-```go
-func rateLimitMiddleware(rl *rateLimiter, trustProxy bool, logger *slog.Logger) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- ip := clientIP(r, trustProxy)
- // ... rest unchanged
- })
- }
-}
-```
-
-Update call site in `server.go:83` to pass `cfg.TrustProxy`.
-
-#### 3.5 Update tests
-
-**File**: `internal/api/middleware_test.go`
-
-- Update all 6 `clientIP(r)` calls to `clientIP(r, true)` for existing proxy-trust test cases
-- Add new test cases for `clientIP(r, false)` verifying proxy headers are ignored
-- Add test case for invalid IP in proxy header (falls back to RemoteAddr)
-
-(go-reviewer WARNING #2 + SUGGESTION #3 — middleware_test.go added to file list)
-
-### Verification
-
-```bash
-go build ./...
-go vet ./...
-golangci-lint run ./...
-go test -race ./...
-```
-
----
-
-## Out of Scope
-
-- RAG non-atomic upsert — documented, low-risk, runs only at startup
-- Secret masking improvements — cosmetic, no security impact
-- sqlc `emit_interface` change — unnecessary (no consumers need an interface)
-- `GetMaxSequenceNumber` rename — not a getter on Store, it's an internal sqlc query name
-- `TestGetSession_*` test function renames in `api/session_test.go` — test HTTP handlers, not Store
-
----
-
-## Implementation Order
-
-| Step | Phase | Files Modified | Risk |
-|------|-------|---------------|------|
-| 1 | Phase 1 | store.go, sessions.sql, errors.go, doc.go | Low |
-| 2 | sqlc regenerate | `cd build/sql && sqlc generate` | None |
-| 3 | Phase 2 | sessions.sql, store.go, cli.go, session.go, chat.go, doc.go, errors.go, tests | Medium |
-| 4 | sqlc regenerate | `cd build/sql && sqlc generate` | None |
-| 5 | Phase 3 | config.go, server.go, serve.go, middleware.go, middleware_test.go, config.example.yaml | Low-Medium |
-| 6 | Full verification | — | — |
-
-Each phase is independently verifiable. Phases 1+2 can be combined into a single sqlc regeneration if committed together.
-(go-reviewer SUGGESTION #4 — sqlc path made consistent, combination noted)
-
----
-
-## Risk Assessment
-
-| Risk | Mitigation |
-|------|-----------|
-| Missed call site breaks build | `go build ./...` catches immediately |
-| sqlc param type rename missed | Compiler error on `sqlc.GetMessagesParams` |
-| Test name references old method | Find-and-replace + `go test` catches |
-| `clientIP` signature change breaks callers | Only 1 production call site + 6 test calls; compiler catches |
-| TrustProxy default=false behind proxy | Documented migration: set `KOOPA_TRUST_PROXY=true` or `trust_proxy: true` |
-
----
-
-## Estimated Edit Count
-
-| Phase | Files | Edit Sites |
-|-------|-------|-----------|
-| Phase 1 (dead code) | 4 | ~7 |
-| Phase 2 (rename) | 9 | ~50 |
-| Phase 3 (clientIP) | 6 | ~15 |
-| **Total** | **~15** | **~72** |
diff --git a/go.mod b/go.mod
index 4e9ed08..554d324 100644
--- a/go.mod
+++ b/go.mod
@@ -24,8 +24,8 @@ require (
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0
go.opentelemetry.io/otel/sdk v1.38.0
go.uber.org/goleak v1.3.0
- golang.org/x/sync v0.18.0
golang.org/x/time v0.14.0
+ google.golang.org/genai v1.41.0
)
require (
@@ -170,12 +170,12 @@ require (
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/oauth2 v0.33.0 // indirect
+ golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/term v0.37.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/api v0.252.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
- google.golang.org/genai v1.41.0 // indirect
google.golang.org/genproto v0.0.0-20251014184007-4626949a642f // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect
diff --git a/internal/agent/chat/chat_test.go b/internal/agent/chat/chat_test.go
deleted file mode 100644
index 8592aea..0000000
--- a/internal/agent/chat/chat_test.go
+++ /dev/null
@@ -1,565 +0,0 @@
-package chat
-
-import (
- "context"
- "errors"
- "log/slog"
- "strings"
- "testing"
-
- "github.com/firebase/genkit/go/ai"
-)
-
-// TestNew_ValidationErrors tests constructor validation
-func TestNew_ValidationErrors(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
- cfg Config
- errContains string
- }{
- {
- name: "nil genkit",
- cfg: Config{},
- errContains: "genkit instance is required",
- },
- {
- name: "nil retriever",
- cfg: Config{
- Genkit: nil, // Still nil, so we'll get Genkit error first
- },
- errContains: "genkit instance is required",
- },
- {
- name: "nil logger - requires all previous deps",
- cfg: Config{
- // Missing Genkit
- },
- errContains: "genkit instance is required",
- },
- {
- name: "empty tools - requires all previous deps",
- cfg: Config{
- // Missing Genkit
- },
- errContains: "genkit instance is required",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
- _, err := New(tt.cfg)
- if err == nil {
- t.Fatal("New() expected error, got nil")
- }
- if !strings.Contains(err.Error(), tt.errContains) {
- t.Errorf("New() error = %q, want to contain %q", err.Error(), tt.errContains)
- }
- })
- }
-}
-
-// TestConstants tests package constants
-func TestConstants(t *testing.T) {
- t.Parallel()
-
- t.Run("Name constant", func(t *testing.T) {
- t.Parallel()
- if Name != "chat" {
- t.Errorf("Name = %q, want %q", Name, "chat")
- }
- })
-
- t.Run("Description is not empty", func(t *testing.T) {
- t.Parallel()
- if Description == "" {
- t.Error("Description is empty, want non-empty")
- }
- })
-
- t.Run("KoopaPromptName is set", func(t *testing.T) {
- t.Parallel()
- if KoopaPromptName != "koopa" {
- t.Errorf("KoopaPromptName = %q, want %q", KoopaPromptName, "koopa")
- }
- })
-}
-
-// TestStreamCallback_Type tests the StreamCallback type definition
-func TestStreamCallback_Type(t *testing.T) {
- t.Parallel()
-
- t.Run("nil callback is valid", func(t *testing.T) {
- t.Parallel()
- var callback StreamCallback
- if callback != nil {
- t.Errorf("nil callback = %v, want nil", callback)
- }
- })
-
- t.Run("callback can be assigned", func(t *testing.T) {
- t.Parallel()
- called := false
- callback := StreamCallback(func(_ context.Context, _ *ai.ModelResponseChunk) error {
- called = true
- return nil
- })
- if callback == nil {
- t.Fatal("callback is nil, want non-nil")
- }
- err := callback(context.Background(), nil)
- if err != nil {
- t.Errorf("callback() unexpected error: %v", err)
- }
- if !called {
- t.Error("callback was not called")
- }
- })
-
- t.Run("callback can return error", func(t *testing.T) {
- t.Parallel()
- expectedErr := errors.New("test error")
- callback := StreamCallback(func(_ context.Context, _ *ai.ModelResponseChunk) error {
- return expectedErr
- })
- err := callback(context.Background(), nil)
- if !errors.Is(err, expectedErr) {
- t.Errorf("callback() = %v, want %v", err, expectedErr)
- }
- })
-}
-
-// TestConfig_Structure tests the Config struct
-func TestConfig_Structure(t *testing.T) {
- t.Parallel()
-
- t.Run("zero value has nil fields", func(t *testing.T) {
- t.Parallel()
- var cfg Config
- if cfg.Genkit != nil {
- t.Errorf("cfg.Genkit = %v, want nil", cfg.Genkit)
- }
- if cfg.Retriever != nil {
- t.Errorf("cfg.Retriever = %v, want nil", cfg.Retriever)
- }
- if cfg.SessionStore != nil {
- t.Errorf("cfg.SessionStore = %v, want nil", cfg.SessionStore)
- }
- if cfg.Logger != nil {
- t.Errorf("cfg.Logger = %v, want nil", cfg.Logger)
- }
- if cfg.Tools != nil {
- t.Errorf("cfg.Tools = %v, want nil", cfg.Tools)
- }
- })
-}
-
-// TestChat_RetrieveRAGContext_SkipsWhenTopKZero tests RAG context retrieval
-func TestChat_RetrieveRAGContext_SkipsWhenTopKZero(t *testing.T) {
- t.Parallel()
-
- t.Run("returns nil when topK is zero", func(t *testing.T) {
- t.Parallel()
- c := &Chat{
- ragTopK: 0,
- logger: slog.Default(),
- }
- docs := c.retrieveRAGContext(context.Background(), "test query")
- if docs != nil {
- t.Errorf("retrieveRAGContext() = %v, want nil", docs)
- }
- })
-
- t.Run("returns nil when topK is negative", func(t *testing.T) {
- t.Parallel()
- c := &Chat{
- ragTopK: -1,
- logger: slog.Default(),
- }
- docs := c.retrieveRAGContext(context.Background(), "test query")
- if docs != nil {
- t.Errorf("retrieveRAGContext() = %v, want nil", docs)
- }
- })
-}
-
-// =============================================================================
-// Edge Case Tests for Real Scenarios
-// =============================================================================
-
-// TestChat_EmptyResponseHandling tests that empty model responses are handled gracefully.
-func TestChat_EmptyResponseHandling(t *testing.T) {
- t.Parallel()
-
- t.Run("empty string triggers fallback", func(t *testing.T) {
- t.Parallel()
- // Test the logic of empty response detection
- responseText := ""
- if strings.TrimSpace(responseText) == "" {
- responseText = FallbackResponseMessage
- }
- if !strings.Contains(responseText, "apologize") {
- t.Errorf("fallback response = %q, want to contain %q", responseText, "apologize")
- }
- if responseText == "" {
- t.Error("fallback response is empty, want non-empty")
- }
- })
-
- t.Run("whitespace-only triggers fallback", func(t *testing.T) {
- t.Parallel()
- responseText := " \n\t "
- if strings.TrimSpace(responseText) == "" {
- responseText = FallbackResponseMessage
- }
- if !strings.Contains(responseText, "apologize") {
- t.Errorf("fallback response = %q, want to contain %q", responseText, "apologize")
- }
- })
-
- t.Run("valid response is preserved", func(t *testing.T) {
- t.Parallel()
- responseText := "Hello, I'm here to help!"
- originalText := responseText
- if strings.TrimSpace(responseText) == "" {
- responseText = FallbackResponseMessage
- }
- if responseText != originalText {
- t.Errorf("responseText = %q, want %q", responseText, originalText)
- }
- })
-}
-
-// TestChat_ContextCancellation tests graceful handling of context cancellation.
-func TestChat_ContextCancellation(t *testing.T) {
- t.Parallel()
-
- t.Run("canceled context is detected", func(t *testing.T) {
- t.Parallel()
- ctx, cancel := context.WithCancel(context.Background())
- cancel() // Cancel immediately
-
- // Verify context is canceled
- if !errors.Is(ctx.Err(), context.Canceled) {
- t.Errorf("ctx.Err() = %v, want context.Canceled", ctx.Err())
- }
- })
-
- t.Run("deadline exceeded is different from canceled", func(t *testing.T) {
- t.Parallel()
- ctx, cancel := context.WithTimeout(context.Background(), 0)
- defer cancel()
-
- // Wait for timeout
- <-ctx.Done()
-
- // DeadlineExceeded is different from Canceled
- if !errors.Is(ctx.Err(), context.DeadlineExceeded) {
- t.Errorf("ctx.Err() = %v, want context.DeadlineExceeded", ctx.Err())
- }
- if errors.Is(ctx.Err(), context.Canceled) {
- t.Errorf("ctx.Err() = context.Canceled, want context.DeadlineExceeded")
- }
- })
-}
-
-// TestChat_MaxTurnsProtection tests that conversation doesn't loop infinitely.
-// Safety: Prevents runaway agent loops that could exhaust resources.
-func TestChat_MaxTurnsProtection(t *testing.T) {
- t.Parallel()
-
- t.Run("max turns concept validation", func(t *testing.T) {
- t.Parallel()
- // In a real agent loop, we would track turns
- maxTurns := 10
- currentTurn := 0
-
- // Simulate turn counting
- for i := 0; i < 100; i++ {
- currentTurn++
- if currentTurn >= maxTurns {
- break
- }
- }
-
- if currentTurn != maxTurns {
- t.Errorf("currentTurn = %d, want %d (should stop at max turns)", currentTurn, maxTurns)
- }
- })
-}
-
-// TestChat_ToolFailureRecovery tests that the agent can continue after tool failures.
-// Resilience: Agent should gracefully handle tool execution errors.
-func TestChat_ToolFailureRecovery(t *testing.T) {
- t.Parallel()
-
- t.Run("tool error is wrapped", func(t *testing.T) {
- t.Parallel()
- toolErr := errors.New("tool failed: file not found")
- wrappedErr := errors.New("tool execution failed: " + toolErr.Error())
- if !strings.Contains(wrappedErr.Error(), "tool execution failed") {
- t.Errorf("wrappedErr = %q, want to contain %q", wrappedErr.Error(), "tool execution failed")
- }
- if !strings.Contains(wrappedErr.Error(), "file not found") {
- t.Errorf("wrappedErr = %q, want to contain %q", wrappedErr.Error(), "file not found")
- }
- })
-
- t.Run("tool error does not crash agent", func(t *testing.T) {
- t.Parallel()
- // Simulate error handling that doesn't propagate
- var lastErr error
- handleToolError := func(err error) {
- lastErr = err // Log but don't crash
- }
-
- handleToolError(errors.New("tool failed"))
- if lastErr == nil {
- t.Error("lastErr is nil, want non-nil")
- }
- // Agent continues running
- })
-}
-
-// =============================================================================
-// deepCopyMessages / deepCopyPart / shallowCopyMap Tests
-// =============================================================================
-
-func TestDeepCopyMessages_NilInput(t *testing.T) {
- t.Parallel()
- got := deepCopyMessages(nil)
- if got != nil {
- t.Errorf("deepCopyMessages(nil) = %v, want nil", got)
- }
-}
-
-func TestDeepCopyMessages_EmptySlice(t *testing.T) {
- t.Parallel()
- got := deepCopyMessages([]*ai.Message{})
- if got == nil {
- t.Fatal("deepCopyMessages(empty) = nil, want non-nil empty slice")
- }
- if len(got) != 0 {
- t.Errorf("deepCopyMessages(empty) len = %d, want 0", len(got))
- }
-}
-
-func TestDeepCopyMessages_MutateOriginalText(t *testing.T) {
- t.Parallel()
-
- original := []*ai.Message{
- ai.NewUserMessage(ai.NewTextPart("hello world")),
- }
-
- copied := deepCopyMessages(original)
-
- // Mutate the original message's content slice
- original[0].Content[0].Text = "MUTATED"
-
- if copied[0].Content[0].Text != "hello world" {
- t.Errorf("deepCopyMessages() copy was affected by original mutation: got %q, want %q",
- copied[0].Content[0].Text, "hello world")
- }
-}
-
-func TestDeepCopyMessages_MutateOriginalContentSlice(t *testing.T) {
- t.Parallel()
-
- original := []*ai.Message{
- ai.NewUserMessage(ai.NewTextPart("first"), ai.NewTextPart("second")),
- }
-
- copied := deepCopyMessages(original)
-
- // Append to original's content slice — should not affect copy
- original[0].Content = append(original[0].Content, ai.NewTextPart("third"))
-
- if len(copied[0].Content) != 2 {
- t.Errorf("deepCopyMessages() copy content len = %d, want 2", len(copied[0].Content))
- }
-}
-
-func TestDeepCopyMessages_PreservesRole(t *testing.T) {
- t.Parallel()
-
- original := []*ai.Message{
- ai.NewUserMessage(ai.NewTextPart("q")),
- ai.NewModelMessage(ai.NewTextPart("a")),
- }
-
- copied := deepCopyMessages(original)
-
- if copied[0].Role != ai.RoleUser {
- t.Errorf("deepCopyMessages()[0].Role = %q, want %q", copied[0].Role, ai.RoleUser)
- }
- if copied[1].Role != ai.RoleModel {
- t.Errorf("deepCopyMessages()[1].Role = %q, want %q", copied[1].Role, ai.RoleModel)
- }
-}
-
-func TestDeepCopyMessages_Metadata(t *testing.T) {
- t.Parallel()
-
- original := []*ai.Message{{
- Role: ai.RoleUser,
- Content: []*ai.Part{ai.NewTextPart("test")},
- Metadata: map[string]any{"key": "value"},
- }}
-
- copied := deepCopyMessages(original)
-
- // Mutate original metadata
- original[0].Metadata["key"] = "MUTATED"
-
- if copied[0].Metadata["key"] != "value" {
- t.Errorf("deepCopyMessages() metadata was affected by mutation: got %q, want %q",
- copied[0].Metadata["key"], "value")
- }
-}
-
-func TestDeepCopyPart_NilInput(t *testing.T) {
- t.Parallel()
- got := deepCopyPart(nil)
- if got != nil {
- t.Errorf("deepCopyPart(nil) = %v, want nil", got)
- }
-}
-
-func TestDeepCopyPart_TextPart(t *testing.T) {
- t.Parallel()
-
- original := ai.NewTextPart("hello")
- copied := deepCopyPart(original)
-
- original.Text = "MUTATED"
-
- if copied.Text != "hello" {
- t.Errorf("deepCopyPart() text affected by mutation: got %q, want %q", copied.Text, "hello")
- }
-}
-
-func TestDeepCopyPart_ToolRequest(t *testing.T) {
- t.Parallel()
-
- original := &ai.Part{
- Kind: ai.PartToolRequest,
- ToolRequest: &ai.ToolRequest{
- Name: "read_file",
- Input: map[string]any{"path": "/tmp/test"},
- },
- }
-
- copied := deepCopyPart(original)
-
- // Mutate original ToolRequest name
- original.ToolRequest.Name = "MUTATED"
-
- if copied.ToolRequest.Name != "read_file" {
- t.Errorf("deepCopyPart() ToolRequest.Name affected by mutation: got %q, want %q",
- copied.ToolRequest.Name, "read_file")
- }
-}
-
-func TestDeepCopyPart_ToolResponse(t *testing.T) {
- t.Parallel()
-
- original := &ai.Part{
- Kind: ai.PartToolResponse,
- ToolResponse: &ai.ToolResponse{
- Name: "read_file",
- Output: "file contents",
- },
- }
-
- copied := deepCopyPart(original)
-
- original.ToolResponse.Name = "MUTATED"
-
- if copied.ToolResponse.Name != "read_file" {
- t.Errorf("deepCopyPart() ToolResponse.Name affected by mutation: got %q, want %q",
- copied.ToolResponse.Name, "read_file")
- }
-}
-
-func TestDeepCopyPart_Resource(t *testing.T) {
- t.Parallel()
-
- original := &ai.Part{
- Kind: ai.PartMedia,
- Resource: &ai.ResourcePart{Uri: "https://example.com/image.png"},
- }
-
- copied := deepCopyPart(original)
-
- original.Resource.Uri = "MUTATED"
-
- if copied.Resource.Uri != "https://example.com/image.png" {
- t.Errorf("deepCopyPart() Resource.Uri affected by mutation: got %q, want %q",
- copied.Resource.Uri, "https://example.com/image.png")
- }
-}
-
-func TestDeepCopyPart_PartMetadata(t *testing.T) {
- t.Parallel()
-
- original := &ai.Part{
- Kind: ai.PartText,
- Text: "test",
- Custom: map[string]any{"c": "custom"},
- Metadata: map[string]any{"m": "meta"},
- }
-
- copied := deepCopyPart(original)
-
- original.Custom["c"] = "MUTATED"
- original.Metadata["m"] = "MUTATED"
-
- if copied.Custom["c"] != "custom" {
- t.Errorf("deepCopyPart() Custom map affected: got %q, want %q", copied.Custom["c"], "custom")
- }
- if copied.Metadata["m"] != "meta" {
- t.Errorf("deepCopyPart() Metadata map affected: got %q, want %q", copied.Metadata["m"], "meta")
- }
-}
-
-func TestShallowCopyMap_NilInput(t *testing.T) {
- t.Parallel()
- got := shallowCopyMap(nil)
- if got != nil {
- t.Errorf("shallowCopyMap(nil) = %v, want nil", got)
- }
-}
-
-func TestShallowCopyMap_IndependentKeys(t *testing.T) {
- t.Parallel()
-
- original := map[string]any{"a": "1", "b": "2"}
- copied := shallowCopyMap(original)
-
- // Add new key to original
- original["c"] = "3"
-
- if _, ok := copied["c"]; ok {
- t.Error("shallowCopyMap() new key in original appeared in copy")
- }
- if len(copied) != 2 {
- t.Errorf("shallowCopyMap() copy len = %d, want 2", len(copied))
- }
-}
-
-func TestShallowCopyMap_MutateValue(t *testing.T) {
- t.Parallel()
-
- original := map[string]any{"key": "value"}
- copied := shallowCopyMap(original)
-
- // Overwrite original value
- original["key"] = "MUTATED"
-
- if copied["key"] != "value" {
- t.Errorf("shallowCopyMap() value affected by mutation: got %q, want %q",
- copied["key"], "value")
- }
-}
diff --git a/internal/agent/chat/flow_test.go b/internal/agent/chat/flow_test.go
deleted file mode 100644
index 644288e..0000000
--- a/internal/agent/chat/flow_test.go
+++ /dev/null
@@ -1,154 +0,0 @@
-package chat
-
-import (
- "errors"
- "testing"
-
- "github.com/koopa0/koopa/internal/agent"
-)
-
-// TestFlowName tests the FlowName constant
-func TestFlowName(t *testing.T) {
- t.Parallel()
- if FlowName != "koopa/chat" {
- t.Errorf("FlowName = %q, want %q", FlowName, "koopa/chat")
- }
- if FlowName == "" {
- t.Error("FlowName is empty, want non-empty")
- }
-}
-
-// TestStreamChunk_Structure tests the StreamChunk type
-func TestStreamChunk_Structure(t *testing.T) {
- t.Parallel()
-
- t.Run("zero value has empty text", func(t *testing.T) {
- t.Parallel()
- var chunk StreamChunk
- if chunk.Text != "" {
- t.Errorf("chunk.Text = %q, want %q", chunk.Text, "")
- }
- })
-
- t.Run("can set text", func(t *testing.T) {
- t.Parallel()
- chunk := StreamChunk{Text: "Hello, World!"}
- if chunk.Text != "Hello, World!" {
- t.Errorf("chunk.Text = %q, want %q", chunk.Text, "Hello, World!")
- }
- })
-
- t.Run("can hold unicode text", func(t *testing.T) {
- t.Parallel()
- chunk := StreamChunk{Text: "你好世界 🌍"}
- if chunk.Text != "你好世界 🌍" {
- t.Errorf("chunk.Text = %q, want %q", chunk.Text, "你好世界 🌍")
- }
- })
-}
-
-// TestInput_Structure tests the Input type
-func TestInput_Structure(t *testing.T) {
- t.Parallel()
-
- t.Run("zero value has empty fields", func(t *testing.T) {
- t.Parallel()
- var input Input
- if input.Query != "" {
- t.Errorf("input.Query = %q, want %q", input.Query, "")
- }
- if input.SessionID != "" {
- t.Errorf("input.SessionID = %q, want %q", input.SessionID, "")
- }
- })
-
- t.Run("can set all fields", func(t *testing.T) {
- t.Parallel()
- input := Input{
- Query: "What is the weather?",
- SessionID: "test-session-123",
- }
- if input.Query != "What is the weather?" {
- t.Errorf("input.Query = %q, want %q", input.Query, "What is the weather?")
- }
- if input.SessionID != "test-session-123" {
- t.Errorf("input.SessionID = %q, want %q", input.SessionID, "test-session-123")
- }
- })
-}
-
-// TestOutput_Structure tests the Output type
-func TestOutput_Structure(t *testing.T) {
- t.Parallel()
-
- t.Run("zero value has empty fields", func(t *testing.T) {
- t.Parallel()
- var output Output
- if output.Response != "" {
- t.Errorf("output.Response = %q, want %q", output.Response, "")
- }
- if output.SessionID != "" {
- t.Errorf("output.SessionID = %q, want %q", output.SessionID, "")
- }
- })
-
- t.Run("can set response and session", func(t *testing.T) {
- t.Parallel()
- output := Output{
- Response: "The weather is sunny.",
- SessionID: "test-session-123",
- }
- if output.Response != "The weather is sunny." {
- t.Errorf("output.Response = %q, want %q", output.Response, "The weather is sunny.")
- }
- if output.SessionID != "test-session-123" {
- t.Errorf("output.SessionID = %q, want %q", output.SessionID, "test-session-123")
- }
- })
-}
-
-// TestSentinelErrors_CanBeChecked tests that sentinel errors work correctly with errors.Is
-func TestSentinelErrors_CanBeChecked(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
- err error
- sentinel error
- }{
- {"ErrInvalidSession", agent.ErrInvalidSession, agent.ErrInvalidSession},
- {"ErrExecutionFailed", agent.ErrExecutionFailed, agent.ErrExecutionFailed},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
- if !errors.Is(tt.err, tt.sentinel) {
- t.Errorf("errors.Is(%v, %v) = false, want true", tt.err, tt.sentinel)
- }
- })
- }
-}
-
-// TestWrappedErrors_PreserveSentinel tests that wrapped errors preserve sentinel checking
-func TestWrappedErrors_PreserveSentinel(t *testing.T) {
- t.Parallel()
-
- t.Run("wrapped invalid session error", func(t *testing.T) {
- t.Parallel()
- err := errors.New("original error")
- wrapped := errors.Join(agent.ErrInvalidSession, err)
- if !errors.Is(wrapped, agent.ErrInvalidSession) {
- t.Errorf("errors.Is(wrapped, ErrInvalidSession) = false, want true")
- }
- })
-
- t.Run("wrapped execution failed error", func(t *testing.T) {
- t.Parallel()
- err := errors.New("LLM timeout")
- wrapped := errors.Join(agent.ErrExecutionFailed, err)
- if !errors.Is(wrapped, agent.ErrExecutionFailed) {
- t.Errorf("errors.Is(wrapped, ErrExecutionFailed) = false, want true")
- }
- })
-}
diff --git a/internal/agent/chat/retry_test.go b/internal/agent/chat/retry_test.go
deleted file mode 100644
index 94a7036..0000000
--- a/internal/agent/chat/retry_test.go
+++ /dev/null
@@ -1,196 +0,0 @@
-package chat
-
-import (
- "errors"
- "testing"
-)
-
-func TestDefaultRetryConfig(t *testing.T) {
- t.Parallel()
-
- cfg := DefaultRetryConfig()
-
- if cfg.MaxRetries <= 0 {
- t.Errorf("MaxRetries should be positive, got %d", cfg.MaxRetries)
- }
- if cfg.InitialInterval <= 0 {
- t.Errorf("InitialInterval should be positive, got %v", cfg.InitialInterval)
- }
- if cfg.MaxInterval <= 0 {
- t.Errorf("MaxInterval should be positive, got %v", cfg.MaxInterval)
- }
- if cfg.MaxInterval < cfg.InitialInterval {
- t.Error("MaxInterval should be >= InitialInterval")
- }
-}
-
-func TestRetryableError(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
- err error
- expected bool
- }{
- {
- name: "nil error",
- err: nil,
- expected: false,
- },
- {
- name: "rate limit error",
- err: errors.New("rate limit exceeded"),
- expected: true,
- },
- {
- name: "quota exceeded error",
- err: errors.New("quota exceeded for project"),
- expected: true,
- },
- {
- name: "429 status code",
- err: errors.New("HTTP 429: Too Many Requests"),
- expected: true,
- },
- {
- name: "500 server error",
- err: errors.New("HTTP 500 Internal Server Error"),
- expected: true,
- },
- {
- name: "502 bad gateway",
- err: errors.New("502 Bad Gateway"),
- expected: true,
- },
- {
- name: "503 unavailable",
- err: errors.New("503 Service Unavailable"),
- expected: true,
- },
- {
- name: "504 gateway timeout",
- err: errors.New("504 Gateway Timeout"),
- expected: true,
- },
- {
- name: "unavailable keyword",
- err: errors.New("service unavailable"),
- expected: true,
- },
- {
- name: "connection reset",
- err: errors.New("connection reset by peer"),
- expected: true,
- },
- {
- name: "timeout error",
- err: errors.New("request timeout"),
- expected: true,
- },
- {
- name: "temporary error",
- err: errors.New("temporary failure"),
- expected: true,
- },
- {
- name: "non-retryable error",
- err: errors.New("invalid API key"),
- expected: false,
- },
- {
- name: "non-retryable 400 error",
- err: errors.New("HTTP 400 Bad Request"),
- expected: false,
- },
- {
- name: "non-retryable 401 error",
- err: errors.New("HTTP 401 Unauthorized"),
- expected: false,
- },
- {
- name: "non-retryable 403 error",
- err: errors.New("HTTP 403 Forbidden"),
- expected: false,
- },
- {
- name: "case insensitive rate limit",
- err: errors.New("RATE LIMIT reached"),
- expected: true,
- },
- {
- name: "case insensitive timeout",
- err: errors.New("TIMEOUT occurred"),
- expected: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
-
- got := retryableError(tt.err)
- if got != tt.expected {
- t.Errorf("retryableError(%v) = %v, want %v", tt.err, got, tt.expected)
- }
- })
- }
-}
-
-func TestContainsAny(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
- s string
- substrs []string
- expected bool
- }{
- {
- name: "empty string",
- s: "",
- substrs: []string{"foo"},
- expected: false,
- },
- {
- name: "empty substrs",
- s: "foo bar",
- substrs: []string{},
- expected: false,
- },
- {
- name: "contains first substr",
- s: "foo bar baz",
- substrs: []string{"foo", "qux"},
- expected: true,
- },
- {
- name: "contains last substr",
- s: "foo bar baz",
- substrs: []string{"qux", "baz"},
- expected: true,
- },
- {
- name: "case insensitive match",
- s: "FOO BAR BAZ",
- substrs: []string{"foo"},
- expected: true,
- },
- {
- name: "no match",
- s: "foo bar baz",
- substrs: []string{"qux", "quux"},
- expected: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
-
- got := containsAny(tt.s, tt.substrs...)
- if got != tt.expected {
- t.Errorf("containsAny(%q, %v) = %v, want %v", tt.s, tt.substrs, got, tt.expected)
- }
- })
- }
-}
diff --git a/internal/agent/doc.go b/internal/agent/doc.go
deleted file mode 100644
index eeca0bd..0000000
--- a/internal/agent/doc.go
+++ /dev/null
@@ -1,33 +0,0 @@
-// Package agent provides sentinel errors for the chat agent.
-//
-// # Overview
-//
-// This package provides shared error types for building conversational AI agents.
-// The main implementation is in the chat subpackage.
-//
-// # Errors
-//
-// The package provides sentinel errors for consistent error handling:
-//
-// agent.ErrInvalidSession // Invalid session ID format
-// agent.ErrExecutionFailed // LLM or tool execution failed
-//
-// # Usage
-//
-// The chat subpackage provides the Chat agent implementation:
-//
-// import "github.com/koopa0/koopa/internal/agent/chat"
-//
-// chatAgent, err := chat.New(chat.Config{
-// Genkit: g,
-// Retriever: retriever,
-// SessionStore: sessionStore,
-// Logger: logger,
-// Tools: tools,
-// MaxTurns: 10,
-// RAGTopK: 5,
-// Language: "auto",
-// })
-//
-// See the chat subpackage for the complete implementation.
-package agent
diff --git a/internal/agent/errors.go b/internal/agent/errors.go
deleted file mode 100644
index 6b132f1..0000000
--- a/internal/agent/errors.go
+++ /dev/null
@@ -1,16 +0,0 @@
-// Package agent provides the agent abstraction layer for AI chat functionality.
-package agent
-
-import "errors"
-
-// Sentinel errors for agent operations.
-// Only errors that are checked with errors.Is() are defined here.
-var (
- // ErrInvalidSession indicates the session ID is invalid or malformed.
- // Used by: web/handlers/chat.go for HTTP status mapping
- ErrInvalidSession = errors.New("invalid session")
-
- // ErrExecutionFailed indicates agent execution failed.
- // Used by: web/handlers/chat.go for HTTP status mapping
- ErrExecutionFailed = errors.New("execution failed")
-)
diff --git a/internal/api/chat.go b/internal/api/chat.go
index 6bbe80c..1035986 100644
--- a/internal/api/chat.go
+++ b/internal/api/chat.go
@@ -11,31 +11,16 @@ import (
"strings"
"time"
- "github.com/firebase/genkit/go/ai"
- "github.com/firebase/genkit/go/genkit"
"github.com/google/uuid"
- "github.com/koopa0/koopa/internal/agent"
- "github.com/koopa0/koopa/internal/agent/chat"
+ "github.com/koopa0/koopa/internal/chat"
"github.com/koopa0/koopa/internal/tools"
)
// SSE timeout for streaming connections.
const sseTimeout = 5 * time.Minute
-// Title generation constants.
-const (
- titleMaxLength = 50
- titleGenerationTimeout = 5 * time.Second
- titleInputMaxRunes = 500
-)
-
-const titlePrompt = `Generate a concise title (max 50 characters) for a chat session based on this first message.
-The title should capture the main topic or intent.
-Return ONLY the title text, no quotes, no explanations, no punctuation at the end.
-
-Message: %s
-
-Title:`
+// titleMaxLength is the maximum rune length for a fallback session title.
+const titleMaxLength = 50
// Tool display info for JSON SSE events.
type toolDisplayInfo struct {
@@ -45,17 +30,20 @@ type toolDisplayInfo struct {
}
var toolDisplay = map[string]toolDisplayInfo{
- "web_search": {StartMsg: "搜尋網路中...", CompleteMsg: "搜尋完成", ErrorMsg: "搜尋服務暫時無法使用,請稍後再試"},
- "web_fetch": {StartMsg: "讀取網頁中...", CompleteMsg: "已讀取內容", ErrorMsg: "無法讀取網頁內容"},
- "read_file": {StartMsg: "讀取檔案中...", CompleteMsg: "已讀取檔案", ErrorMsg: "無法讀取檔案"},
- "write_file": {StartMsg: "寫入檔案中...", CompleteMsg: "已寫入檔案", ErrorMsg: "寫入檔案失敗"},
- "list_files": {StartMsg: "瀏覽目錄中...", CompleteMsg: "目錄瀏覽完成", ErrorMsg: "無法瀏覽目錄"},
- "delete_file": {StartMsg: "刪除檔案中...", CompleteMsg: "已刪除檔案", ErrorMsg: "刪除檔案失敗"},
- "get_file_info": {StartMsg: "取得檔案資訊中...", CompleteMsg: "已取得檔案資訊", ErrorMsg: "無法取得檔案資訊"},
- "execute_command": {StartMsg: "執行命令中...", CompleteMsg: "命令執行完成", ErrorMsg: "命令執行失敗"},
- "current_time": {StartMsg: "取得時間中...", CompleteMsg: "時間已取得", ErrorMsg: "無法取得時間"},
- "get_env": {StartMsg: "取得環境變數中...", CompleteMsg: "環境變數已取得", ErrorMsg: "無法取得環境變數"},
- "knowledge_search": {StartMsg: "搜尋知識庫中...", CompleteMsg: "知識庫搜尋完成", ErrorMsg: "無法搜尋知識庫"},
+ "web_search": {StartMsg: "搜尋網路中...", CompleteMsg: "搜尋完成", ErrorMsg: "搜尋服務暫時無法使用,請稍後再試"},
+ "web_fetch": {StartMsg: "讀取網頁中...", CompleteMsg: "已讀取內容", ErrorMsg: "無法讀取網頁內容"},
+ "read_file": {StartMsg: "讀取檔案中...", CompleteMsg: "已讀取檔案", ErrorMsg: "無法讀取檔案"},
+ "write_file": {StartMsg: "寫入檔案中...", CompleteMsg: "已寫入檔案", ErrorMsg: "寫入檔案失敗"},
+ "list_files": {StartMsg: "瀏覽目錄中...", CompleteMsg: "目錄瀏覽完成", ErrorMsg: "無法瀏覽目錄"},
+ "delete_file": {StartMsg: "刪除檔案中...", CompleteMsg: "已刪除檔案", ErrorMsg: "刪除檔案失敗"},
+ "get_file_info": {StartMsg: "取得檔案資訊中...", CompleteMsg: "已取得檔案資訊", ErrorMsg: "無法取得檔案資訊"},
+ "execute_command": {StartMsg: "執行命令中...", CompleteMsg: "命令執行完成", ErrorMsg: "命令執行失敗"},
+ "current_time": {StartMsg: "取得時間中...", CompleteMsg: "時間已取得", ErrorMsg: "無法取得時間"},
+ "get_env": {StartMsg: "取得環境變數中...", CompleteMsg: "環境變數已取得", ErrorMsg: "無法取得環境變數"},
+ "search_history": {StartMsg: "搜尋對話記錄中...", CompleteMsg: "對話記錄搜尋完成", ErrorMsg: "無法搜尋對話記錄"},
+ "search_documents": {StartMsg: "搜尋知識庫中...", CompleteMsg: "知識庫搜尋完成", ErrorMsg: "無法搜尋知識庫"},
+ "search_system_knowledge": {StartMsg: "搜尋系統知識中...", CompleteMsg: "系統知識搜尋完成", ErrorMsg: "無法搜尋系統知識"},
+ "knowledge_store": {StartMsg: "儲存知識中...", CompleteMsg: "知識已儲存", ErrorMsg: "儲存知識失敗"},
}
var defaultToolDisplay = toolDisplayInfo{
@@ -73,48 +61,47 @@ func getToolDisplay(name string) toolDisplayInfo {
// chatHandler handles chat-related API requests.
type chatHandler struct {
- logger *slog.Logger
- genkit *genkit.Genkit
- modelName string // Provider-qualified model name for title generation
- flow *chat.Flow
- sessions *sessionManager
+ logger *slog.Logger
+ agent *chat.Agent // Optional: nil disables AI title generation
+ flow *chat.Flow
+ sessions *sessionManager
}
// send handles POST /api/v1/chat — accepts JSON, sends message to chat flow.
-//
-//nolint:revive // unused-receiver: method bound to chatHandler for consistent route registration
func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) {
var req struct {
Content string `json:"content"`
SessionID string `json:"sessionId"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- WriteError(w, http.StatusBadRequest, "invalid_json", "invalid request body")
+ WriteError(w, http.StatusBadRequest, "invalid_json", "invalid request body", h.logger)
return
}
content := strings.TrimSpace(req.Content)
if content == "" {
- WriteError(w, http.StatusBadRequest, "content_required", "content is required")
+ WriteError(w, http.StatusBadRequest, "content_required", "content is required", h.logger)
return
}
- // Resolve session ID from request body or context
- var sessionID uuid.UUID
+ // Resolve session from context (set by session middleware from cookie)
+ sessionID, ok := sessionIDFromContext(r.Context())
+ if !ok {
+ WriteError(w, http.StatusBadRequest, "session_required", "session ID required", h.logger)
+ return
+ }
+
+ // If body also specifies a session, verify it matches (defense-in-depth)
if req.SessionID != "" {
parsed, err := uuid.Parse(req.SessionID)
if err != nil {
- WriteError(w, http.StatusBadRequest, "invalid_session", "invalid session ID")
+ WriteError(w, http.StatusBadRequest, "invalid_session", "invalid session ID", h.logger)
return
}
- sessionID = parsed
- } else {
- ctxID, ok := SessionIDFromContext(r.Context())
- if !ok {
- WriteError(w, http.StatusBadRequest, "session_required", "session ID required")
+ if parsed != sessionID {
+ WriteError(w, http.StatusForbidden, "forbidden", "session access denied", h.logger)
return
}
- sessionID = ctxID
}
msgID := uuid.New().String()
@@ -128,7 +115,7 @@ func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) {
"msgId": msgID,
"sessionId": sessionID.String(),
"streamUrl": "/api/v1/chat/stream?" + params.Encode(),
- })
+ }, h.logger)
}
// stream handles GET /api/v1/chat/stream — SSE endpoint with JSON events.
@@ -137,12 +124,21 @@ func (h *chatHandler) stream(w http.ResponseWriter, r *http.Request) {
sessionID := r.URL.Query().Get("session_id")
query := r.URL.Query().Get("query")
- if msgID == "" || sessionID == "" {
- WriteError(w, http.StatusBadRequest, "missing_params", "msgId and session_id required")
+ if msgID == "" || sessionID == "" || query == "" {
+ WriteError(w, http.StatusBadRequest, "missing_params", "msgId, session_id, and query required", h.logger)
return
}
- if query == "" {
- query = "Hello"
+
+ // Verify session ownership
+ parsedID, err := uuid.Parse(sessionID)
+ if err != nil {
+ WriteError(w, http.StatusBadRequest, "invalid_session", "invalid session ID", h.logger)
+ return
+ }
+ ctxID, ok := sessionIDFromContext(r.Context())
+ if !ok || ctxID != parsedID {
+ WriteError(w, http.StatusForbidden, "forbidden", "session access denied", h.logger)
+ return
}
// Set SSE headers
@@ -151,24 +147,23 @@ func (h *chatHandler) stream(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
- flusher, ok := w.(http.Flusher)
- if !ok {
- WriteError(w, http.StatusInternalServerError, "sse_unsupported", "streaming not supported")
+ if _, ok := w.(http.Flusher); !ok {
+ WriteError(w, http.StatusInternalServerError, "sse_unsupported", "streaming not supported", h.logger)
return
}
ctx, cancel := context.WithTimeout(r.Context(), sseTimeout)
defer cancel()
- if h.flow != nil {
- h.streamWithFlow(ctx, w, flusher, msgID, sessionID, query)
- } else {
- h.simulateStreaming(ctx, w, flusher, msgID, sessionID, query)
+ if h.flow == nil {
+ _ = sseEvent(w, "error", map[string]string{"error": "chat flow not initialized"})
+ return
}
+ h.streamWithFlow(ctx, w, msgID, sessionID, query)
}
// sseEvent writes a single SSE event.
-func sseEvent(w http.ResponseWriter, f http.Flusher, event string, data any) error {
+func sseEvent(w http.ResponseWriter, event string, data any) error {
jsonData, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshal SSE data: %w", err)
@@ -177,19 +172,21 @@ func sseEvent(w http.ResponseWriter, f http.Flusher, event string, data any) err
if err != nil {
return fmt.Errorf("write SSE event: %w", err)
}
- f.Flush()
+ if f, ok := w.(http.Flusher); ok {
+ f.Flush()
+ }
return nil
}
// streamWithFlow uses the real chat.Flow for AI responses.
-func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter, f http.Flusher, msgID, sessionID, query string) {
+func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter, msgID, sessionID, query string) {
input := chat.Input{
Query: query,
SessionID: sessionID,
}
// Create JSON tool emitter and inject into context
- emitter := &jsonToolEmitter{w: w, f: f, msgID: msgID}
+ emitter := &jsonToolEmitter{w: w, msgID: msgID}
ctx = tools.ContextWithEmitter(ctx, emitter)
h.logger.Debug("starting stream", "sessionId", sessionID)
@@ -197,7 +194,6 @@ func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter,
var (
finalOutput chat.Output
streamErr error
- buf strings.Builder
)
for streamValue, err := range h.flow.Stream(ctx, input) {
@@ -219,11 +215,8 @@ func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter,
}
if streamValue.Stream.Text != "" {
- buf.WriteString(streamValue.Stream.Text)
- content := buf.String()
- buf.Reset()
- if err := sseEvent(w, f, "chunk", map[string]string{"msgId": msgID, "text": content}); err != nil {
- h.logger.Error("failed to write chunk", "error", err)
+ if err := sseEvent(w, "chunk", map[string]string{"msgId": msgID, "text": streamValue.Stream.Text}); err != nil {
+ h.logger.Error("writing chunk", "error", err)
return
}
}
@@ -231,16 +224,11 @@ func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter,
if streamErr != nil {
code, message := classifyError(streamErr)
- h.logger.Error("flow execution failed", "error", streamErr, "sessionId", sessionID)
- _ = sseEvent(w, f, "error", map[string]string{"msgId": msgID, "code": code, "message": message})
+ h.logger.Error("executing flow", "error", streamErr, "sessionId", sessionID)
+ _ = sseEvent(w, "error", map[string]string{"msgId": msgID, "code": code, "message": message}) // best-effort: client may have disconnected
return
}
- // Flush remaining buffer
- if buf.Len() > 0 {
- _ = sseEvent(w, f, "chunk", map[string]string{"msgId": msgID, "text": buf.String()})
- }
-
// Generate title before sending done event
title := h.maybeGenerateTitle(ctx, sessionID, query)
@@ -253,55 +241,15 @@ func (h *chatHandler) streamWithFlow(ctx context.Context, w http.ResponseWriter,
if title != "" {
doneData["title"] = title
}
- _ = sseEvent(w, f, "done", doneData)
-}
-
-// simulateStreaming is a placeholder for testing without real Flow.
-func (h *chatHandler) simulateStreaming(ctx context.Context, w http.ResponseWriter, f http.Flusher, msgID, sessionID, query string) {
- response := fmt.Sprintf("I received your message: %q. This is a simulated response.", query)
- words := strings.Fields(response)
-
- var full strings.Builder
- for i, word := range words {
- select {
- case <-ctx.Done():
- h.logContextDone(ctx, msgID)
- return
- default:
- }
-
- if i > 0 {
- full.WriteString(" ")
- }
- full.WriteString(word)
-
- if err := sseEvent(w, f, "chunk", map[string]string{"msgId": msgID, "text": full.String()}); err != nil {
- h.logger.Error("failed to send chunk", "error", err)
- return
- }
-
- time.Sleep(50 * time.Millisecond)
- }
-
- title := h.maybeGenerateTitle(ctx, sessionID, query)
-
- doneData := map[string]string{
- "msgId": msgID,
- "sessionId": sessionID,
- "response": full.String(),
- }
- if title != "" {
- doneData["title"] = title
- }
- _ = sseEvent(w, f, "done", doneData)
+ _ = sseEvent(w, "done", doneData) // best-effort: client may have disconnected
}
// classifyError returns error code and user message based on error type.
func classifyError(err error) (code, message string) {
switch {
- case errors.Is(err, agent.ErrInvalidSession):
+ case errors.Is(err, chat.ErrInvalidSession):
return "invalid_session", "Invalid session. Please refresh the page."
- case errors.Is(err, agent.ErrExecutionFailed):
+ case errors.Is(err, chat.ErrExecutionFailed):
return "execution_failed", err.Error()
case errors.Is(err, context.DeadlineExceeded):
return "timeout", "Request timed out. Please try again."
@@ -311,8 +259,13 @@ func classifyError(err error) (code, message string) {
}
// maybeGenerateTitle generates a session title if one doesn't exist.
+// Uses Agent.GenerateTitle for AI-powered titles, falls back to truncation.
// Returns the generated title or empty string.
func (h *chatHandler) maybeGenerateTitle(ctx context.Context, sessionID, userMessage string) string {
+ if h.sessions == nil || h.sessions.store == nil {
+ return ""
+ }
+
sessionUUID, err := uuid.Parse(sessionID)
if err != nil {
return ""
@@ -327,13 +280,16 @@ func (h *chatHandler) maybeGenerateTitle(ctx context.Context, sessionID, userMes
return ""
}
- title := h.generateTitleWithAI(ctx, userMessage)
+ var title string
+ if h.agent != nil {
+ title = h.agent.GenerateTitle(ctx, userMessage)
+ }
if title == "" {
title = truncateForTitle(userMessage)
}
if err := h.sessions.store.UpdateSessionTitle(ctx, sessionUUID, title); err != nil {
- h.logger.Error("failed to update session title", "error", err, "session_id", sessionID)
+ h.logger.Error("updating session title", "error", err, "session_id", sessionID)
return ""
}
@@ -341,42 +297,6 @@ func (h *chatHandler) maybeGenerateTitle(ctx context.Context, sessionID, userMes
return title
}
-// generateTitleWithAI uses Genkit to generate a session title.
-func (h *chatHandler) generateTitleWithAI(ctx context.Context, userMessage string) string {
- if h.genkit == nil {
- return ""
- }
-
- ctx, cancel := context.WithTimeout(ctx, titleGenerationTimeout)
- defer cancel()
-
- inputRunes := []rune(userMessage)
- if len(inputRunes) > titleInputMaxRunes {
- userMessage = string(inputRunes[:titleInputMaxRunes]) + "..."
- }
-
- response, err := genkit.Generate(ctx, h.genkit,
- ai.WithModelName(h.modelName),
- ai.WithPrompt(titlePrompt, userMessage),
- )
- if err != nil {
- h.logger.Debug("AI title generation failed", "error", err)
- return ""
- }
-
- title := strings.TrimSpace(response.Text())
- if title == "" {
- return ""
- }
-
- titleRunes := []rune(title)
- if len(titleRunes) > titleMaxLength {
- title = string(titleRunes[:titleMaxLength-3]) + "..."
- }
-
- return title
-}
-
// truncateForTitle truncates a message to create a fallback session title.
func truncateForTitle(message string) string {
message = strings.TrimSpace(message)
@@ -402,16 +322,15 @@ func (h *chatHandler) logContextDone(ctx context.Context, msgID string) {
}
}
-// jsonToolEmitter implements tools.ToolEventEmitter for JSON SSE events.
+// jsonToolEmitter implements tools.Emitter for JSON SSE events.
type jsonToolEmitter struct {
w http.ResponseWriter
- f http.Flusher
msgID string
}
func (e *jsonToolEmitter) OnToolStart(name string) {
display := getToolDisplay(name)
- _ = sseEvent(e.w, e.f, "tool_start", map[string]string{
+ _ = sseEvent(e.w, "tool_start", map[string]string{ // best-effort
"msgId": e.msgID,
"tool": name,
"message": display.StartMsg,
@@ -420,7 +339,7 @@ func (e *jsonToolEmitter) OnToolStart(name string) {
func (e *jsonToolEmitter) OnToolComplete(name string) {
display := getToolDisplay(name)
- _ = sseEvent(e.w, e.f, "tool_complete", map[string]string{
+ _ = sseEvent(e.w, "tool_complete", map[string]string{ // best-effort
"msgId": e.msgID,
"tool": name,
"message": display.CompleteMsg,
@@ -429,7 +348,7 @@ func (e *jsonToolEmitter) OnToolComplete(name string) {
func (e *jsonToolEmitter) OnToolError(name string) {
display := getToolDisplay(name)
- _ = sseEvent(e.w, e.f, "tool_error", map[string]string{
+ _ = sseEvent(e.w, "tool_error", map[string]string{ // best-effort
"msgId": e.msgID,
"tool": name,
"message": display.ErrorMsg,
@@ -437,4 +356,4 @@ func (e *jsonToolEmitter) OnToolError(name string) {
}
// Compile-time interface verification.
-var _ tools.ToolEventEmitter = (*jsonToolEmitter)(nil)
+var _ tools.Emitter = (*jsonToolEmitter)(nil)
diff --git a/internal/api/chat_test.go b/internal/api/chat_test.go
index 4b708db..e08af8a 100644
--- a/internal/api/chat_test.go
+++ b/internal/api/chat_test.go
@@ -9,10 +9,13 @@ import (
"net/http"
"net/http/httptest"
"net/url"
+ "strings"
"testing"
+ "github.com/firebase/genkit/go/genkit"
"github.com/google/uuid"
- "github.com/koopa0/koopa/internal/agent"
+
+ "github.com/koopa0/koopa/internal/chat"
)
func newTestChatHandler() *chatHandler {
@@ -22,8 +25,6 @@ func newTestChatHandler() *chatHandler {
}
func TestChatSend_URLEncoding(t *testing.T) {
- ch := newTestChatHandler()
-
sessionID := uuid.New()
content := "你好 world & foo=bar#hash?query"
@@ -34,8 +35,10 @@ func TestChatSend_URLEncoding(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body))
+ ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID)
+ r = r.WithContext(ctx)
- ch.send(w, r)
+ newTestChatHandler().send(w, r)
if w.Code != http.StatusOK {
t.Fatalf("send() status = %d, want %d", w.Code, http.StatusOK)
@@ -67,8 +70,6 @@ func TestChatSend_URLEncoding(t *testing.T) {
}
func TestChatSend_SessionIDFromBody(t *testing.T) {
- ch := newTestChatHandler()
-
sessionID := uuid.New()
body, _ := json.Marshal(map[string]string{
"content": "hello",
@@ -77,8 +78,10 @@ func TestChatSend_SessionIDFromBody(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body))
+ ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID)
+ r = r.WithContext(ctx)
- ch.send(w, r)
+ newTestChatHandler().send(w, r)
if w.Code != http.StatusOK {
t.Fatalf("send() status = %d, want %d", w.Code, http.StatusOK)
@@ -97,8 +100,6 @@ func TestChatSend_SessionIDFromBody(t *testing.T) {
}
func TestChatSend_SessionIDFromContext(t *testing.T) {
- ch := newTestChatHandler()
-
sessionID := uuid.New()
body, _ := json.Marshal(map[string]string{
"content": "hello",
@@ -112,7 +113,7 @@ func TestChatSend_SessionIDFromContext(t *testing.T) {
ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID)
r = r.WithContext(ctx)
- ch.send(w, r)
+ newTestChatHandler().send(w, r)
if w.Code != http.StatusOK {
t.Fatalf("send() status = %d, want %d", w.Code, http.StatusOK)
@@ -127,8 +128,6 @@ func TestChatSend_SessionIDFromContext(t *testing.T) {
}
func TestChatSend_MissingContent(t *testing.T) {
- ch := newTestChatHandler()
-
body, _ := json.Marshal(map[string]string{
"sessionId": uuid.New().String(),
})
@@ -136,7 +135,7 @@ func TestChatSend_MissingContent(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body))
- ch.send(w, r)
+ newTestChatHandler().send(w, r)
if w.Code != http.StatusBadRequest {
t.Fatalf("send(no content) status = %d, want %d", w.Code, http.StatusBadRequest)
@@ -150,8 +149,6 @@ func TestChatSend_MissingContent(t *testing.T) {
}
func TestChatSend_EmptyContent(t *testing.T) {
- ch := newTestChatHandler()
-
body, _ := json.Marshal(map[string]string{
"content": " ",
"sessionId": uuid.New().String(),
@@ -160,7 +157,7 @@ func TestChatSend_EmptyContent(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body))
- ch.send(w, r)
+ newTestChatHandler().send(w, r)
if w.Code != http.StatusBadRequest {
t.Fatalf("send(whitespace) status = %d, want %d", w.Code, http.StatusBadRequest)
@@ -168,8 +165,6 @@ func TestChatSend_EmptyContent(t *testing.T) {
}
func TestChatSend_InvalidSessionID(t *testing.T) {
- ch := newTestChatHandler()
-
body, _ := json.Marshal(map[string]string{
"content": "hello",
"sessionId": "not-a-uuid",
@@ -177,8 +172,11 @@ func TestChatSend_InvalidSessionID(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body))
+ // Inject context session so we reach the body parse check
+ ctx := context.WithValue(r.Context(), ctxKeySessionID, uuid.New())
+ r = r.WithContext(ctx)
- ch.send(w, r)
+ newTestChatHandler().send(w, r)
if w.Code != http.StatusBadRequest {
t.Fatalf("send(bad uuid) status = %d, want %d", w.Code, http.StatusBadRequest)
@@ -192,8 +190,6 @@ func TestChatSend_InvalidSessionID(t *testing.T) {
}
func TestChatSend_NoSession(t *testing.T) {
- ch := newTestChatHandler()
-
body, _ := json.Marshal(map[string]string{
"content": "hello",
// No sessionId in body, no session in context
@@ -202,7 +198,7 @@ func TestChatSend_NoSession(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body))
- ch.send(w, r)
+ newTestChatHandler().send(w, r)
if w.Code != http.StatusBadRequest {
t.Fatalf("send(no session) status = %d, want %d", w.Code, http.StatusBadRequest)
@@ -216,12 +212,10 @@ func TestChatSend_NoSession(t *testing.T) {
}
func TestChatSend_InvalidJSON(t *testing.T) {
- ch := newTestChatHandler()
-
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader([]byte("not json")))
- ch.send(w, r)
+ newTestChatHandler().send(w, r)
if w.Code != http.StatusBadRequest {
t.Fatalf("send(invalid json) status = %d, want %d", w.Code, http.StatusBadRequest)
@@ -235,11 +229,11 @@ func TestTruncateForTitle(t *testing.T) {
wantMax int
wantDots bool
}{
- {"short", "Hello world", 50, false},
- {"exact_50", "12345678901234567890123456789012345678901234567890", 50, false},
- {"long", "This is a very long message that exceeds the maximum allowed title length of fifty characters", 53, true},
- {"empty", "", 0, false},
- {"whitespace", " hello ", 50, false},
+ {name: "short", input: "Hello world", wantMax: 50, wantDots: false},
+ {name: "exact_50", input: "12345678901234567890123456789012345678901234567890", wantMax: 50, wantDots: false},
+ {name: "long", input: "This is a very long message that exceeds the maximum allowed title length of fifty characters", wantMax: 53, wantDots: true},
+ {name: "empty", input: "", wantMax: 0, wantDots: false},
+ {name: "whitespace", input: " hello ", wantMax: 50, wantDots: false},
}
for _, tt := range tests {
@@ -280,16 +274,175 @@ func TestGetToolDisplay(t *testing.T) {
}
}
+func TestStream_MissingParams(t *testing.T) {
+ ch := newTestChatHandler()
+
+ tests := []struct {
+ name string
+ query string
+ }{
+ {name: "missing all", query: ""},
+ {name: "missing session_id and query", query: "?msgId=abc"},
+ {name: "missing msgId and query", query: "?session_id=abc"},
+ {name: "missing query", query: "?msgId=abc&session_id=def"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream"+tt.query, nil)
+
+ ch.stream(w, r)
+
+ if w.Code != http.StatusBadRequest {
+ t.Fatalf("stream(%s) status = %d, want %d", tt.name, w.Code, http.StatusBadRequest)
+ }
+
+ errResp := decodeErrorEnvelope(t, w)
+ if errResp.Code != "missing_params" {
+ t.Errorf("stream(%s) code = %q, want %q", tt.name, errResp.Code, "missing_params")
+ }
+ })
+ }
+}
+
+func TestStream_SSEHeaders(t *testing.T) {
+ ch := newTestChatHandler() // flow is nil → error event (headers still set)
+ sessionID := uuid.New()
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil)
+ ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID)
+ r = r.WithContext(ctx)
+
+ ch.stream(w, r)
+
+ wantHeaders := map[string]string{
+ "Content-Type": "text/event-stream",
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "X-Accel-Buffering": "no",
+ }
+
+ for header, want := range wantHeaders {
+ if got := w.Header().Get(header); got != want {
+ t.Errorf("stream() header %q = %q, want %q", header, got, want)
+ }
+ }
+}
+
+func TestStream_NilFlow(t *testing.T) {
+ ch := newTestChatHandler() // flow is nil → error SSE event
+ sessionID := uuid.New()
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hello", nil)
+ ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID)
+ r = r.WithContext(ctx)
+
+ ch.stream(w, r)
+
+ body := w.Body.String()
+
+ // Should emit an error event, not chunk or done events
+ if !strings.Contains(body, "event: error\n") {
+ t.Error("stream(nil flow) expected error event in SSE output")
+ }
+ if !strings.Contains(body, "chat flow not initialized") {
+ t.Error("stream(nil flow) expected 'chat flow not initialized' in error event")
+ }
+ if strings.Contains(body, "event: chunk\n") {
+ t.Error("stream(nil flow) should not emit chunk events")
+ }
+ if strings.Contains(body, "event: done\n") {
+ t.Error("stream(nil flow) should not emit done events")
+ }
+}
+
+func TestSSEEvent_Format(t *testing.T) {
+ w := httptest.NewRecorder()
+
+ data := map[string]string{"msgId": "abc", "text": "hello"}
+ err := sseEvent(w, "chunk", data)
+
+ if err != nil {
+ t.Fatalf("sseEvent() error: %v", err)
+ }
+
+ body := w.Body.String()
+
+ // Verify SSE format: "event:
\ndata: \n\n"
+ if !strings.HasPrefix(body, "event: chunk\ndata: ") {
+ t.Errorf("sseEvent() format = %q, want prefix %q", body, "event: chunk\ndata: ")
+ }
+
+ if !strings.HasSuffix(body, "\n\n") {
+ t.Errorf("sseEvent() should end with double newline, got %q", body)
+ }
+
+ // Verify JSON payload is valid
+ dataLine := strings.TrimPrefix(body, "event: chunk\ndata: ")
+ dataLine = strings.TrimSuffix(dataLine, "\n\n")
+
+ var decoded map[string]string
+ if err := json.Unmarshal([]byte(dataLine), &decoded); err != nil {
+ t.Fatalf("sseEvent() data is not valid JSON: %v", err)
+ }
+
+ if decoded["msgId"] != "abc" {
+ t.Errorf("sseEvent() data.msgId = %q, want %q", decoded["msgId"], "abc")
+ }
+ if decoded["text"] != "hello" {
+ t.Errorf("sseEvent() data.text = %q, want %q", decoded["text"], "hello")
+ }
+}
+
+func TestSSEEvent_MarshalError(t *testing.T) {
+ w := httptest.NewRecorder()
+
+ // Channels cannot be marshaled to JSON
+ err := sseEvent(w, "chunk", make(chan int))
+
+ if err == nil {
+ t.Fatal("sseEvent(unmarshalable) expected error, got nil")
+ }
+}
+
+func TestStream_NilFlow_ContextCanceled(t *testing.T) {
+ ch := newTestChatHandler() // flow is nil
+ sessionID := uuid.New()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel() // Cancel immediately
+ ctx = context.WithValue(ctx, ctxKeySessionID, sessionID)
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil)
+ r = r.WithContext(ctx)
+
+ ch.stream(w, r)
+
+ body := w.Body.String()
+
+ // Flow is nil, so error event is emitted regardless of context state.
+ if !strings.Contains(body, "event: error\n") {
+ t.Error("stream(nil flow, canceled) expected error event")
+ }
+ if strings.Contains(body, "event: done\n") {
+ t.Error("stream(nil flow, canceled) should not emit done event")
+ }
+}
+
func TestClassifyError(t *testing.T) {
tests := []struct {
name string
err error
wantCode string
}{
- {"invalid_session", agent.ErrInvalidSession, "invalid_session"},
- {"execution_failed", agent.ErrExecutionFailed, "execution_failed"},
- {"deadline_exceeded", context.DeadlineExceeded, "timeout"},
- {"generic_error", errors.New("something went wrong"), "flow_error"},
+ {name: "invalid_session", err: chat.ErrInvalidSession, wantCode: "invalid_session"},
+ {name: "execution_failed", err: chat.ErrExecutionFailed, wantCode: "execution_failed"},
+ {name: "deadline_exceeded", err: context.DeadlineExceeded, wantCode: "timeout"},
+ {name: "generic_error", err: errors.New("something went wrong"), wantCode: "flow_error"},
}
for _, tt := range tests {
@@ -304,3 +457,231 @@ func TestClassifyError(t *testing.T) {
})
}
}
+
+func TestChatSend_SessionMismatch(t *testing.T) {
+ body, _ := json.Marshal(map[string]string{
+ "content": "hello",
+ "sessionId": uuid.New().String(), // Different from context
+ })
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body))
+ ctx := context.WithValue(r.Context(), ctxKeySessionID, uuid.New())
+ r = r.WithContext(ctx)
+
+ newTestChatHandler().send(w, r)
+
+ if w.Code != http.StatusForbidden {
+ t.Fatalf("send(mismatched session) status = %d, want %d", w.Code, http.StatusForbidden)
+ }
+
+ errResp := decodeErrorEnvelope(t, w)
+ if errResp.Code != "forbidden" {
+ t.Errorf("send(mismatched session) code = %q, want %q", errResp.Code, "forbidden")
+ }
+}
+
+func TestStream_OwnershipDenied(t *testing.T) {
+ ch := newTestChatHandler()
+ sessionID := uuid.New()
+ otherID := uuid.New()
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil)
+ ctx := context.WithValue(r.Context(), ctxKeySessionID, otherID)
+ r = r.WithContext(ctx)
+
+ ch.stream(w, r)
+
+ if w.Code != http.StatusForbidden {
+ t.Fatalf("stream(wrong session) status = %d, want %d", w.Code, http.StatusForbidden)
+ }
+}
+
+func TestStream_NoSession(t *testing.T) {
+ ch := newTestChatHandler()
+ sessionID := uuid.New()
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil)
+ // No session in context
+
+ ch.stream(w, r)
+
+ if w.Code != http.StatusForbidden {
+ t.Fatalf("stream(no session) status = %d, want %d", w.Code, http.StatusForbidden)
+ }
+}
+
+// sseTestEvent represents a parsed SSE event for test assertions.
+type sseTestEvent struct {
+ Type string
+ Data map[string]string
+}
+
+// parseSSEEvents parses an SSE response body into structured events.
+func parseSSEEvents(t *testing.T, body string) []sseTestEvent {
+ t.Helper()
+ var events []sseTestEvent
+ for _, block := range strings.Split(body, "\n\n") {
+ block = strings.TrimSpace(block)
+ if block == "" {
+ continue
+ }
+ var ev sseTestEvent
+ for _, line := range strings.Split(block, "\n") {
+ switch {
+ case strings.HasPrefix(line, "event: "):
+ ev.Type = strings.TrimPrefix(line, "event: ")
+ case strings.HasPrefix(line, "data: "):
+ raw := strings.TrimPrefix(line, "data: ")
+ ev.Data = make(map[string]string)
+ if err := json.Unmarshal([]byte(raw), &ev.Data); err != nil {
+ t.Fatalf("parseSSEEvents: invalid JSON in data line %q: %v", raw, err)
+ }
+ }
+ }
+ if ev.Type != "" {
+ events = append(events, ev)
+ }
+ }
+ return events
+}
+
+// filterSSEEvents returns events matching the given type.
+func filterSSEEvents(events []sseTestEvent, eventType string) []sseTestEvent {
+ var filtered []sseTestEvent
+ for _, e := range events {
+ if e.Type == eventType {
+ filtered = append(filtered, e)
+ }
+ }
+ return filtered
+}
+
+func TestStreamWithFlow(t *testing.T) {
+ sessionID := uuid.New()
+ sessionIDStr := sessionID.String()
+
+ tests := []struct {
+ name string
+ flowFn func(context.Context, chat.Input, func(context.Context, chat.StreamChunk) error) (chat.Output, error)
+ wantChunks []string // expected chunk text values in order
+ wantDone map[string]string // expected fields in done event (nil = no done expected)
+ wantError map[string]string // expected fields in error event (nil = no error expected)
+ }{
+ {
+ name: "success with chunks",
+ flowFn: func(ctx context.Context, input chat.Input, stream func(context.Context, chat.StreamChunk) error) (chat.Output, error) {
+ if stream != nil {
+ if err := stream(ctx, chat.StreamChunk{Text: "Hello "}); err != nil {
+ return chat.Output{}, err
+ }
+ if err := stream(ctx, chat.StreamChunk{Text: "World"}); err != nil {
+ return chat.Output{}, err
+ }
+ }
+ return chat.Output{Response: "Hello World", SessionID: input.SessionID}, nil
+ },
+ wantChunks: []string{"Hello ", "World"},
+ wantDone: map[string]string{"response": "Hello World", "sessionId": sessionIDStr},
+ },
+ {
+ name: "flow error without chunks",
+ flowFn: func(_ context.Context, _ chat.Input, _ func(context.Context, chat.StreamChunk) error) (chat.Output, error) {
+ return chat.Output{}, chat.ErrInvalidSession
+ },
+ wantError: map[string]string{"code": "invalid_session"},
+ },
+ {
+ name: "partial chunks then error",
+ flowFn: func(ctx context.Context, _ chat.Input, stream func(context.Context, chat.StreamChunk) error) (chat.Output, error) {
+ if stream != nil {
+ if err := stream(ctx, chat.StreamChunk{Text: "partial"}); err != nil {
+ return chat.Output{}, err
+ }
+ }
+ return chat.Output{}, chat.ErrExecutionFailed
+ },
+ wantChunks: []string{"partial"},
+ wantError: map[string]string{"code": "execution_failed"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ t.Cleanup(cancel)
+
+ g := genkit.Init(ctx)
+ flowName := "test/" + strings.ReplaceAll(tt.name, " ", "_")
+ testFlow := genkit.DefineStreamingFlow(g, flowName, tt.flowFn)
+
+ ch := &chatHandler{
+ logger: slog.New(slog.DiscardHandler),
+ flow: testFlow,
+ }
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet,
+ "/api/v1/chat/stream?msgId=m1&session_id="+sessionIDStr+"&query=test", nil)
+ rctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID)
+ r = r.WithContext(rctx)
+
+ ch.stream(w, r)
+
+ if ct := w.Header().Get("Content-Type"); ct != "text/event-stream" {
+ t.Fatalf("stream(%s) Content-Type = %q, want %q", tt.name, ct, "text/event-stream")
+ }
+
+ events := parseSSEEvents(t, w.Body.String())
+
+ // Verify chunk events
+ chunks := filterSSEEvents(events, "chunk")
+ if len(chunks) != len(tt.wantChunks) {
+ t.Fatalf("stream(%s) got %d chunk events, want %d", tt.name, len(chunks), len(tt.wantChunks))
+ }
+ for i, wantText := range tt.wantChunks {
+ if got := chunks[i].Data["text"]; got != wantText {
+ t.Errorf("stream(%s) chunk[%d].text = %q, want %q", tt.name, i, got, wantText)
+ }
+ if got := chunks[i].Data["msgId"]; got != "m1" {
+ t.Errorf("stream(%s) chunk[%d].msgId = %q, want %q", tt.name, i, got, "m1")
+ }
+ }
+
+ // Verify done event
+ doneEvents := filterSSEEvents(events, "done")
+ if tt.wantDone != nil {
+ if len(doneEvents) != 1 {
+ t.Fatalf("stream(%s) got %d done events, want 1", tt.name, len(doneEvents))
+ }
+ for k, want := range tt.wantDone {
+ if got := doneEvents[0].Data[k]; got != want {
+ t.Errorf("stream(%s) done.%s = %q, want %q", tt.name, k, got, want)
+ }
+ }
+ if got := doneEvents[0].Data["msgId"]; got != "m1" {
+ t.Errorf("stream(%s) done.msgId = %q, want %q", tt.name, got, "m1")
+ }
+ } else if len(doneEvents) != 0 {
+ t.Errorf("stream(%s) got %d done events, want 0", tt.name, len(doneEvents))
+ }
+
+ // Verify error event
+ errorEvents := filterSSEEvents(events, "error")
+ if tt.wantError != nil {
+ if len(errorEvents) != 1 {
+ t.Fatalf("stream(%s) got %d error events, want 1", tt.name, len(errorEvents))
+ }
+ for k, want := range tt.wantError {
+ if got := errorEvents[0].Data[k]; got != want {
+ t.Errorf("stream(%s) error.%s = %q, want %q", tt.name, k, got, want)
+ }
+ }
+ } else if len(errorEvents) != 0 {
+ t.Errorf("stream(%s) got %d error events, want 0", tt.name, len(errorEvents))
+ }
+ })
+ }
+}
diff --git a/internal/api/doc.go b/internal/api/doc.go
new file mode 100644
index 0000000..b6ae844
--- /dev/null
+++ b/internal/api/doc.go
@@ -0,0 +1,79 @@
+// Package api provides the JSON REST API server for Koopa.
+//
+// # Architecture
+//
+// The API server uses Go 1.22+ routing with a layered middleware stack:
+//
+// Recovery → Logging → RateLimit → CORS → Session → CSRF → Routes
+//
+// Health probes (/health, /ready) bypass the middleware stack via a
+// top-level mux, ensuring they remain fast and unauthenticated.
+//
+// # Endpoints
+//
+// Health probes (no middleware):
+// - GET /health — returns {"status":"ok"}
+// - GET /ready — returns {"status":"ok"}
+//
+// CSRF provisioning:
+// - GET /api/v1/csrf-token — returns pre-session or session-bound token
+//
+// Session CRUD (ownership-enforced):
+// - POST /api/v1/sessions — create new session
+// - GET /api/v1/sessions — list caller's sessions
+// - GET /api/v1/sessions/{id} — get session by ID
+// - GET /api/v1/sessions/{id}/messages — get session messages
+// - DELETE /api/v1/sessions/{id} — delete session
+//
+// Chat (ownership-enforced):
+// - POST /api/v1/chat — initiate chat, returns stream URL
+// - GET /api/v1/chat/stream — SSE endpoint for streaming responses
+//
+// # CSRF Token Model
+//
+// Two token types prevent cross-site request forgery:
+//
+// - Pre-session tokens ("pre:nonce:timestamp:signature"): issued before
+// a session exists, valid for the first POST /sessions call.
+//
+// - Session-bound tokens ("timestamp:signature"): bound to a specific
+// session via HMAC-SHA256, verified with constant-time comparison.
+//
+// Both expire after 24 hours with 5 minutes of clock skew tolerance.
+//
+// # Session Ownership
+//
+// All session-accessing endpoints verify that the requested resource
+// matches the caller's session cookie. This prevents session enumeration
+// and cross-session data access.
+//
+// # Error Handling
+//
+// All responses use an envelope format:
+//
+// Success: {"data": }
+// Error: {"error": {"code": "...", "message": "..."}}
+//
+// Tool errors during chat are sent as SSE events (event: error),
+// not HTTP error responses, since SSE headers are already committed.
+//
+// # SSE Streaming
+//
+// Chat responses stream via Server-Sent Events with typed events:
+//
+// - chunk: incremental text content
+// - tool_start: tool execution began
+// - tool_complete: tool execution succeeded
+// - tool_error: tool execution failed
+// - done: final response with session metadata
+// - error: flow-level error
+//
+// # Security
+//
+// The middleware stack enforces:
+// - CSRF protection for state-changing requests
+// - Per-IP rate limiting (token bucket, 60 req/min burst)
+// - CORS with explicit origin allowlist
+// - Security headers (CSP, HSTS, X-Frame-Options, etc.)
+// - HttpOnly, Secure, SameSite=Lax session cookies
+package api
diff --git a/internal/api/e2e_test.go b/internal/api/e2e_test.go
new file mode 100644
index 0000000..db2473a
--- /dev/null
+++ b/internal/api/e2e_test.go
@@ -0,0 +1,470 @@
+//go:build integration
+
+package api
+
+import (
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/koopa0/koopa/internal/session"
+ "github.com/koopa0/koopa/internal/sqlc"
+ "github.com/koopa0/koopa/internal/testutil"
+)
+
+// e2eServer creates a full Server with all middleware backed by a real PostgreSQL database.
+// Returns the server handler and cleanup function.
+func e2eServer(t *testing.T) http.Handler {
+ t.Helper()
+
+ db := testutil.SetupTestDB(t)
+
+ store := session.New(sqlc.New(db.Pool), db.Pool, slog.New(slog.DiscardHandler))
+
+ srv, err := NewServer(ServerConfig{
+ Logger: slog.New(slog.DiscardHandler),
+ SessionStore: store,
+ CSRFSecret: []byte("e2e-test-secret-at-least-32-characters!!"),
+ CORSOrigins: []string{"http://localhost:4200"},
+ IsDev: true,
+ })
+ if err != nil {
+ t.Fatalf("NewServer() error: %v", err)
+ }
+
+ return srv.Handler()
+}
+
+// e2eCookies extracts cookies from a response for use in subsequent requests.
+// When multiple Set-Cookie headers share the same name, only the last one is kept
+// (matching browser behavior where later cookies overwrite earlier ones).
+func e2eCookies(t *testing.T, w *httptest.ResponseRecorder) []*http.Cookie {
+ t.Helper()
+ all := w.Result().Cookies()
+ seen := make(map[string]int, len(all))
+ var deduped []*http.Cookie
+ for _, c := range all {
+ if idx, ok := seen[c.Name]; ok {
+ deduped[idx] = c // overwrite with later cookie
+ } else {
+ seen[c.Name] = len(deduped)
+ deduped = append(deduped, c)
+ }
+ }
+ return deduped
+}
+
+// e2eAddCookies adds cookies to a request.
+func e2eAddCookies(r *http.Request, cookies []*http.Cookie) {
+ for _, c := range cookies {
+ r.AddCookie(c)
+ }
+}
+
+func TestE2E_HealthBypassesMiddleware(t *testing.T) {
+ handler := e2eServer(t)
+
+ for _, path := range []string{"/health", "/ready"} {
+ t.Run(path, func(t *testing.T) {
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, path, nil)
+ r.RemoteAddr = "10.0.0.1:12345"
+
+ handler.ServeHTTP(w, r)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("GET %s status = %d, want %d", path, w.Code, http.StatusOK)
+ }
+
+ // Health probes are on the top-level mux and bypass the middleware stack
+ // (including security headers). Verify they return valid JSON.
+ var body map[string]string
+ decodeData(t, w, &body)
+ if body["status"] != "ok" {
+ t.Errorf("GET %s status = %q, want %q", path, body["status"], "ok")
+ }
+ })
+ }
+}
+
+func TestE2E_FullSessionLifecycle(t *testing.T) {
+ handler := e2eServer(t)
+
+ // --- Step 1: GET /api/v1/csrf-token → pre-session token ---
+ w1 := httptest.NewRecorder()
+ r1 := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil)
+ r1.RemoteAddr = "10.0.0.1:12345"
+
+ handler.ServeHTTP(w1, r1)
+
+ if w1.Code != http.StatusOK {
+ t.Fatalf("step 1: GET /csrf-token status = %d, want %d", w1.Code, http.StatusOK)
+ }
+
+ var csrfResp map[string]string
+ decodeData(t, w1, &csrfResp)
+
+ preSessionToken := csrfResp["csrfToken"]
+ if preSessionToken == "" {
+ t.Fatal("step 1: expected csrfToken in response")
+ }
+ if !strings.HasPrefix(preSessionToken, "pre:") {
+ t.Fatalf("step 1: token = %q, want pre: prefix", preSessionToken)
+ }
+
+ // --- Step 2: POST /api/v1/sessions with pre-session CSRF → 201 ---
+ w2 := httptest.NewRecorder()
+ r2 := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil)
+ r2.RemoteAddr = "10.0.0.1:12345"
+ r2.Header.Set("X-CSRF-Token", preSessionToken)
+
+ handler.ServeHTTP(w2, r2)
+
+ if w2.Code != http.StatusCreated {
+ t.Fatalf("step 2: POST /sessions status = %d, want %d\nbody: %s", w2.Code, http.StatusCreated, w2.Body.String())
+ }
+
+ var createResp map[string]string
+ decodeData(t, w2, &createResp)
+
+ sessionID := createResp["id"]
+ sessionCSRF := createResp["csrfToken"]
+
+ if sessionID == "" {
+ t.Fatal("step 2: expected id in response")
+ }
+ if sessionCSRF == "" {
+ t.Fatal("step 2: expected csrfToken in response")
+ }
+ if strings.HasPrefix(sessionCSRF, "pre:") {
+ t.Fatal("step 2: session-bound token should not have pre: prefix")
+ }
+
+ // Extract cookies (should have sid cookie)
+ cookies := e2eCookies(t, w2)
+ var sidCookie *http.Cookie
+ for _, c := range cookies {
+ if c.Name == "sid" {
+ sidCookie = c
+ }
+ }
+ if sidCookie == nil {
+ t.Fatal("step 2: expected sid cookie")
+ }
+
+ // --- Step 3: GET /api/v1/sessions/{id} with cookie → 200 ---
+ w3 := httptest.NewRecorder()
+ r3 := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sessionID, nil)
+ r3.RemoteAddr = "10.0.0.1:12345"
+ e2eAddCookies(r3, cookies)
+
+ handler.ServeHTTP(w3, r3)
+
+ if w3.Code != http.StatusOK {
+ t.Fatalf("step 3: GET /sessions/%s status = %d, want %d\nbody: %s", sessionID, w3.Code, http.StatusOK, w3.Body.String())
+ }
+
+ var getResp map[string]string
+ decodeData(t, w3, &getResp)
+
+ if getResp["id"] != sessionID {
+ t.Errorf("step 3: id = %q, want %q", getResp["id"], sessionID)
+ }
+
+ // --- Step 4: GET /api/v1/sessions/{id}/messages → 200 (empty) ---
+ w4 := httptest.NewRecorder()
+ r4 := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sessionID+"/messages", nil)
+ r4.RemoteAddr = "10.0.0.1:12345"
+ e2eAddCookies(r4, cookies)
+
+ handler.ServeHTTP(w4, r4)
+
+ if w4.Code != http.StatusOK {
+ t.Fatalf("step 4: GET /sessions/%s/messages status = %d, want %d\nbody: %s", sessionID, w4.Code, http.StatusOK, w4.Body.String())
+ }
+
+ // --- Step 5: GET /api/v1/sessions with cookie → 200 (list contains session) ---
+ w5 := httptest.NewRecorder()
+ r5 := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil)
+ r5.RemoteAddr = "10.0.0.1:12345"
+ e2eAddCookies(r5, cookies)
+
+ handler.ServeHTTP(w5, r5)
+
+ if w5.Code != http.StatusOK {
+ t.Fatalf("step 5: GET /sessions status = %d, want %d", w5.Code, http.StatusOK)
+ }
+
+ // --- Step 6: DELETE /api/v1/sessions/{id} with cookie + CSRF → 200 ---
+ w6 := httptest.NewRecorder()
+ r6 := httptest.NewRequest(http.MethodDelete, "/api/v1/sessions/"+sessionID, nil)
+ r6.RemoteAddr = "10.0.0.1:12345"
+ r6.Header.Set("X-CSRF-Token", sessionCSRF)
+ e2eAddCookies(r6, cookies)
+
+ handler.ServeHTTP(w6, r6)
+
+ if w6.Code != http.StatusOK {
+ t.Fatalf("step 6: DELETE /sessions/%s status = %d, want %d\nbody: %s", sessionID, w6.Code, http.StatusOK, w6.Body.String())
+ }
+
+ // --- Step 7: GET deleted session → 404 ---
+ w7 := httptest.NewRecorder()
+ r7 := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sessionID, nil)
+ r7.RemoteAddr = "10.0.0.1:12345"
+ e2eAddCookies(r7, cookies)
+
+ handler.ServeHTTP(w7, r7)
+
+ if w7.Code != http.StatusNotFound {
+ t.Fatalf("step 7: GET deleted session status = %d, want %d", w7.Code, http.StatusNotFound)
+ }
+}
+
+func TestE2E_MissingCSRF_Rejected(t *testing.T) {
+ handler := e2eServer(t)
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil)
+ r.RemoteAddr = "10.0.0.1:12345"
+ // No X-CSRF-Token header
+
+ handler.ServeHTTP(w, r)
+
+ if w.Code != http.StatusForbidden {
+ t.Fatalf("POST /sessions (no CSRF) status = %d, want %d\nbody: %s", w.Code, http.StatusForbidden, w.Body.String())
+ }
+
+ errResp := decodeErrorEnvelope(t, w)
+ if errResp.Code != "csrf_invalid" && errResp.Code != "session_required" {
+ t.Errorf("POST /sessions (no CSRF) code = %q, want csrf_invalid or session_required", errResp.Code)
+ }
+}
+
+func TestE2E_InvalidCSRF_Rejected(t *testing.T) {
+ handler := e2eServer(t)
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil)
+ r.RemoteAddr = "10.0.0.1:12345"
+ r.Header.Set("X-CSRF-Token", "totally-fake-token")
+
+ handler.ServeHTTP(w, r)
+
+ if w.Code != http.StatusForbidden {
+ t.Fatalf("POST /sessions (bad CSRF) status = %d, want %d", w.Code, http.StatusForbidden)
+ }
+}
+
+func TestE2E_CrossSessionAccess_Denied(t *testing.T) {
+ handler := e2eServer(t)
+
+ // Create session A
+ w1 := httptest.NewRecorder()
+ r1 := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil)
+ r1.RemoteAddr = "10.0.0.1:12345"
+ handler.ServeHTTP(w1, r1)
+
+ var csrf1 map[string]string
+ decodeData(t, w1, &csrf1)
+
+ w2 := httptest.NewRecorder()
+ r2 := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil)
+ r2.RemoteAddr = "10.0.0.1:12345"
+ r2.Header.Set("X-CSRF-Token", csrf1["csrfToken"])
+ handler.ServeHTTP(w2, r2)
+
+ var sessA map[string]string
+ decodeData(t, w2, &sessA)
+ cookiesA := e2eCookies(t, w2)
+
+ // Create session B (different "client")
+ w3 := httptest.NewRecorder()
+ r3 := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil)
+ r3.RemoteAddr = "10.0.0.2:12345"
+ handler.ServeHTTP(w3, r3)
+
+ var csrf2 map[string]string
+ decodeData(t, w3, &csrf2)
+
+ w4 := httptest.NewRecorder()
+ r4 := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil)
+ r4.RemoteAddr = "10.0.0.2:12345"
+ r4.Header.Set("X-CSRF-Token", csrf2["csrfToken"])
+ handler.ServeHTTP(w4, r4)
+
+ var sessB map[string]string
+ decodeData(t, w4, &sessB)
+
+ // Client A tries to access session B → 403
+ w5 := httptest.NewRecorder()
+ r5 := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sessB["id"], nil)
+ r5.RemoteAddr = "10.0.0.1:12345"
+ e2eAddCookies(r5, cookiesA) // Cookie has session A
+
+ handler.ServeHTTP(w5, r5)
+
+ if w5.Code != http.StatusForbidden {
+ t.Fatalf("cross-session GET status = %d, want %d", w5.Code, http.StatusForbidden)
+ }
+
+ errResp := decodeErrorEnvelope(t, w5)
+ if errResp.Code != "forbidden" {
+ t.Errorf("cross-session GET code = %q, want %q", errResp.Code, "forbidden")
+ }
+}
+
+func TestE2E_RateLimiting(t *testing.T) {
+ handler := e2eServer(t)
+
+ // Rate limiter is configured at 1 token/sec, burst 60.
+ // /health bypasses the middleware stack, so use an API path that goes through rate limiting.
+ var lastCode int
+ for i := range 65 {
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil)
+ r.RemoteAddr = "10.99.99.99:12345" // Unique IP so no interference
+
+ handler.ServeHTTP(w, r)
+ lastCode = w.Code
+
+ if w.Code == http.StatusTooManyRequests {
+ // Verify Retry-After header
+ if got := w.Header().Get("Retry-After"); got != "1" {
+ t.Errorf("rate limited Retry-After = %q, want %q", got, "1")
+ }
+ t.Logf("rate limited at request %d", i+1)
+ return
+ }
+ }
+
+ t.Fatalf("rate limiter: no 429 within 65 requests, last status = %d", lastCode)
+}
+
+func TestE2E_SecurityHeaders(t *testing.T) {
+ handler := e2eServer(t)
+
+ // Only check API paths — /health and /ready are on the top-level mux
+ // and intentionally bypass the middleware stack (including security headers).
+ paths := []struct {
+ method string
+ path string
+ }{
+ {http.MethodGet, "/api/v1/csrf-token"},
+ {http.MethodGet, "/api/v1/sessions"},
+ }
+
+ for _, p := range paths {
+ t.Run(p.method+" "+p.path, func(t *testing.T) {
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(p.method, p.path, nil)
+ r.RemoteAddr = "10.0.0.1:12345"
+
+ handler.ServeHTTP(w, r)
+
+ wantHeaders := map[string]string{
+ "X-Content-Type-Options": "nosniff",
+ "X-Frame-Options": "DENY",
+ "Referrer-Policy": "strict-origin-when-cross-origin",
+ "Content-Security-Policy": "default-src 'none'",
+ }
+
+ for header, want := range wantHeaders {
+ if got := w.Header().Get(header); got != want {
+ t.Errorf("%s %s header %q = %q, want %q", p.method, p.path, header, got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestE2E_CORSPreflight(t *testing.T) {
+ handler := e2eServer(t)
+
+ t.Run("allowed origin", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodOptions, "/api/v1/sessions", nil)
+ r.RemoteAddr = "10.0.0.1:12345"
+ r.Header.Set("Origin", "http://localhost:4200")
+
+ handler.ServeHTTP(w, r)
+
+ if w.Code != http.StatusNoContent {
+ t.Fatalf("CORS preflight status = %d, want %d", w.Code, http.StatusNoContent)
+ }
+ if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:4200" {
+ t.Errorf("Allow-Origin = %q, want %q", got, "http://localhost:4200")
+ }
+ if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
+ t.Errorf("Allow-Credentials = %q, want %q", got, "true")
+ }
+ })
+
+ t.Run("disallowed origin", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodOptions, "/api/v1/sessions", nil)
+ r.RemoteAddr = "10.0.0.1:12345"
+ r.Header.Set("Origin", "http://evil.com")
+
+ handler.ServeHTTP(w, r)
+
+ if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
+ t.Errorf("disallowed origin Allow-Origin = %q, want empty", got)
+ }
+ })
+}
+
+func TestE2E_SSEStream(t *testing.T) {
+ handler := e2eServer(t)
+
+ // --- Create a session first (ownership check requires valid cookie) ---
+ w1 := httptest.NewRecorder()
+ r1 := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil)
+ r1.RemoteAddr = "10.0.0.1:12345"
+ handler.ServeHTTP(w1, r1)
+
+ var csrfResp map[string]string
+ decodeData(t, w1, &csrfResp)
+
+ w2 := httptest.NewRecorder()
+ r2 := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil)
+ r2.RemoteAddr = "10.0.0.1:12345"
+ r2.Header.Set("X-CSRF-Token", csrfResp["csrfToken"])
+ handler.ServeHTTP(w2, r2)
+
+ if w2.Code != http.StatusCreated {
+ t.Fatalf("create session status = %d, want %d\nbody: %s", w2.Code, http.StatusCreated, w2.Body.String())
+ }
+
+ var sessResp map[string]string
+ decodeData(t, w2, &sessResp)
+ sessionID := sessResp["id"]
+ cookies := e2eCookies(t, w2)
+
+ // --- SSE stream with valid session cookie ---
+ // e2eServer has no ChatFlow configured (nil), so the handler returns
+ // an error event instead of chunk/done events.
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=e2e-1&session_id="+sessionID+"&query=hello", nil)
+ r.RemoteAddr = "10.0.0.1:12345"
+ e2eAddCookies(r, cookies)
+
+ handler.ServeHTTP(w, r)
+
+ // Should get SSE response (200 implicit from streaming)
+ if ct := w.Header().Get("Content-Type"); ct != "text/event-stream" {
+ t.Fatalf("SSE Content-Type = %q, want %q", ct, "text/event-stream")
+ }
+
+ body := w.Body.String()
+
+ // With nil ChatFlow, handler sends error event
+ if !strings.Contains(body, "event: error\n") {
+ t.Errorf("SSE response missing error event, body:\n%s", body)
+ }
+ if !strings.Contains(body, "chat flow not initialized") {
+ t.Errorf("SSE error missing expected message, body:\n%s", body)
+ }
+}
diff --git a/internal/api/health.go b/internal/api/health.go
index dd46070..b9a093f 100644
--- a/internal/api/health.go
+++ b/internal/api/health.go
@@ -5,5 +5,5 @@ import "net/http"
// health is a simple health check endpoint for Docker/Kubernetes probes.
// Returns 200 OK with {"status":"ok"}.
func health(w http.ResponseWriter, _ *http.Request) {
- WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"})
+ WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"}, nil)
}
diff --git a/internal/api/integration_test.go b/internal/api/integration_test.go
new file mode 100644
index 0000000..94dedc1
--- /dev/null
+++ b/internal/api/integration_test.go
@@ -0,0 +1,382 @@
+//go:build integration
+
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/firebase/genkit/go/ai"
+ "github.com/google/uuid"
+
+ "github.com/koopa0/koopa/internal/session"
+ "github.com/koopa0/koopa/internal/sqlc"
+ "github.com/koopa0/koopa/internal/testutil"
+)
+
+// setupIntegrationSessionManager creates a sessionManager backed by a real PostgreSQL database.
+func setupIntegrationSessionManager(t *testing.T) *sessionManager {
+ t.Helper()
+
+ db := testutil.SetupTestDB(t)
+
+ store := session.New(sqlc.New(db.Pool), db.Pool, slog.New(slog.DiscardHandler))
+
+ return &sessionManager{
+ store: store,
+ hmacSecret: []byte("test-secret-at-least-32-characters!!"),
+ isDev: true,
+ logger: slog.New(slog.DiscardHandler),
+ }
+}
+
+func TestCreateSession_Success(t *testing.T) {
+ sm := setupIntegrationSessionManager(t)
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil)
+
+ sm.createSession(w, r)
+
+ if w.Code != http.StatusCreated {
+ t.Fatalf("createSession() status = %d, want %d\nbody: %s", w.Code, http.StatusCreated, w.Body.String())
+ }
+
+ var resp map[string]string
+ decodeData(t, w, &resp)
+
+ // Should return a valid UUID
+ if _, err := uuid.Parse(resp["id"]); err != nil {
+ t.Errorf("createSession() id = %q, want valid UUID", resp["id"])
+ }
+
+ // Should return a CSRF token bound to the new session
+ if resp["csrfToken"] == "" {
+ t.Error("createSession() expected csrfToken in response")
+ }
+
+ // Should set a session cookie
+ cookies := w.Result().Cookies()
+ var found bool
+ for _, c := range cookies {
+ if c.Name == "sid" {
+ found = true
+ if c.Value != resp["id"] {
+ t.Errorf("createSession() cookie sid = %q, want %q", c.Value, resp["id"])
+ }
+ if !c.HttpOnly {
+ t.Error("createSession() cookie should be HttpOnly")
+ }
+ }
+ }
+ if !found {
+ t.Error("createSession() expected sid cookie to be set")
+ }
+}
+
+func TestGetSession_Success(t *testing.T) {
+ sm := setupIntegrationSessionManager(t)
+ ctx := context.Background()
+
+ // Create a session first
+ sess, err := sm.store.CreateSession(ctx, "Test Session")
+ if err != nil {
+ t.Fatalf("setup: CreateSession() error: %v", err)
+ }
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sess.ID.String(), nil)
+ r.SetPathValue("id", sess.ID.String())
+
+ // Inject session ownership (same session ID in context)
+ rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID)
+ r = r.WithContext(rctx)
+
+ sm.getSession(w, r)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("getSession(%s) status = %d, want %d\nbody: %s", sess.ID, w.Code, http.StatusOK, w.Body.String())
+ }
+
+ var resp map[string]string
+ decodeData(t, w, &resp)
+
+ if resp["id"] != sess.ID.String() {
+ t.Errorf("getSession() id = %q, want %q", resp["id"], sess.ID.String())
+ }
+ if resp["title"] != "Test Session" {
+ t.Errorf("getSession() title = %q, want %q", resp["title"], "Test Session")
+ }
+ if resp["createdAt"] == "" {
+ t.Error("getSession() expected createdAt in response")
+ }
+ if resp["updatedAt"] == "" {
+ t.Error("getSession() expected updatedAt in response")
+ }
+}
+
+func TestGetSession_NotFound(t *testing.T) {
+ sm := setupIntegrationSessionManager(t)
+
+ missingID := uuid.New()
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+missingID.String(), nil)
+ r.SetPathValue("id", missingID.String())
+
+ // Set ownership to match (bypasses ownership check, tests store-level 404)
+ rctx := context.WithValue(r.Context(), ctxKeySessionID, missingID)
+ r = r.WithContext(rctx)
+
+ sm.getSession(w, r)
+
+ if w.Code != http.StatusNotFound {
+ t.Fatalf("getSession(missing) status = %d, want %d", w.Code, http.StatusNotFound)
+ }
+
+ errResp := decodeErrorEnvelope(t, w)
+ if errResp.Code != "not_found" {
+ t.Errorf("getSession(missing) code = %q, want %q", errResp.Code, "not_found")
+ }
+}
+
+func TestListSessions_WithSession(t *testing.T) {
+ sm := setupIntegrationSessionManager(t)
+ ctx := context.Background()
+
+ // Create a session
+ sess, err := sm.store.CreateSession(ctx, "My Chat")
+ if err != nil {
+ t.Fatalf("setup: CreateSession() error: %v", err)
+ }
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil)
+
+ // Inject session ownership
+ rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID)
+ r = r.WithContext(rctx)
+
+ sm.listSessions(w, r)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("listSessions() status = %d, want %d\nbody: %s", w.Code, http.StatusOK, w.Body.String())
+ }
+
+ type sessionItem struct {
+ ID string `json:"id"`
+ Title string `json:"title"`
+ UpdatedAt string `json:"updatedAt"`
+ }
+ var items []sessionItem
+ decodeData(t, w, &items)
+
+ if len(items) != 1 {
+ t.Fatalf("listSessions() returned %d items, want 1", len(items))
+ }
+ if items[0].ID != sess.ID.String() {
+ t.Errorf("listSessions() items[0].id = %q, want %q", items[0].ID, sess.ID.String())
+ }
+ if items[0].Title != "My Chat" {
+ t.Errorf("listSessions() items[0].title = %q, want %q", items[0].Title, "My Chat")
+ }
+ if items[0].UpdatedAt == "" {
+ t.Error("listSessions() expected updatedAt in item")
+ }
+}
+
+func TestGetSessionMessages_Empty(t *testing.T) {
+ sm := setupIntegrationSessionManager(t)
+ ctx := context.Background()
+
+ // Create a session with no messages
+ sess, err := sm.store.CreateSession(ctx, "")
+ if err != nil {
+ t.Fatalf("setup: CreateSession() error: %v", err)
+ }
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sess.ID.String()+"/messages", nil)
+ r.SetPathValue("id", sess.ID.String())
+
+ rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID)
+ r = r.WithContext(rctx)
+
+ sm.getSessionMessages(w, r)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("getSessionMessages(empty) status = %d, want %d\nbody: %s", w.Code, http.StatusOK, w.Body.String())
+ }
+
+ // Decode the raw envelope to check the data type
+ var env struct {
+ Data json.RawMessage `json:"data"`
+ }
+ if err := json.NewDecoder(w.Body).Decode(&env); err != nil {
+ t.Fatalf("decoding envelope: %v", err)
+ }
+
+ var items []json.RawMessage
+ if err := json.Unmarshal(env.Data, &items); err != nil {
+ t.Fatalf("decoding data as array: %v", err)
+ }
+
+ if len(items) != 0 {
+ t.Errorf("getSessionMessages(empty) returned %d items, want 0", len(items))
+ }
+}
+
+func TestDeleteSession_Success(t *testing.T) {
+ sm := setupIntegrationSessionManager(t)
+ ctx := context.Background()
+
+ // Create a session
+ sess, err := sm.store.CreateSession(ctx, "To Delete")
+ if err != nil {
+ t.Fatalf("setup: CreateSession() error: %v", err)
+ }
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodDelete, "/api/v1/sessions/"+sess.ID.String(), nil)
+ r.SetPathValue("id", sess.ID.String())
+
+ // Inject ownership
+ rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID)
+ r = r.WithContext(rctx)
+
+ sm.deleteSession(w, r)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("deleteSession() status = %d, want %d\nbody: %s", w.Code, http.StatusOK, w.Body.String())
+ }
+
+ var resp map[string]string
+ decodeData(t, w, &resp)
+
+ if resp["status"] != "deleted" {
+ t.Errorf("deleteSession() status = %q, want %q", resp["status"], "deleted")
+ }
+
+ // Verify session is actually gone
+ _, err = sm.store.Session(ctx, sess.ID)
+ if err == nil {
+ t.Error("deleteSession() session still exists after deletion")
+ }
+}
+
+func TestCSRFTokenEndpoint_WithSession(t *testing.T) {
+ sm := setupIntegrationSessionManager(t)
+ ctx := context.Background()
+
+ // Create a real session
+ sess, err := sm.store.CreateSession(ctx, "")
+ if err != nil {
+ t.Fatalf("setup: CreateSession() error: %v", err)
+ }
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil)
+ r.AddCookie(&http.Cookie{
+ Name: "sid",
+ Value: sess.ID.String(),
+ })
+
+ sm.csrfToken(w, r)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("csrfToken(with session) status = %d, want %d", w.Code, http.StatusOK)
+ }
+
+ var body map[string]string
+ decodeData(t, w, &body)
+
+ token := body["csrfToken"]
+ if token == "" {
+ t.Fatal("csrfToken(with session) expected csrfToken in response")
+ }
+
+ // Session-bound tokens should NOT have pre: prefix
+ if isPreSessionToken(token) {
+ t.Error("csrfToken(with session) should return session-bound token, not pre-session")
+ }
+
+ // Token should validate against the session
+ if err := sm.CheckCSRF(sess.ID, token); err != nil {
+ t.Fatalf("csrfToken(with session) returned invalid token: %v", err)
+ }
+}
+
+func TestGetSessionMessages_WithMessages(t *testing.T) {
+ sm := setupIntegrationSessionManager(t)
+ ctx := context.Background()
+
+ // Create a session with messages
+ sess, err := sm.store.CreateSession(ctx, "Test Chat")
+ if err != nil {
+ t.Fatalf("setup: CreateSession() error: %v", err)
+ }
+
+ // Add user and model messages
+ msgs := []*ai.Message{
+ ai.NewUserMessage(ai.NewTextPart("What is Go?")),
+ ai.NewModelMessage(ai.NewTextPart("Go is a programming language.")),
+ }
+ if err := sm.store.AppendMessages(ctx, sess.ID, msgs); err != nil {
+ t.Fatalf("setup: AppendMessages() error: %v", err)
+ }
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sess.ID.String()+"/messages", nil)
+ r.SetPathValue("id", sess.ID.String())
+
+ rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID)
+ r = r.WithContext(rctx)
+
+ sm.getSessionMessages(w, r)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("getSessionMessages() status = %d, want %d\nbody: %s", w.Code, http.StatusOK, w.Body.String())
+ }
+
+ type messageItem struct {
+ ID string `json:"id"`
+ Role string `json:"role"`
+ Content string `json:"content"`
+ CreatedAt string `json:"createdAt"`
+ }
+ var items []messageItem
+ decodeData(t, w, &items)
+
+ if len(items) != 2 {
+ t.Fatalf("getSessionMessages() returned %d items, want 2", len(items))
+ }
+
+ // First message: user
+ if items[0].Role != "user" {
+ t.Errorf("getSessionMessages() items[0].role = %q, want %q", items[0].Role, "user")
+ }
+ if items[0].Content != "What is Go?" {
+ t.Errorf("getSessionMessages() items[0].content = %q, want %q", items[0].Content, "What is Go?")
+ }
+ if items[0].ID == "" {
+ t.Error("getSessionMessages() items[0].id is empty")
+ }
+ if items[0].CreatedAt == "" {
+ t.Error("getSessionMessages() items[0].createdAt is empty")
+ }
+
+ // Second message: model (normalizeRole converts "model" → "assistant" in DB)
+ if items[1].Role != "assistant" {
+ t.Errorf("getSessionMessages() items[1].role = %q, want %q", items[1].Role, "assistant")
+ }
+ if items[1].Content != "Go is a programming language." {
+ t.Errorf("getSessionMessages() items[1].content = %q, want %q", items[1].Content, "Go is a programming language.")
+ }
+ if items[1].ID == "" {
+ t.Error("getSessionMessages() items[1].id is empty")
+ }
+}
diff --git a/internal/api/middleware.go b/internal/api/middleware.go
index a9a02d0..821eaef 100644
--- a/internal/api/middleware.go
+++ b/internal/api/middleware.go
@@ -3,10 +3,8 @@ package api
import (
"context"
"log/slog"
- "net"
"net/http"
"strings"
- "sync"
"time"
"github.com/google/uuid"
@@ -17,9 +15,9 @@ type sessionIDKey struct{}
var ctxKeySessionID = sessionIDKey{}
-// SessionIDFromContext retrieves the session ID from the request context.
+// sessionIDFromContext retrieves the session ID from the request context.
// Returns uuid.Nil and false if not found.
-func SessionIDFromContext(ctx context.Context) (uuid.UUID, bool) {
+func sessionIDFromContext(ctx context.Context) (uuid.UUID, bool) {
sessionID, ok := ctx.Value(ctxKeySessionID).(uuid.UUID)
return sessionID, ok
}
@@ -27,46 +25,47 @@ func SessionIDFromContext(ctx context.Context) (uuid.UUID, bool) {
// loggingWriter wraps http.ResponseWriter to capture metrics.
// Implements Flusher for SSE streaming and Unwrap for ResponseController.
type loggingWriter struct {
- http.ResponseWriter
+ w http.ResponseWriter
statusCode int
bytesWritten int64
}
-func (w *loggingWriter) WriteHeader(code int) {
- w.statusCode = code
- w.ResponseWriter.WriteHeader(code)
+func (lw *loggingWriter) Header() http.Header {
+ return lw.w.Header()
+}
+
+func (lw *loggingWriter) WriteHeader(code int) {
+ lw.statusCode = code
+ lw.w.WriteHeader(code)
}
//nolint:wrapcheck // http.ResponseWriter wrapper must return unwrapped errors
-func (w *loggingWriter) Write(b []byte) (int, error) {
- if w.statusCode == 0 {
- w.statusCode = http.StatusOK
+func (lw *loggingWriter) Write(b []byte) (int, error) {
+ if lw.statusCode == 0 {
+ lw.statusCode = http.StatusOK
}
- n, err := w.ResponseWriter.Write(b)
- w.bytesWritten += int64(n)
+ n, err := lw.w.Write(b)
+ lw.bytesWritten += int64(n)
return n, err
}
// Flush implements http.Flusher for SSE streaming support.
-func (w *loggingWriter) Flush() {
- if f, ok := w.ResponseWriter.(http.Flusher); ok {
+func (lw *loggingWriter) Flush() {
+ if f, ok := lw.w.(http.Flusher); ok {
f.Flush()
}
}
// Unwrap returns the underlying ResponseWriter for http.ResponseController.
-func (w *loggingWriter) Unwrap() http.ResponseWriter {
- return w.ResponseWriter
+func (lw *loggingWriter) Unwrap() http.ResponseWriter {
+ return lw.w
}
// recoveryMiddleware recovers from panics to prevent server crashes.
func recoveryMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- wrapper := &loggingWriter{
- ResponseWriter: w,
- statusCode: 0,
- }
+ wrapper := &loggingWriter{w: w}
defer func() {
if err := recover(); err != nil {
@@ -77,7 +76,7 @@ func recoveryMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
)
if wrapper.statusCode == 0 {
- WriteError(w, http.StatusInternalServerError, "internal_error", "internal server error")
+ WriteError(w, http.StatusInternalServerError, "internal_error", "internal server error", logger)
} else {
logger.Warn("cannot send error response, headers already sent",
"path", r.URL.Path,
@@ -92,22 +91,29 @@ func recoveryMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
}
// loggingMiddleware logs request details including latency, status, and response size.
+// Reuses an existing *loggingWriter from outer middleware (e.g., recoveryMiddleware)
+// to avoid double-wrapping the ResponseWriter.
func loggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
- wrapper := &loggingWriter{
- ResponseWriter: w,
- statusCode: http.StatusOK,
+ wrapper, ok := w.(*loggingWriter)
+ if !ok {
+ wrapper = &loggingWriter{w: w}
}
next.ServeHTTP(wrapper, r)
+ status := wrapper.statusCode
+ if status == 0 {
+ status = http.StatusOK
+ }
+
logger.Debug("http request",
"method", r.Method,
"path", r.URL.Path,
- "status", wrapper.statusCode,
+ "status", status,
"bytes", wrapper.bytesWritten,
"duration", time.Since(start),
"ip", r.RemoteAddr,
@@ -146,39 +152,21 @@ func corsMiddleware(allowedOrigins []string) func(http.Handler) http.Handler {
}
}
-// sessionMiddleware ensures a valid session exists before processing the request.
-// GET/HEAD/OPTIONS: read-only, don't create session.
-// POST/PUT/DELETE: create session if needed.
-func sessionMiddleware(sm *sessionManager, logger *slog.Logger) func(http.Handler) http.Handler {
+// sessionMiddleware extracts the session ID from the cookie and adds it to the
+// request context. If no valid session cookie is present, the request continues
+// without a session ID in context. Individual handlers are responsible for
+// creating sessions when needed (e.g., createSession).
+func sessionMiddleware(sm *sessionManager) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
- sessionID, err := sm.ID(r)
- if err == nil {
- ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID)
- next.ServeHTTP(w, r.WithContext(ctx))
- return
- }
- // No session — pre-session state, continue without session ID
- next.ServeHTTP(w, r)
- return
- }
-
- // State-changing request: create session if needed
- sessionID, err := sm.GetOrCreate(w, r)
- if err != nil {
- logger.Error("session creation failed",
- "error", err,
- "path", r.URL.Path,
- "method", r.Method,
- "remote_addr", r.RemoteAddr,
- )
- WriteError(w, http.StatusInternalServerError, "session_error", "session creation failed")
+ sessionID, err := sm.ID(r)
+ if err == nil {
+ ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID)
+ next.ServeHTTP(w, r.WithContext(ctx))
return
}
-
- ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID)
- next.ServeHTTP(w, r.WithContext(ctx))
+ // No valid session cookie — continue without session in context
+ next.ServeHTTP(w, r)
})
}
}
@@ -205,7 +193,7 @@ func csrfMiddleware(sm *sessionManager, logger *slog.Logger) func(http.Handler)
"path", r.URL.Path,
"method", r.Method,
)
- WriteError(w, http.StatusForbidden, "csrf_invalid", "CSRF validation failed")
+ WriteError(w, http.StatusForbidden, "csrf_invalid", "CSRF validation failed", logger)
return
}
next.ServeHTTP(w, r)
@@ -213,24 +201,24 @@ func csrfMiddleware(sm *sessionManager, logger *slog.Logger) func(http.Handler)
}
// Session-bound token
- sessionID, ok := SessionIDFromContext(r.Context())
+ sessionID, ok := sessionIDFromContext(r.Context())
if !ok {
- logger.Error("CSRF validation failed: session ID not in context",
+ logger.Error("validating CSRF: session ID not in context",
"path", r.URL.Path,
"method", r.Method,
)
- WriteError(w, http.StatusForbidden, "session_required", "session required")
+ WriteError(w, http.StatusForbidden, "session_required", "session required", logger)
return
}
if err := sm.CheckCSRF(sessionID, csrfToken); err != nil {
- logger.Warn("CSRF validation failed",
+ logger.Warn("validating CSRF",
"error", err,
"session", sessionID,
"path", r.URL.Path,
"method", r.Method,
)
- WriteError(w, http.StatusForbidden, "csrf_invalid", "CSRF validation failed")
+ WriteError(w, http.StatusForbidden, "csrf_invalid", "CSRF validation failed", logger)
return
}
@@ -255,140 +243,3 @@ func setSecurityHeaders(w http.ResponseWriter, isDev bool) {
func isPreSessionToken(token string) bool {
return strings.HasPrefix(token, preSessionPrefix)
}
-
-// ============================================================================
-// Rate Limiting
-// ============================================================================
-
-const (
- rateLimiterCleanupInterval = 5 * time.Minute
- rateLimiterStaleThreshold = 10 * time.Minute
-)
-
-// rateLimiter implements per-IP token bucket rate limiting.
-// Cleanup of stale entries happens inline during allow() calls.
-type rateLimiter struct {
- mu sync.Mutex
- visitors map[string]*visitor
- rate float64 // tokens per second
- burst int // max tokens (also initial tokens)
- lastCleanup time.Time
-}
-
-// visitor tracks the token bucket state for a single IP.
-type visitor struct {
- tokens float64
- lastSeen time.Time
-}
-
-// newRateLimiter creates a rate limiter.
-// rate: tokens refilled per second. burst: maximum tokens (and initial allowance).
-func newRateLimiter(rate float64, burst int) *rateLimiter {
- return &rateLimiter{
- visitors: make(map[string]*visitor),
- rate: rate,
- burst: burst,
- lastCleanup: time.Now(),
- }
-}
-
-// allow checks if a request from the given IP is allowed.
-// Returns false if the IP has exhausted its tokens.
-func (rl *rateLimiter) allow(ip string) bool {
- rl.mu.Lock()
- defer rl.mu.Unlock()
-
- now := time.Now()
-
- // Periodic cleanup of stale entries
- if now.Sub(rl.lastCleanup) > rateLimiterCleanupInterval {
- for k, v := range rl.visitors {
- if now.Sub(v.lastSeen) > rateLimiterStaleThreshold {
- delete(rl.visitors, k)
- }
- }
- rl.lastCleanup = now
- }
-
- v, exists := rl.visitors[ip]
- if !exists {
- rl.visitors[ip] = &visitor{
- tokens: float64(rl.burst) - 1,
- lastSeen: now,
- }
- return true
- }
-
- // Refill tokens based on elapsed time
- elapsed := now.Sub(v.lastSeen).Seconds()
- v.tokens += elapsed * rl.rate
- if v.tokens > float64(rl.burst) {
- v.tokens = float64(rl.burst)
- }
- v.lastSeen = now
-
- if v.tokens < 1 {
- return false
- }
-
- v.tokens--
- return true
-}
-
-// rateLimitMiddleware returns middleware that limits requests per IP.
-// Uses token bucket algorithm: each IP gets `burst` initial tokens,
-// refilling at `rate` tokens per second.
-func rateLimitMiddleware(rl *rateLimiter, trustProxy bool, logger *slog.Logger) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- ip := clientIP(r, trustProxy)
- if !rl.allow(ip) {
- logger.Warn("rate limit exceeded",
- "ip", ip,
- "path", r.URL.Path,
- "method", r.Method,
- )
- w.Header().Set("Retry-After", "1")
- WriteError(w, http.StatusTooManyRequests, "rate_limited", "too many requests")
- return
- }
- next.ServeHTTP(w, r)
- })
- }
-}
-
-// clientIP extracts the client IP from the request.
-//
-// When trustProxy is true, checks X-Real-IP first (set by nginx/HAProxy),
-// then X-Forwarded-For (first IP). Header values are validated with net.ParseIP
-// to prevent injection of non-IP strings into rate limiter keys.
-//
-// When trustProxy is false, only uses RemoteAddr (safe default for direct exposure).
-func clientIP(r *http.Request, trustProxy bool) string {
- if trustProxy {
- // Prefer X-Real-IP (single value, set by reverse proxy)
- if xri := r.Header.Get("X-Real-IP"); xri != "" {
- if ip := net.ParseIP(strings.TrimSpace(xri)); ip != nil {
- return ip.String()
- }
- }
-
- // Fall back to X-Forwarded-For (first IP is the client)
- if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
- raw := xff
- if first, _, ok := strings.Cut(xff, ","); ok {
- raw = first
- }
- if ip := net.ParseIP(strings.TrimSpace(raw)); ip != nil {
- return ip.String()
- }
- }
- }
-
- // Fall back to RemoteAddr (strip port)
- ip, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
- return r.RemoteAddr
- }
- return ip
-}
diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go
index c108e77..c27e60f 100644
--- a/internal/api/middleware_test.go
+++ b/internal/api/middleware_test.go
@@ -6,7 +6,6 @@ import (
"net/http"
"net/http/httptest"
"testing"
- "time"
"github.com/google/uuid"
)
@@ -44,7 +43,7 @@ func TestRecoveryMiddleware_NoPanic(t *testing.T) {
logger := discardLogger()
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- WriteJSON(w, http.StatusOK, map[string]string{"ok": "true"})
+ WriteJSON(w, http.StatusOK, map[string]string{"ok": "true"}, nil)
})
handler := recoveryMiddleware(logger)(okHandler)
@@ -266,7 +265,7 @@ func TestSecurityHeaders(t *testing.T) {
w := httptest.NewRecorder()
setSecurityHeaders(w, false)
- expected := map[string]string{
+ wantHeaders := map[string]string{
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"Referrer-Policy": "strict-origin-when-cross-origin",
@@ -274,7 +273,7 @@ func TestSecurityHeaders(t *testing.T) {
"Strict-Transport-Security": "max-age=63072000; includeSubDomains",
}
- for header, want := range expected {
+ for header, want := range wantHeaders {
if got := w.Header().Get(header); got != want {
t.Errorf("setSecurityHeaders(isDev=false) %q = %q, want %q", header, got, want)
}
@@ -296,208 +295,133 @@ func TestSecurityHeaders(t *testing.T) {
})
}
-// ============================================================================
-// Rate Limiting Tests
-// ============================================================================
+func Test_sessionIDFromContext(t *testing.T) {
+ t.Run("present", func(t *testing.T) {
+ id := uuid.New()
+ ctx := context.WithValue(context.Background(), ctxKeySessionID, id)
-func TestRateLimiter_AllowsWithinBurst(t *testing.T) {
- rl := newRateLimiter(1.0, 5)
+ got, ok := sessionIDFromContext(ctx)
+ if !ok {
+ t.Fatal("sessionIDFromContext() ok = false, want true")
+ }
+ if got != id {
+ t.Errorf("sessionIDFromContext() = %s, want %s", got, id)
+ }
+ })
- for i := range 5 {
- if !rl.allow("1.2.3.4") {
- t.Fatalf("allow() returned false on request %d (within burst of 5)", i+1)
+ t.Run("absent", func(t *testing.T) {
+ _, ok := sessionIDFromContext(context.Background())
+ if ok {
+ t.Error("sessionIDFromContext(empty) ok = true, want false")
}
- }
+ })
}
-func TestRateLimiter_BlocksAfterBurst(t *testing.T) {
- rl := newRateLimiter(1.0, 3)
-
- // Exhaust the burst
- for range 3 {
- rl.allow("1.2.3.4")
- }
-
- if rl.allow("1.2.3.4") {
- t.Error("allow() should return false after burst exhausted")
+func TestSessionMiddleware_GET_WithCookie(t *testing.T) {
+ logger := discardLogger()
+ sm := &sessionManager{
+ hmacSecret: []byte("test-secret-at-least-32-characters!!"),
+ logger: logger,
}
-}
-
-func TestRateLimiter_SeparateIPs(t *testing.T) {
- rl := newRateLimiter(1.0, 2)
- // Exhaust IP 1
- rl.allow("1.1.1.1")
- rl.allow("1.1.1.1")
+ sessionID := uuid.New()
- // IP 2 should still be allowed
- if !rl.allow("2.2.2.2") {
- t.Error("allow() should allow a different IP")
- }
-}
+ var gotID uuid.UUID
+ var gotOK bool
+ handler := sessionMiddleware(sm)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
+ gotID, gotOK = sessionIDFromContext(r.Context())
+ }))
-func TestRateLimiter_RefillsOverTime(t *testing.T) {
- rl := newRateLimiter(100.0, 1) // 100 tokens/sec so we can test quickly
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil)
+ r.AddCookie(&http.Cookie{Name: "sid", Value: sessionID.String()})
- // Use the single token
- rl.allow("1.2.3.4")
+ handler.ServeHTTP(w, r)
- if rl.allow("1.2.3.4") {
- t.Error("allow() should be blocked immediately after burst exhausted")
+ if !gotOK {
+ t.Fatal("sessionMiddleware(GET, valid cookie) expected session ID in context")
}
-
- // Wait enough time for a token to refill
- time.Sleep(20 * time.Millisecond)
-
- if !rl.allow("1.2.3.4") {
- t.Error("allow() should be allowed after token refill")
+ if gotID != sessionID {
+ t.Errorf("sessionMiddleware(GET, valid cookie) session ID = %s, want %s", gotID, sessionID)
}
}
-func TestRateLimitMiddleware_Returns429(t *testing.T) {
- rl := newRateLimiter(0.001, 1) // Very low rate
+func TestSessionMiddleware_GET_WithoutCookie(t *testing.T) {
logger := discardLogger()
+ sm := &sessionManager{
+ hmacSecret: []byte("test-secret-at-least-32-characters!!"),
+ logger: logger,
+ }
- handler := rateLimitMiddleware(rl, false, logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- w.WriteHeader(http.StatusOK)
+ var gotOK bool
+ called := false
+ handler := sessionMiddleware(sm)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
+ called = true
+ _, gotOK = sessionIDFromContext(r.Context())
}))
- // First request should succeed
w := httptest.NewRecorder()
- r := httptest.NewRequest(http.MethodGet, "/", nil)
- r.RemoteAddr = "10.0.0.1:12345"
- handler.ServeHTTP(w, r)
-
- if w.Code != http.StatusOK {
- t.Fatalf("first request status = %d, want %d", w.Code, http.StatusOK)
- }
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil)
- // Second request should be rate limited
- w = httptest.NewRecorder()
- r = httptest.NewRequest(http.MethodGet, "/", nil)
- r.RemoteAddr = "10.0.0.1:12345"
handler.ServeHTTP(w, r)
- if w.Code != http.StatusTooManyRequests {
- t.Fatalf("rate limited request status = %d, want %d", w.Code, http.StatusTooManyRequests)
+ if !called {
+ t.Fatal("sessionMiddleware(GET, no cookie) did not call next handler")
}
-
- if got := w.Header().Get("Retry-After"); got != "1" {
- t.Errorf("Retry-After = %q, want %q", got, "1")
+ if gotOK {
+ t.Error("sessionMiddleware(GET, no cookie) should not have session ID in context")
}
}
-func TestClientIP(t *testing.T) {
- tests := []struct {
- name string
- trustProxy bool
- remoteAddr string
- xff string
- xri string
- want string
- }{
- {
- name: "remote addr with port",
- trustProxy: true,
- remoteAddr: "10.0.0.1:12345",
- want: "10.0.0.1",
- },
- {
- name: "X-Forwarded-For single when trusted",
- trustProxy: true,
- remoteAddr: "127.0.0.1:80",
- xff: "203.0.113.50",
- want: "203.0.113.50",
- },
- {
- name: "X-Forwarded-For multiple when trusted",
- trustProxy: true,
- remoteAddr: "127.0.0.1:80",
- xff: "203.0.113.50, 70.41.3.18, 150.172.238.178",
- want: "203.0.113.50",
- },
- {
- name: "X-Real-IP when trusted",
- trustProxy: true,
- remoteAddr: "127.0.0.1:80",
- xri: "203.0.113.50",
- want: "203.0.113.50",
- },
- {
- name: "X-Real-IP takes precedence over X-Forwarded-For when trusted",
- trustProxy: true,
- remoteAddr: "127.0.0.1:80",
- xff: "203.0.113.50",
- xri: "198.51.100.1",
- want: "198.51.100.1",
- },
- {
- name: "untrusted ignores X-Forwarded-For",
- trustProxy: false,
- remoteAddr: "10.0.0.1:12345",
- xff: "203.0.113.50",
- want: "10.0.0.1",
- },
- {
- name: "untrusted ignores X-Real-IP",
- trustProxy: false,
- remoteAddr: "10.0.0.1:12345",
- xri: "203.0.113.50",
- want: "10.0.0.1",
- },
- {
- name: "invalid X-Real-IP falls through to XFF",
- trustProxy: true,
- remoteAddr: "127.0.0.1:80",
- xri: "not-an-ip",
- xff: "203.0.113.50",
- want: "203.0.113.50",
- },
- {
- name: "invalid XFF falls through to RemoteAddr",
- trustProxy: true,
- remoteAddr: "127.0.0.1:80",
- xff: "not-an-ip",
- want: "127.0.0.1",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- r := httptest.NewRequest(http.MethodGet, "/", nil)
- r.RemoteAddr = tt.remoteAddr
- if tt.xff != "" {
- r.Header.Set("X-Forwarded-For", tt.xff)
- }
- if tt.xri != "" {
- r.Header.Set("X-Real-IP", tt.xri)
- }
+func TestSessionMiddleware_GET_InvalidCookie(t *testing.T) {
+ logger := discardLogger()
+ sm := &sessionManager{
+ hmacSecret: []byte("test-secret-at-least-32-characters!!"),
+ logger: logger,
+ }
- if got := clientIP(r, tt.trustProxy); got != tt.want {
- t.Errorf("clientIP(r, %v) = %q, want %q", tt.trustProxy, got, tt.want)
- }
- })
+ var gotOK bool
+ handler := sessionMiddleware(sm)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
+ _, gotOK = sessionIDFromContext(r.Context())
+ }))
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil)
+ r.AddCookie(&http.Cookie{Name: "sid", Value: "not-a-uuid"})
+
+ handler.ServeHTTP(w, r)
+
+ if gotOK {
+ t.Error("sessionMiddleware(GET, invalid cookie) should not have session ID in context")
}
}
-func TestSessionIDFromContext(t *testing.T) {
- t.Run("present", func(t *testing.T) {
- id := uuid.New()
- ctx := context.WithValue(context.Background(), ctxKeySessionID, id)
+func TestLoggingMiddleware(t *testing.T) {
+ logger := discardLogger()
- got, ok := SessionIDFromContext(ctx)
- if !ok {
- t.Fatal("expected session ID to be present")
- }
- if got != id {
- t.Errorf("SessionIDFromContext() = %s, want %s", got, id)
+ called := false
+ inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ called = true
+ w.WriteHeader(http.StatusCreated)
+ if _, err := w.Write([]byte("hello")); err != nil {
+ t.Errorf("Write() error: %v", err)
}
})
- t.Run("absent", func(t *testing.T) {
- _, ok := SessionIDFromContext(context.Background())
- if ok {
- t.Error("expected session ID to be absent")
- }
- })
+ handler := loggingMiddleware(logger)(inner)
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodPost, "/test", nil)
+
+ handler.ServeHTTP(w, r)
+
+ if !called {
+ t.Fatal("loggingMiddleware did not call next handler")
+ }
+ if w.Code != http.StatusCreated {
+ t.Errorf("loggingMiddleware status = %d, want %d", w.Code, http.StatusCreated)
+ }
+ if w.Body.String() != "hello" {
+ t.Errorf("loggingMiddleware body = %q, want %q", w.Body.String(), "hello")
+ }
}
diff --git a/internal/api/ratelimit.go b/internal/api/ratelimit.go
new file mode 100644
index 0000000..c937e49
--- /dev/null
+++ b/internal/api/ratelimit.go
@@ -0,0 +1,135 @@
+package api
+
+import (
+ "log/slog"
+ "net"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "golang.org/x/time/rate"
+)
+
+const (
+ rateLimiterCleanupInterval = 5 * time.Minute
+ rateLimiterStaleThreshold = 10 * time.Minute
+)
+
+// rateLimiter implements per-IP rate limiting using golang.org/x/time/rate.
+// Cleanup of stale entries happens inline during allow() calls.
+type rateLimiter struct {
+ mu sync.Mutex
+ visitors map[string]*visitor
+ limit rate.Limit
+ burst int
+ lastCleanup time.Time
+}
+
+// visitor holds a rate limiter and last-seen time for a single IP.
+type visitor struct {
+ limiter *rate.Limiter
+ lastSeen time.Time
+}
+
+// newRateLimiter creates a rate limiter.
+// r: tokens refilled per second. burst: maximum tokens (and initial allowance).
+func newRateLimiter(r float64, burst int) *rateLimiter {
+ return &rateLimiter{
+ visitors: make(map[string]*visitor),
+ limit: rate.Limit(r),
+ burst: burst,
+ lastCleanup: time.Now(),
+ }
+}
+
+// allow checks if a request from the given IP is allowed.
+// Returns false if the IP has exhausted its tokens.
+func (rl *rateLimiter) allow(ip string) bool {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ now := time.Now()
+
+ // Periodic cleanup of stale entries
+ if now.Sub(rl.lastCleanup) > rateLimiterCleanupInterval {
+ for k, v := range rl.visitors {
+ if now.Sub(v.lastSeen) > rateLimiterStaleThreshold {
+ delete(rl.visitors, k)
+ }
+ }
+ rl.lastCleanup = now
+ }
+
+ v, exists := rl.visitors[ip]
+ if !exists {
+ limiter := rate.NewLimiter(rl.limit, rl.burst)
+ rl.visitors[ip] = &visitor{
+ limiter: limiter,
+ lastSeen: now,
+ }
+ limiter.Allow()
+ return true
+ }
+
+ v.lastSeen = now
+ return v.limiter.Allow()
+}
+
+// rateLimitMiddleware returns middleware that limits requests per IP.
+// Uses token bucket algorithm: each IP gets `burst` initial tokens,
+// refilling at `rate` tokens per second.
+func rateLimitMiddleware(rl *rateLimiter, trustProxy bool, logger *slog.Logger) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ip := clientIP(r, trustProxy)
+ if !rl.allow(ip) {
+ logger.Warn("rate limit exceeded",
+ "ip", ip,
+ "path", r.URL.Path,
+ "method", r.Method,
+ )
+ w.Header().Set("Retry-After", "1")
+ WriteError(w, http.StatusTooManyRequests, "rate_limited", "too many requests", logger)
+ return
+ }
+ next.ServeHTTP(w, r)
+ })
+ }
+}
+
+// clientIP extracts the client IP from the request.
+//
+// When trustProxy is true, checks X-Real-IP first (set by nginx/HAProxy),
+// then X-Forwarded-For (first IP). Header values are validated with net.ParseIP
+// to prevent injection of non-IP strings into rate limiter keys.
+//
+// When trustProxy is false, only uses RemoteAddr (safe default for direct exposure).
+func clientIP(r *http.Request, trustProxy bool) string {
+ if trustProxy {
+ // Prefer X-Real-IP (single value, set by reverse proxy)
+ if xri := r.Header.Get("X-Real-IP"); xri != "" {
+ if ip := net.ParseIP(strings.TrimSpace(xri)); ip != nil {
+ return ip.String()
+ }
+ }
+
+ // Fall back to X-Forwarded-For (first IP is the client)
+ if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
+ raw := xff
+ if first, _, ok := strings.Cut(xff, ","); ok {
+ raw = first
+ }
+ if ip := net.ParseIP(strings.TrimSpace(raw)); ip != nil {
+ return ip.String()
+ }
+ }
+ }
+
+ // Fall back to RemoteAddr (strip port)
+ ip, _, err := net.SplitHostPort(r.RemoteAddr)
+ if err != nil {
+ return r.RemoteAddr
+ }
+ return ip
+}
diff --git a/internal/api/ratelimit_test.go b/internal/api/ratelimit_test.go
new file mode 100644
index 0000000..ff2d26b
--- /dev/null
+++ b/internal/api/ratelimit_test.go
@@ -0,0 +1,204 @@
+package api
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+)
+
+func TestRateLimiter_AllowsWithinBurst(t *testing.T) {
+ rl := newRateLimiter(1.0, 5)
+
+ for i := range 5 {
+ if !rl.allow("1.2.3.4") {
+ t.Fatalf("allow() returned false on request %d (within burst of 5)", i+1)
+ }
+ }
+}
+
+func TestRateLimiter_BlocksAfterBurst(t *testing.T) {
+ rl := newRateLimiter(1.0, 3)
+
+ // Exhaust the burst
+ for range 3 {
+ rl.allow("1.2.3.4")
+ }
+
+ if rl.allow("1.2.3.4") {
+ t.Error("allow() should return false after burst exhausted")
+ }
+}
+
+func TestRateLimiter_SeparateIPs(t *testing.T) {
+ rl := newRateLimiter(1.0, 2)
+
+ // Exhaust IP 1
+ rl.allow("1.1.1.1")
+ rl.allow("1.1.1.1")
+
+ // IP 2 should still be allowed
+ if !rl.allow("2.2.2.2") {
+ t.Error("allow() should allow a different IP")
+ }
+}
+
+func TestRateLimiter_RefillsOverTime(t *testing.T) {
+ rl := newRateLimiter(100.0, 1) // 100 tokens/sec so we can test quickly
+
+ // Use the single token
+ rl.allow("1.2.3.4")
+
+ if rl.allow("1.2.3.4") {
+ t.Error("allow() should be blocked immediately after burst exhausted")
+ }
+
+ // Wait enough time for a token to refill
+ time.Sleep(20 * time.Millisecond)
+
+ if !rl.allow("1.2.3.4") {
+ t.Error("allow() should be allowed after token refill")
+ }
+}
+
+func TestRateLimitMiddleware_Returns429(t *testing.T) {
+ rl := newRateLimiter(0.001, 1) // Very low rate
+ logger := discardLogger()
+
+ handler := rateLimitMiddleware(rl, false, logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ // First request should succeed
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest(http.MethodGet, "/", nil)
+ r.RemoteAddr = "10.0.0.1:12345"
+ handler.ServeHTTP(w, r)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("first request status = %d, want %d", w.Code, http.StatusOK)
+ }
+
+ // Second request should be rate limited
+ w = httptest.NewRecorder()
+ r = httptest.NewRequest(http.MethodGet, "/", nil)
+ r.RemoteAddr = "10.0.0.1:12345"
+ handler.ServeHTTP(w, r)
+
+ if w.Code != http.StatusTooManyRequests {
+ t.Fatalf("rate limited request status = %d, want %d", w.Code, http.StatusTooManyRequests)
+ }
+
+ if got := w.Header().Get("Retry-After"); got != "1" {
+ t.Errorf("Retry-After = %q, want %q", got, "1")
+ }
+}
+
+func TestClientIP(t *testing.T) {
+ tests := []struct {
+ name string
+ trustProxy bool
+ remoteAddr string
+ xff string
+ xri string
+ want string
+ }{
+ {
+ name: "remote addr with port",
+ trustProxy: true,
+ remoteAddr: "10.0.0.1:12345",
+ want: "10.0.0.1",
+ },
+ {
+ name: "X-Forwarded-For single when trusted",
+ trustProxy: true,
+ remoteAddr: "127.0.0.1:80",
+ xff: "203.0.113.50",
+ want: "203.0.113.50",
+ },
+ {
+ name: "X-Forwarded-For multiple when trusted",
+ trustProxy: true,
+ remoteAddr: "127.0.0.1:80",
+ xff: "203.0.113.50, 70.41.3.18, 150.172.238.178",
+ want: "203.0.113.50",
+ },
+ {
+ name: "X-Real-IP when trusted",
+ trustProxy: true,
+ remoteAddr: "127.0.0.1:80",
+ xri: "203.0.113.50",
+ want: "203.0.113.50",
+ },
+ {
+ name: "X-Real-IP takes precedence over X-Forwarded-For when trusted",
+ trustProxy: true,
+ remoteAddr: "127.0.0.1:80",
+ xff: "203.0.113.50",
+ xri: "198.51.100.1",
+ want: "198.51.100.1",
+ },
+ {
+ name: "untrusted ignores X-Forwarded-For",
+ trustProxy: false,
+ remoteAddr: "10.0.0.1:12345",
+ xff: "203.0.113.50",
+ want: "10.0.0.1",
+ },
+ {
+ name: "untrusted ignores X-Real-IP",
+ trustProxy: false,
+ remoteAddr: "10.0.0.1:12345",
+ xri: "203.0.113.50",
+ want: "10.0.0.1",
+ },
+ {
+ name: "invalid X-Real-IP falls through to XFF",
+ trustProxy: true,
+ remoteAddr: "127.0.0.1:80",
+ xri: "not-an-ip",
+ xff: "203.0.113.50",
+ want: "203.0.113.50",
+ },
+ {
+ name: "invalid XFF falls through to RemoteAddr",
+ trustProxy: true,
+ remoteAddr: "127.0.0.1:80",
+ xff: "not-an-ip",
+ want: "127.0.0.1",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := httptest.NewRequest(http.MethodGet, "/", nil)
+ r.RemoteAddr = tt.remoteAddr
+ if tt.xff != "" {
+ r.Header.Set("X-Forwarded-For", tt.xff)
+ }
+ if tt.xri != "" {
+ r.Header.Set("X-Real-IP", tt.xri)
+ }
+
+ if got := clientIP(r, tt.trustProxy); got != tt.want {
+ t.Errorf("clientIP(r, %v) = %q, want %q", tt.trustProxy, got, tt.want)
+ }
+ })
+ }
+}
+
+func BenchmarkRateLimiterAllow(b *testing.B) {
+ rl := newRateLimiter(1e9, 1<<30) // effectively unlimited
+ for b.Loop() {
+ rl.allow("1.2.3.4")
+ }
+}
+
+func BenchmarkClientIP(b *testing.B) {
+ r := httptest.NewRequest(http.MethodGet, "/", nil)
+ r.RemoteAddr = "10.0.0.1:12345"
+ r.Header.Set("X-Real-IP", "203.0.113.50")
+ for b.Loop() {
+ clientIP(r, true)
+ }
+}
diff --git a/internal/api/response.go b/internal/api/response.go
index a040b74..15dd1ff 100644
--- a/internal/api/response.go
+++ b/internal/api/response.go
@@ -1,4 +1,3 @@
-// Package api provides the JSON REST API server for Koopa.
package api
import (
@@ -23,12 +22,16 @@ type envelope struct {
// WriteJSON writes data wrapped in an envelope as JSON.
// For nil data, writes no body (use with 204 No Content).
-func WriteJSON(w http.ResponseWriter, status int, data any) {
+// If logger is nil, falls back to slog.Default().
+func WriteJSON(w http.ResponseWriter, status int, data any, logger *slog.Logger) {
if data != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(envelope{Data: data}); err != nil {
- slog.Error("failed to encode JSON response", "error", err)
+ if logger == nil {
+ logger = slog.Default()
+ }
+ logger.Error("encoding JSON response", "error", err)
}
} else {
w.WriteHeader(status)
@@ -36,10 +39,14 @@ func WriteJSON(w http.ResponseWriter, status int, data any) {
}
// WriteError writes a JSON error response wrapped in an envelope.
-func WriteError(w http.ResponseWriter, status int, code, message string) {
+// If logger is nil, falls back to slog.Default().
+func WriteError(w http.ResponseWriter, status int, code, message string, logger *slog.Logger) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(envelope{Error: &Error{Code: code, Message: message}}); err != nil {
- slog.Error("failed to encode JSON error response", "error", err)
+ if logger == nil {
+ logger = slog.Default()
+ }
+ logger.Error("encoding JSON error response", "error", err)
}
}
diff --git a/internal/api/response_test.go b/internal/api/response_test.go
index 37243cd..139fb4d 100644
--- a/internal/api/response_test.go
+++ b/internal/api/response_test.go
@@ -14,10 +14,10 @@ func decodeData(t *testing.T, w *httptest.ResponseRecorder, target any) {
Data json.RawMessage `json:"data"`
}
if err := json.NewDecoder(w.Body).Decode(&env); err != nil {
- t.Fatalf("failed to decode envelope: %v", err)
+ t.Fatalf("decoding envelope: %v", err)
}
if err := json.Unmarshal(env.Data, target); err != nil {
- t.Fatalf("failed to decode envelope data: %v", err)
+ t.Fatalf("decoding envelope data: %v", err)
}
}
@@ -28,7 +28,7 @@ func decodeErrorEnvelope(t *testing.T, w *httptest.ResponseRecorder) Error {
Error *Error `json:"error"`
}
if err := json.NewDecoder(w.Body).Decode(&env); err != nil {
- t.Fatalf("failed to decode envelope: %v", err)
+ t.Fatalf("decoding envelope: %v", err)
}
if env.Error == nil {
t.Fatal("expected error in envelope, got nil")
@@ -40,7 +40,7 @@ func TestWriteJSON(t *testing.T) {
w := httptest.NewRecorder()
data := map[string]string{"key": "value"}
- WriteJSON(w, http.StatusOK, data)
+ WriteJSON(w, http.StatusOK, data, nil)
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
t.Errorf("WriteJSON() Content-Type = %q, want %q", ct, "application/json")
@@ -62,7 +62,7 @@ func TestWriteJSON_Envelope(t *testing.T) {
w := httptest.NewRecorder()
items := []string{"a", "b"}
- WriteJSON(w, http.StatusOK, items)
+ WriteJSON(w, http.StatusOK, items, nil)
var body []string
decodeData(t, w, &body)
@@ -75,7 +75,7 @@ func TestWriteJSON_Envelope(t *testing.T) {
func TestWriteJSON_NilData(t *testing.T) {
w := httptest.NewRecorder()
- WriteJSON(w, http.StatusNoContent, nil)
+ WriteJSON(w, http.StatusNoContent, nil, nil)
if w.Code != http.StatusNoContent {
t.Errorf("WriteJSON(nil) status = %d, want %d", w.Code, http.StatusNoContent)
@@ -89,7 +89,7 @@ func TestWriteJSON_NilData(t *testing.T) {
func TestWriteJSON_CustomStatus(t *testing.T) {
w := httptest.NewRecorder()
- WriteJSON(w, http.StatusCreated, map[string]string{"id": "123"})
+ WriteJSON(w, http.StatusCreated, map[string]string{"id": "123"}, nil)
if w.Code != http.StatusCreated {
t.Errorf("WriteJSON() status = %d, want %d", w.Code, http.StatusCreated)
@@ -99,7 +99,7 @@ func TestWriteJSON_CustomStatus(t *testing.T) {
func TestWriteError(t *testing.T) {
w := httptest.NewRecorder()
- WriteError(w, http.StatusBadRequest, "invalid_input", "name is required")
+ WriteError(w, http.StatusBadRequest, "invalid_input", "name is required", nil)
if w.Code != http.StatusBadRequest {
t.Errorf("WriteError() status = %d, want %d", w.Code, http.StatusBadRequest)
diff --git a/internal/api/server.go b/internal/api/server.go
index c75ac2f..df0f20c 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -5,16 +5,14 @@ import (
"log/slog"
"net/http"
- "github.com/firebase/genkit/go/genkit"
- "github.com/koopa0/koopa/internal/agent/chat"
+ "github.com/koopa0/koopa/internal/chat"
"github.com/koopa0/koopa/internal/session"
)
// ServerConfig contains configuration for creating the API server.
type ServerConfig struct {
Logger *slog.Logger
- Genkit *genkit.Genkit // Optional: nil disables AI title generation
- ModelName string // Provider-qualified model name (e.g., "googleai/gemini-2.5-flash")
+ ChatAgent *chat.Agent // Optional: nil disables AI title generation
ChatFlow *chat.Flow // Optional: nil enables simulation mode
SessionStore *session.Store // Required
CSRFSecret []byte // Required: 32+ bytes
@@ -50,11 +48,10 @@ func NewServer(cfg ServerConfig) (*Server, error) {
}
ch := &chatHandler{
- logger: logger,
- genkit: cfg.Genkit,
- modelName: cfg.ModelName,
- flow: cfg.ChatFlow,
- sessions: sm,
+ logger: logger,
+ agent: cfg.ChatAgent,
+ flow: cfg.ChatFlow,
+ sessions: sm,
}
mux := http.NewServeMux()
@@ -79,7 +76,7 @@ func NewServer(cfg ServerConfig) (*Server, error) {
// Build middleware stack: Recovery → Logging → RateLimit → CORS → Session → CSRF → Routes
var handler http.Handler = mux
handler = csrfMiddleware(sm, logger)(handler)
- handler = sessionMiddleware(sm, logger)(handler)
+ handler = sessionMiddleware(sm)(handler)
handler = corsMiddleware(cfg.CORSOrigins)(handler)
handler = rateLimitMiddleware(rl, cfg.TrustProxy, logger)(handler)
handler = loggingMiddleware(logger)(handler)
diff --git a/internal/api/session.go b/internal/api/session.go
index 4f542ba..d062316 100644
--- a/internal/api/session.go
+++ b/internal/api/session.go
@@ -19,12 +19,18 @@ import (
// Sentinel errors for session/CSRF operations.
var (
+ // ErrSessionCookieNotFound is returned when the session cookie is absent from the request.
ErrSessionCookieNotFound = errors.New("session cookie not found")
- ErrSessionInvalid = errors.New("session ID invalid")
- ErrCSRFRequired = errors.New("csrf token required")
- ErrCSRFInvalid = errors.New("csrf token invalid")
- ErrCSRFExpired = errors.New("csrf token expired")
- ErrCSRFMalformed = errors.New("csrf token malformed")
+ // ErrSessionInvalid is returned when the session cookie value is not a valid UUID.
+ ErrSessionInvalid = errors.New("session ID invalid")
+ // ErrCSRFRequired is returned when a state-changing request has no CSRF token.
+ ErrCSRFRequired = errors.New("csrf token required")
+ // ErrCSRFInvalid is returned when the CSRF token signature does not match.
+ ErrCSRFInvalid = errors.New("csrf token invalid")
+ // ErrCSRFExpired is returned when the CSRF token timestamp exceeds csrfTokenTTL.
+ ErrCSRFExpired = errors.New("csrf token expired")
+ // ErrCSRFMalformed is returned when the CSRF token format cannot be parsed.
+ ErrCSRFMalformed = errors.New("csrf token malformed")
)
// Pre-session CSRF token prefix to distinguish from session-bound tokens.
@@ -32,10 +38,11 @@ const preSessionPrefix = "pre:"
// Cookie and CSRF configuration.
const (
- sessionCookieName = "sid"
- csrfTokenTTL = 24 * time.Hour
- sessionMaxAge = 30 * 24 * 3600 // 30 days in seconds
- csrfClockSkew = 5 * time.Minute
+ sessionCookieName = "sid"
+ csrfTokenTTL = 24 * time.Hour
+ sessionMaxAge = 30 * 24 * 3600 // 30 days in seconds
+ csrfClockSkew = 5 * time.Minute
+ messagesDefaultLimit = 100
)
// sessionManager handles session cookies and CSRF token operations.
@@ -46,31 +53,6 @@ type sessionManager struct {
logger *slog.Logger
}
-// GetOrCreate retrieves session ID from cookie or creates a new session.
-// On success, sets/refreshes the session cookie and returns the session UUID.
-func (sm *sessionManager) GetOrCreate(w http.ResponseWriter, r *http.Request) (uuid.UUID, error) {
- // Try to get existing session from cookie
- cookie, err := r.Cookie(sessionCookieName)
- if err == nil && cookie.Value != "" {
- sessionID, parseErr := uuid.Parse(cookie.Value)
- if parseErr == nil {
- if _, getErr := sm.store.Session(r.Context(), sessionID); getErr == nil {
- sm.setCookie(w, sessionID)
- return sessionID, nil
- }
- }
- }
-
- // Create new session
- sess, err := sm.store.CreateSession(r.Context(), "", "", "")
- if err != nil {
- return uuid.Nil, fmt.Errorf("create session: %w", err)
- }
-
- sm.setCookie(w, sess.ID)
- return sess.ID, nil
-}
-
// ID extracts session ID from cookie without creating a new session.
func (*sessionManager) ID(r *http.Request) (uuid.UUID, error) {
cookie, err := r.Cookie(sessionCookieName)
@@ -197,24 +179,24 @@ func (sm *sessionManager) CheckPreSessionCSRF(token string) error {
func (sm *sessionManager) requireOwnership(w http.ResponseWriter, r *http.Request) (uuid.UUID, bool) {
idStr := r.PathValue("id")
if idStr == "" {
- WriteError(w, http.StatusBadRequest, "missing_id", "session ID required")
+ WriteError(w, http.StatusBadRequest, "missing_id", "session ID required", sm.logger)
return uuid.Nil, false
}
targetID, err := uuid.Parse(idStr)
if err != nil {
- WriteError(w, http.StatusBadRequest, "invalid_id", "invalid session ID")
+ WriteError(w, http.StatusBadRequest, "invalid_id", "invalid session ID", sm.logger)
return uuid.Nil, false
}
- ownerID, ok := SessionIDFromContext(r.Context())
+ ownerID, ok := sessionIDFromContext(r.Context())
if !ok || ownerID != targetID {
sm.logger.Warn("session ownership check failed",
"target", targetID,
"path", r.URL.Path,
"remote_addr", r.RemoteAddr,
)
- WriteError(w, http.StatusForbidden, "forbidden", "session access denied")
+ WriteError(w, http.StatusForbidden, "forbidden", "session access denied", sm.logger)
return uuid.Nil, false
}
@@ -240,13 +222,13 @@ func (sm *sessionManager) csrfToken(w http.ResponseWriter, r *http.Request) {
if err == nil {
WriteJSON(w, http.StatusOK, map[string]string{
"csrfToken": sm.NewCSRFToken(sessionID),
- })
+ }, sm.logger)
return
}
WriteJSON(w, http.StatusOK, map[string]string{
"csrfToken": sm.NewPreSessionCSRFToken(),
- })
+ }, sm.logger)
}
// listSessions handles GET /api/v1/sessions — returns sessions owned by the caller.
@@ -258,21 +240,21 @@ func (sm *sessionManager) listSessions(w http.ResponseWriter, r *http.Request) {
UpdatedAt string `json:"updatedAt"`
}
- sessionID, ok := SessionIDFromContext(r.Context())
+ sessionID, ok := sessionIDFromContext(r.Context())
if !ok {
// No session cookie — return empty list
- WriteJSON(w, http.StatusOK, []sessionItem{})
+ WriteJSON(w, http.StatusOK, []sessionItem{}, sm.logger)
return
}
sess, err := sm.store.Session(r.Context(), sessionID)
if err != nil {
- if errors.Is(err, session.ErrSessionNotFound) {
- WriteJSON(w, http.StatusOK, []sessionItem{})
+ if errors.Is(err, session.ErrNotFound) {
+ WriteJSON(w, http.StatusOK, []sessionItem{}, sm.logger)
return
}
- sm.logger.Error("failed to get session", "error", err, "session_id", sessionID)
- WriteError(w, http.StatusInternalServerError, "list_failed", "failed to list sessions")
+ sm.logger.Error("getting session", "error", err, "session_id", sessionID)
+ WriteError(w, http.StatusInternalServerError, "list_failed", "failed to list sessions", sm.logger)
return
}
@@ -282,15 +264,15 @@ func (sm *sessionManager) listSessions(w http.ResponseWriter, r *http.Request) {
Title: sess.Title,
UpdatedAt: sess.UpdatedAt.Format(time.RFC3339),
},
- })
+ }, sm.logger)
}
// createSession handles POST /api/v1/sessions — creates a new session.
func (sm *sessionManager) createSession(w http.ResponseWriter, r *http.Request) {
- sess, err := sm.store.CreateSession(r.Context(), "", "", "")
+ sess, err := sm.store.CreateSession(r.Context(), "")
if err != nil {
- sm.logger.Error("failed to create session", "error", err)
- WriteError(w, http.StatusInternalServerError, "create_failed", "failed to create session")
+ sm.logger.Error("creating session", "error", err)
+ WriteError(w, http.StatusInternalServerError, "create_failed", "failed to create session", sm.logger)
return
}
@@ -299,7 +281,7 @@ func (sm *sessionManager) createSession(w http.ResponseWriter, r *http.Request)
WriteJSON(w, http.StatusCreated, map[string]string{
"id": sess.ID.String(),
"csrfToken": sm.NewCSRFToken(sess.ID),
- })
+ }, sm.logger)
}
// getSession handles GET /api/v1/sessions/{id} — returns a single session.
@@ -312,12 +294,12 @@ func (sm *sessionManager) getSession(w http.ResponseWriter, r *http.Request) {
sess, err := sm.store.Session(r.Context(), id)
if err != nil {
- if errors.Is(err, session.ErrSessionNotFound) {
- WriteError(w, http.StatusNotFound, "not_found", "session not found")
+ if errors.Is(err, session.ErrNotFound) {
+ WriteError(w, http.StatusNotFound, "not_found", "session not found", sm.logger)
return
}
- sm.logger.Error("failed to get session", "error", err, "session_id", id)
- WriteError(w, http.StatusInternalServerError, "get_failed", "failed to get session")
+ sm.logger.Error("getting session", "error", err, "session_id", id)
+ WriteError(w, http.StatusInternalServerError, "get_failed", "failed to get session", sm.logger)
return
}
@@ -326,7 +308,7 @@ func (sm *sessionManager) getSession(w http.ResponseWriter, r *http.Request) {
"title": sess.Title,
"createdAt": sess.CreatedAt.Format(time.RFC3339),
"updatedAt": sess.UpdatedAt.Format(time.RFC3339),
- })
+ }, sm.logger)
}
// getSessionMessages handles GET /api/v1/sessions/{id}/messages — returns messages for a session.
@@ -337,10 +319,10 @@ func (sm *sessionManager) getSessionMessages(w http.ResponseWriter, r *http.Requ
return
}
- messages, err := sm.store.Messages(r.Context(), id, 100, 0)
+ messages, err := sm.store.Messages(r.Context(), id, messagesDefaultLimit, 0)
if err != nil {
- sm.logger.Error("failed to get messages", "error", err, "session_id", id)
- WriteError(w, http.StatusInternalServerError, "get_failed", "failed to get messages")
+ sm.logger.Error("getting messages", "error", err, "session_id", id)
+ WriteError(w, http.StatusInternalServerError, "get_failed", "failed to get messages", sm.logger)
return
}
@@ -369,7 +351,7 @@ func (sm *sessionManager) getSessionMessages(w http.ResponseWriter, r *http.Requ
}
}
- WriteJSON(w, http.StatusOK, items)
+ WriteJSON(w, http.StatusOK, items, sm.logger)
}
// deleteSession handles DELETE /api/v1/sessions/{id} — deletes a session.
@@ -381,10 +363,10 @@ func (sm *sessionManager) deleteSession(w http.ResponseWriter, r *http.Request)
}
if err := sm.store.DeleteSession(r.Context(), id); err != nil {
- sm.logger.Error("failed to delete session", "error", err, "session_id", id)
- WriteError(w, http.StatusInternalServerError, "delete_failed", "failed to delete session")
+ sm.logger.Error("deleting session", "error", err, "session_id", id)
+ WriteError(w, http.StatusInternalServerError, "delete_failed", "failed to delete session", sm.logger)
return
}
- WriteJSON(w, http.StatusOK, map[string]string{"status": "deleted"})
+ WriteJSON(w, http.StatusOK, map[string]string{"status": "deleted"}, sm.logger)
}
diff --git a/internal/api/session_test.go b/internal/api/session_test.go
index 118ba4c..efdf2b5 100644
--- a/internal/api/session_test.go
+++ b/internal/api/session_test.go
@@ -81,9 +81,9 @@ func TestCSRFToken_Malformed(t *testing.T) {
name string
token string
}{
- {"empty", ""},
- {"no_colon", "justtext"},
- {"bad_timestamp", "notanumber:signature"},
+ {name: "empty", token: ""},
+ {name: "no_colon", token: "justtext"},
+ {name: "bad_timestamp", token: "notanumber:signature"},
}
for _, tt := range tests {
@@ -245,10 +245,6 @@ func TestGetSessionMessages_InvalidUUID(t *testing.T) {
}
}
-// ============================================================================
-// Session Ownership Tests
-// ============================================================================
-
func TestRequireOwnership_NoSession(t *testing.T) {
sm := newTestSessionManager()
targetID := uuid.New()
@@ -348,3 +344,59 @@ func TestListSessions_NoSession(t *testing.T) {
t.Errorf("listSessions(no session) returned %d items, want 0", len(items))
}
}
+
+func FuzzCheckCSRF(f *testing.F) {
+ sm := newTestSessionManager()
+ sessionID := uuid.New()
+ validToken := sm.NewCSRFToken(sessionID)
+
+ f.Add(sessionID.String(), validToken)
+ f.Add(sessionID.String(), "")
+ f.Add(sessionID.String(), "notanumber:signature")
+ f.Add(sessionID.String(), "12345:badsig")
+ f.Add(uuid.New().String(), validToken)
+ f.Add("", "")
+ f.Add("not-a-uuid", "1234:sig")
+
+ f.Fuzz(func(t *testing.T, sessionIDStr, token string) {
+ id, err := uuid.Parse(sessionIDStr)
+ if err != nil {
+ return
+ }
+ _ = sm.CheckCSRF(id, token) // must not panic
+ })
+}
+
+func FuzzCheckPreSessionCSRF(f *testing.F) {
+ sm := newTestSessionManager()
+ validToken := sm.NewPreSessionCSRFToken()
+
+ f.Add(validToken)
+ f.Add("")
+ f.Add("pre:")
+ f.Add("pre:nonce:notanumber:sig")
+ f.Add("pre:abc:12345:sig")
+ f.Add("notpre:abc:123:sig")
+ f.Add("pre:abc:12345:sig:extra")
+
+ f.Fuzz(func(t *testing.T, token string) {
+ _ = sm.CheckPreSessionCSRF(token) // must not panic
+ })
+}
+
+func BenchmarkNewCSRFToken(b *testing.B) {
+ sm := newTestSessionManager()
+ sessionID := uuid.New()
+ for b.Loop() {
+ sm.NewCSRFToken(sessionID)
+ }
+}
+
+func BenchmarkCheckCSRF(b *testing.B) {
+ sm := newTestSessionManager()
+ sessionID := uuid.New()
+ token := sm.NewCSRFToken(sessionID)
+ for b.Loop() {
+ _ = sm.CheckCSRF(sessionID, token)
+ }
+}
diff --git a/internal/app/app.go b/internal/app/app.go
index a13ef60..ad26490 100644
--- a/internal/app/app.go
+++ b/internal/app/app.go
@@ -1,108 +1,97 @@
-// Package app provides application initialization and dependency injection.
+// Package app provides application initialization and lifecycle management.
//
-// App is the core container that orchestrates all application components with struct-based DI.
-// It initializes Genkit, database connection, DocStore (via Genkit PostgreSQL Plugin),
-// and creates the agent with all necessary dependencies.
+// App is the core container that holds all application components.
+// Created by Setup, released by Close.
package app
import (
- "context"
"fmt"
"log/slog"
+ "sync"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/plugins/postgresql"
"github.com/jackc/pgx/v5/pgxpool"
- "github.com/koopa0/koopa/internal/agent/chat"
+
+ "github.com/koopa0/koopa/internal/chat"
"github.com/koopa0/koopa/internal/config"
"github.com/koopa0/koopa/internal/security"
"github.com/koopa0/koopa/internal/session"
- "golang.org/x/sync/errgroup"
+ "github.com/koopa0/koopa/internal/tools"
)
-// App is the core application container.
+// App is the application instance.
+// Created by Setup, closed by Close. All entry points (CLI, HTTP, MCP)
+// use this struct to access shared resources.
type App struct {
- // Configuration
- Config *config.Config
-
+ Config *config.Config
Genkit *genkit.Genkit
Embedder ai.Embedder
- DBPool *pgxpool.Pool // Database connection pool
- DocStore *postgresql.DocStore // Genkit PostgreSQL DocStore for indexing
- Retriever ai.Retriever // Genkit Retriever for searching
- SessionStore *session.Store // Session persistence (concrete type, not interface)
- PathValidator *security.Path // Path validator for security
- Tools []ai.Tool // Pre-registered tools
+ DBPool *pgxpool.Pool
+ DocStore *postgresql.DocStore
+ Retriever ai.Retriever
+ SessionStore *session.Store
+ PathValidator *security.Path
+ Tools []ai.Tool // Pre-registered Genkit tools (for chat agent)
- // Lifecycle management
- ctx context.Context
- cancel context.CancelFunc
+ // Concrete toolsets shared by CLI and MCP entry points.
+ File *tools.File
+ System *tools.System
+ Network *tools.Network
+ Knowledge *tools.Knowledge // nil if retriever unavailable
- // errgroup for background goroutine lifecycle management
- eg *errgroup.Group
+ // Lifecycle management (unexported)
+ cancel func()
+ dbCleanup func()
+ otelCleanup func()
+ closeOnce sync.Once
}
-// Close gracefully shuts down App-managed resources.
-// Cleanup function handles DB pool and OTel (single owner principle).
+// Close gracefully shuts down all resources. Safe for concurrent and
+// repeated calls — cleanup runs exactly once via sync.Once.
//
// Shutdown order:
-// 1. Cancel context (signals background tasks to stop)
-// 2. Wait for background goroutines (errgroup)
+// 1. Cancel context (signals background tasks to stop)
+// 2. Close DB pool
+// 3. Flush OTel spans
func (a *App) Close() error {
- slog.Info("shutting down application")
-
- // 1. Cancel context (signals all background tasks to stop)
- if a.cancel != nil {
- a.cancel()
- }
+ a.closeOnce.Do(func() {
+ slog.Info("shutting down application")
- // 2. Wait for background goroutines
- if a.eg != nil {
- if err := a.eg.Wait(); err != nil {
- return fmt.Errorf("background task error: %w", err)
+ // 1. Cancel context (signals all background tasks to stop)
+ if a.cancel != nil {
+ a.cancel()
}
- slog.Debug("background tasks completed")
- }
- // Pool is closed by cleanup function, NOT here (single owner principle)
- return nil
-}
+ // 2. Close DB pool
+ if a.dbCleanup != nil {
+ a.dbCleanup()
+ }
-// Wait blocks until all background goroutines complete.
-// This is useful for waiting on background tasks without closing resources.
-func (a *App) Wait() error {
- if a.eg == nil {
- return nil
- }
- if err := a.eg.Wait(); err != nil {
- return fmt.Errorf("errgroup failed: %w", err)
- }
+ // 3. Flush OTel spans
+ if a.otelCleanup != nil {
+ a.otelCleanup()
+ }
+ })
return nil
}
-// Go starts a new background goroutine tracked by the app's errgroup.
-// Use this for any background tasks that should be waited on during shutdown.
-func (a *App) Go(f func() error) {
- if a.eg != nil {
- a.eg.Go(f)
- }
-}
-
// CreateAgent creates a Chat Agent using pre-registered tools.
-// Tools are registered once at App construction (not lazily).
-// InitializeApp guarantees all dependencies are non-nil.
-func (a *App) CreateAgent(_ context.Context) (*chat.Chat, error) {
- // No nil checks - InitializeApp guarantees injection
- return chat.New(chat.Config{
+// Tools are registered once at Setup (not lazily).
+// Setup guarantees all dependencies are non-nil.
+func (a *App) CreateAgent() (*chat.Agent, error) {
+ agent, err := chat.New(chat.Config{
Genkit: a.Genkit,
- Retriever: a.Retriever,
SessionStore: a.SessionStore,
Logger: slog.Default(),
Tools: a.Tools,
ModelName: a.Config.FullModelName(),
MaxTurns: a.Config.MaxTurns,
- RAGTopK: a.Config.RAGTopK,
Language: a.Config.Language,
})
+ if err != nil {
+ return nil, fmt.Errorf("creating chat agent: %w", err)
+ }
+ return agent, nil
}
diff --git a/internal/app/app_test.go b/internal/app/app_test.go
index 8247a40..dd8049b 100644
--- a/internal/app/app_test.go
+++ b/internal/app/app_test.go
@@ -2,115 +2,81 @@ package app
import (
"context"
- "fmt"
"os"
"path/filepath"
- "runtime"
"testing"
"time"
"github.com/firebase/genkit/go/genkit"
+
"github.com/koopa0/koopa/internal/config"
"github.com/koopa0/koopa/internal/security"
- "golang.org/x/sync/errgroup"
+ "github.com/koopa0/koopa/internal/testutil"
)
-// ============================================================================
-// App.Close() Tests
-// ============================================================================
-
func TestApp_Close(t *testing.T) {
tests := []struct {
name string
- setupApp func() *App
+ setup func() *App
expectError bool
}{
{
name: "close with cancel function",
- setupApp: func() *App {
- ctx, cancel := context.WithCancel(context.Background())
- return &App{
- ctx: ctx,
- cancel: cancel,
- DBPool: nil, // Don't mock pgxpool as it causes panic on close
- }
+ setup: func() *App {
+ _, cancel := context.WithCancel(context.Background())
+ return &App{cancel: cancel}
},
- expectError: false,
},
{
- name: "close with nil DBPool",
- setupApp: func() *App {
- ctx, cancel := context.WithCancel(context.Background())
- return &App{
- ctx: ctx,
- cancel: cancel,
- DBPool: nil,
- }
+ name: "close with nil cancel",
+ setup: func() *App {
+ return &App{cancel: nil}
},
- expectError: false,
},
{
- name: "close with nil cancel function",
- setupApp: func() *App {
+ name: "close with cleanup functions",
+ setup: func() *App {
return &App{
- ctx: context.Background(),
- cancel: nil,
- DBPool: nil,
+ dbCleanup: func() {},
+ otelCleanup: func() {},
}
},
- expectError: false,
},
{
name: "close minimal app",
- setupApp: func() *App {
+ setup: func() *App {
return &App{}
},
- expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- app := tt.setupApp()
- err := app.Close()
+ a := tt.setup()
+ err := a.Close()
if tt.expectError && err == nil {
t.Error("expected error but got none")
}
-
if !tt.expectError && err != nil {
- t.Errorf("unexpected error: %v", err)
- }
-
- // Verify context was canceled if cancel function existed
- if app.cancel != nil && app.ctx != nil {
- select {
- case <-app.ctx.Done():
- // Context was properly canceled
- default:
- t.Error("context was not canceled")
- }
+ t.Errorf("Close() unexpected error: %v", err)
}
})
}
}
-// ============================================================================
-// App Struct Field Tests
-// ============================================================================
-
func TestApp_Fields(t *testing.T) {
- t.Run("app with all fields set", func(t *testing.T) {
+ t.Run("all fields set", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
g := genkit.Init(ctx)
pathValidator, err := security.NewPath([]string{"."})
if err != nil {
- t.Fatalf("failed to create path validator: %v", err)
+ t.Fatalf("security.NewPath() error: %v", err)
}
- app := &App{
+ a := &App{
Config: &config.Config{
ModelName: "gemini-2.0-flash-exp",
},
@@ -120,90 +86,111 @@ func TestApp_Fields(t *testing.T) {
DocStore: nil,
Retriever: nil,
PathValidator: pathValidator,
- ctx: ctx,
cancel: cancel,
}
- // Verify fields are set
- if app.Config == nil {
- t.Error("expected Config to be set")
+ if a.Config == nil {
+ t.Error("App.Config = nil, want non-nil")
}
- if app.Genkit == nil {
- t.Error("expected Genkit to be set")
+ if a.Genkit == nil {
+ t.Error("App.Genkit = nil, want non-nil")
}
- if app.PathValidator == nil {
- t.Error("expected PathValidator to be set")
+ if a.PathValidator == nil {
+ t.Error("App.PathValidator = nil, want non-nil")
}
- if app.ctx == nil {
- t.Error("expected ctx to be set")
- }
- if app.cancel == nil {
- t.Error("expected cancel to be set")
+ if a.cancel == nil {
+ t.Error("App.cancel = nil, want non-nil")
}
})
}
-// ============================================================================
-// Nil Safety Tests
-// ============================================================================
-
func TestApp_NilSafety(t *testing.T) {
tests := []struct {
name string
- app *App
+ a *App
}{
{
- name: "close nil app fields",
- app: &App{},
+ name: "close nil fields",
+ a: &App{},
},
{
- name: "close with only ctx",
- app: &App{
- ctx: context.Background(),
- },
+ name: "close with only cancel",
+ a: &App{cancel: func() {}},
},
{
- name: "close with only cancel",
- app: &App{
- cancel: func() {},
- },
+ name: "close with only cleanup",
+ a: &App{dbCleanup: func() {}, otelCleanup: func() {}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- // Should not panic
- err := tt.app.Close()
+ err := tt.a.Close()
if err != nil {
- t.Errorf("unexpected error: %v", err)
+ t.Errorf("Close() unexpected error: %v", err)
}
})
}
}
-// ============================================================================
-// InitializeApp Integration Tests
-// ============================================================================
+func TestApp_Close_ShutdownOrder(t *testing.T) {
+ t.Run("cleanup called after cancel", func(t *testing.T) {
+ var order []string
+
+ a := &App{
+ cancel: func() { order = append(order, "cancel") },
+ dbCleanup: func() { order = append(order, "db") },
+ otelCleanup: func() { order = append(order, "otel") },
+ }
+
+ _ = a.Close()
+
+ if len(order) != 3 {
+ t.Fatalf("Close() operations = %d, want 3: %v", len(order), order)
+ }
+ if order[0] != "cancel" {
+ t.Errorf("order[0] = %q, want %q", order[0], "cancel")
+ }
+ if order[1] != "db" {
+ t.Errorf("order[1] = %q, want %q", order[1], "db")
+ }
+ if order[2] != "otel" {
+ t.Errorf("order[2] = %q, want %q", order[2], "otel")
+ }
+ })
+
+ t.Run("idempotent close", func(t *testing.T) {
+ callCount := 0
+ a := &App{
+ cancel: func() { callCount++ },
+ dbCleanup: func() { callCount++ },
+ otelCleanup: func() { callCount++ },
+ }
+
+ _ = a.Close()
+ _ = a.Close() // second call should be no-op
-func TestInitializeApp_Success(t *testing.T) {
+ if callCount != 3 {
+ t.Errorf("Close() twice: call count = %d, want 3 (second Close should be no-op)", callCount)
+ }
+ })
+}
+
+func TestSetup_Success(t *testing.T) {
if testing.Short() {
- t.Skip("Skipping integration test in short mode")
+ t.Skip("skipping integration test in short mode")
}
-
- // Check required environment variables
if os.Getenv("GEMINI_API_KEY") == "" {
- t.Skip("GEMINI_API_KEY not set - skipping integration test")
+ t.Skip("GEMINI_API_KEY not set")
}
-
- // Skip if database is not available (PostgreSQL required)
if os.Getenv("DATABASE_URL") == "" {
- t.Skip("DATABASE_URL not set - skipping integration test")
+ t.Skip("DATABASE_URL not set")
}
ctx := context.Background()
cfg := &config.Config{
ModelName: "gemini-2.0-flash-exp",
- EmbedderModel: "text-embedding-004",
+ EmbedderModel: "gemini-embedding-001",
Temperature: 0.7,
MaxTokens: 8192,
PostgresHost: "localhost",
@@ -215,68 +202,58 @@ func TestInitializeApp_Success(t *testing.T) {
PromptDir: getPromptsDir(t),
}
- // Test: InitializeApp should successfully create all components
- app, cleanup, err := InitializeApp(ctx, cfg)
+ a, err := Setup(ctx, cfg)
if err != nil {
- t.Fatalf("InitializeApp failed: %v", err)
+ t.Fatalf("Setup() error: %v", err)
}
- defer cleanup()
- defer func() { _ = app.Close() }()
+ defer func() { _ = a.Close() }()
- // Verify all components are initialized
- if app == nil {
- t.Fatal("expected non-nil app")
- return
+ if a.Config == nil {
+ t.Error("Setup().Config = nil, want non-nil")
}
- if app.Config == nil {
- t.Error("expected Config to be set")
+ if a.Genkit == nil {
+ t.Error("Setup().Genkit = nil, want non-nil")
}
- if app.Genkit == nil {
- t.Error("expected Genkit to be set")
+ if a.Embedder == nil {
+ t.Error("Setup().Embedder = nil, want non-nil")
}
- if app.Embedder == nil {
- t.Error("expected Embedder to be set")
+ if a.DBPool == nil {
+ t.Error("Setup().DBPool = nil, want non-nil")
}
- if app.DBPool == nil {
- t.Error("expected DBPool to be set")
+ if a.DocStore == nil {
+ t.Error("Setup().DocStore = nil, want non-nil")
}
- if app.DocStore == nil {
- t.Error("expected DocStore to be set")
+ if a.Retriever == nil {
+ t.Error("Setup().Retriever = nil, want non-nil")
}
- if app.Retriever == nil {
- t.Error("expected Retriever to be set")
+ if a.SessionStore == nil {
+ t.Error("Setup().SessionStore = nil, want non-nil")
}
- if app.SessionStore == nil {
- t.Error("expected SessionStore to be set")
+ if a.PathValidator == nil {
+ t.Error("Setup().PathValidator = nil, want non-nil")
}
- if app.PathValidator == nil {
- t.Error("expected PathValidator to be set")
- }
- // Note: SystemIndexer was removed - system knowledge indexing is now
- // done via rag.IndexSystemKnowledge() function in newApp()
// Verify database connection is functional
- if err := app.DBPool.Ping(ctx); err != nil {
- t.Errorf("database ping failed: %v", err)
+ if err := a.DBPool.Ping(ctx); err != nil {
+ t.Errorf("pinging database: %v", err)
}
}
-func TestInitializeApp_CleanupFunction(t *testing.T) {
+func TestSetup_CleanupOnClose(t *testing.T) {
if testing.Short() {
- t.Skip("Skipping integration test in short mode")
+ t.Skip("skipping integration test in short mode")
}
-
if os.Getenv("GEMINI_API_KEY") == "" {
- t.Skip("GEMINI_API_KEY not set - skipping integration test")
+ t.Skip("GEMINI_API_KEY not set")
}
if os.Getenv("DATABASE_URL") == "" {
- t.Skip("DATABASE_URL not set - skipping integration test")
+ t.Skip("DATABASE_URL not set")
}
ctx := context.Background()
cfg := &config.Config{
ModelName: "gemini-2.0-flash-exp",
- EmbedderModel: "text-embedding-004",
+ EmbedderModel: "gemini-embedding-001",
PostgresHost: "localhost",
PostgresPort: 5432,
PostgresUser: "postgres",
@@ -286,235 +263,59 @@ func TestInitializeApp_CleanupFunction(t *testing.T) {
PromptDir: getPromptsDir(t),
}
- app, cleanup, err := InitializeApp(ctx, cfg)
+ a, err := Setup(ctx, cfg)
if err != nil {
- t.Fatalf("InitializeApp failed: %v", err)
+ t.Fatalf("Setup() error: %v", err)
}
- // Test: cleanup function should close database pool
- cleanup()
+ pool := a.DBPool
+
+ // Close should release DB pool
+ if err := a.Close(); err != nil {
+ t.Fatalf("Close() error: %v", err)
+ }
// Verify pool is closed (ping should fail)
- if err := app.DBPool.Ping(ctx); err == nil {
- t.Error("expected database ping to fail after cleanup")
+ if err := pool.Ping(ctx); err == nil {
+ t.Error("pool.Ping() after Close() = nil, want error")
}
}
func TestProvidePathValidator_Success(t *testing.T) {
validator, err := providePathValidator()
-
if err != nil {
- t.Fatalf("providePathValidator failed: %v", err)
+ t.Fatalf("providePathValidator() error: %v", err)
}
if validator == nil {
- t.Fatal("expected non-nil path validator")
+ t.Fatal("providePathValidator() returned nil")
}
}
-// ============================================================================
-// Shutdown and Lifecycle Tests
-// ============================================================================
-
-// TestApp_ShutdownTimeout tests that shutdown completes within reasonable time.
-// Safety: Prevents hang during shutdown if background tasks don't respond to cancellation.
func TestApp_ShutdownTimeout(t *testing.T) {
t.Run("graceful shutdown completes quickly", func(t *testing.T) {
- ctx, cancel := context.WithCancel(context.Background())
- app := &App{
- ctx: ctx,
- cancel: cancel,
- DBPool: nil,
- }
+ _, cancel := context.WithCancel(context.Background())
+ a := &App{cancel: cancel}
- // Shutdown should complete quickly (no background tasks)
done := make(chan struct{})
go func() {
- _ = app.Close()
+ _ = a.Close()
close(done)
}()
select {
case <-done:
- // Success: shutdown completed
+ // Success
case <-time.After(5 * time.Second):
- t.Fatal("shutdown timed out - potential deadlock")
+ t.Fatal("shutdown timed out")
}
})
-
- t.Run("shutdown with background goroutine", func(t *testing.T) {
- ctx, cancel := context.WithCancel(context.Background())
- eg, egCtx := errgroup.WithContext(ctx)
-
- app := &App{
- ctx: ctx,
- cancel: cancel,
- DBPool: nil,
- eg: eg,
- }
-
- // Start a background task that respects context cancellation
- taskDone := make(chan struct{})
- app.Go(func() error {
- defer close(taskDone)
- <-egCtx.Done()
- return nil
- })
-
- // Shutdown should complete after background task exits
- done := make(chan struct{})
- go func() {
- _ = app.Close()
- close(done)
- }()
-
- select {
- case <-done:
- // Success: shutdown completed
- case <-time.After(5 * time.Second):
- t.Fatal("shutdown timed out with background goroutine")
- }
-
- // Verify background task was properly terminated
- select {
- case <-taskDone:
- // Task completed
- default:
- t.Error("background task was not properly terminated")
- }
- })
-
- t.Run("shutdown timeout safety", func(t *testing.T) {
- // This test documents the expected behavior:
- // If a background task doesn't respond to context cancellation,
- // shutdown will block. This is intentional to prevent data loss.
- //
- // In production, consider adding a hard timeout:
- // - Use context.WithTimeout for background tasks
- // - Or implement a watchdog timer in Close()
-
- ctx, cancel := context.WithCancel(context.Background())
- app := &App{
- ctx: ctx,
- cancel: cancel,
- }
-
- // Verify cancel is called during Close
- _ = app.Close()
-
- select {
- case <-ctx.Done():
- // Context was properly canceled
- default:
- t.Error("context was not canceled during shutdown")
- }
- })
-}
-
-// TestApp_Wait tests the Wait() method for background task completion.
-func TestApp_Wait(t *testing.T) {
- t.Run("wait with nil errgroup returns nil", func(t *testing.T) {
- app := &App{eg: nil}
- err := app.Wait()
- if err != nil {
- t.Errorf("expected nil error, got: %v", err)
- }
- })
-
- t.Run("wait blocks until tasks complete", func(t *testing.T) {
- ctx := context.Background()
- eg, _ := errgroup.WithContext(ctx)
-
- app := &App{eg: eg}
-
- taskStarted := make(chan struct{})
- taskDone := make(chan struct{})
-
- app.Go(func() error {
- close(taskStarted)
- time.Sleep(50 * time.Millisecond)
- close(taskDone)
- return nil
- })
-
- <-taskStarted // Wait for task to start
-
- // Wait should block until task completes
- err := app.Wait()
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
-
- select {
- case <-taskDone:
- // Task completed before Wait returned
- default:
- t.Error("Wait returned before task completed")
- }
- })
-}
-
-// TestApp_Go tests the Go() method for starting background tasks.
-func TestApp_Go(t *testing.T) {
- t.Run("go with nil errgroup does not panic", func(t *testing.T) {
- app := &App{eg: nil}
-
- // Should not panic
- app.Go(func() error {
- return nil
- })
- })
-
- t.Run("go tracks task in errgroup", func(t *testing.T) {
- ctx := context.Background()
- eg, _ := errgroup.WithContext(ctx)
-
- app := &App{eg: eg}
-
- executed := false
- app.Go(func() error {
- executed = true
- return nil
- })
-
- // Wait for task
- _ = app.Wait()
-
- if !executed {
- t.Error("task was not executed")
- }
- })
-}
-
-// ============================================================================
-// Test Helpers
-// ============================================================================
-
-// findProjectRoot finds the project root directory by looking for go.mod.
-func findProjectRoot() (string, error) {
- _, filename, _, ok := runtime.Caller(0)
- if !ok {
- return "", fmt.Errorf("could not determine caller filename")
- }
-
- dir := filepath.Dir(filename)
- for {
- if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
- return dir, nil
- }
- parent := filepath.Dir(dir)
- if parent == dir {
- return "", fmt.Errorf("project root (go.mod) not found")
- }
- dir = parent
- }
}
-// getPromptsDir returns the absolute path to the prompts directory.
func getPromptsDir(t *testing.T) string {
t.Helper()
- root, err := findProjectRoot()
+ root, err := testutil.FindProjectRoot()
if err != nil || root == "" {
- t.Skip("Could not find project root")
+ t.Skip("could not find project root")
}
return filepath.Join(root, "prompts")
}
diff --git a/internal/app/runtime.go b/internal/app/runtime.go
deleted file mode 100644
index c5fff33..0000000
--- a/internal/app/runtime.go
+++ /dev/null
@@ -1,87 +0,0 @@
-package app
-
-import (
- "context"
- "errors"
- "fmt"
- "log/slog"
-
- "github.com/koopa0/koopa/internal/agent/chat"
- "github.com/koopa0/koopa/internal/config"
-)
-
-// ChatRuntime provides a fully initialized application runtime with all components ready to use.
-// It encapsulates the common initialization logic used by CLI and HTTP server entry points.
-// MCP mode uses InitializeApp directly (no chat flow needed).
-// Implements io.Closer for resource cleanup.
-type ChatRuntime struct {
- App *App
- Flow *chat.Flow
- cleanup func() // cleanup (unexported) - handles DB pool, OTel
-}
-
-// Close releases all resources. Implements io.Closer.
-// Shutdown order: App.Close (goroutines) → cleanup (DB pool, OTel).
-func (r *ChatRuntime) Close() error {
- var errs []error
-
- // 1. App shutdown (cancel context, wait for goroutines)
- if r.App != nil {
- if err := r.App.Close(); err != nil {
- errs = append(errs, fmt.Errorf("app close: %w", err))
- }
- }
-
- // 2. Cleanup (DB pool, OTel)
- if r.cleanup != nil {
- r.cleanup()
- }
-
- return errors.Join(errs...)
-}
-
-// NewChatRuntime creates a fully initialized runtime with all components ready for use.
-// This is the recommended way to initialize Koopa for CLI and HTTP entry points.
-//
-// Usage:
-//
-// runtime, err := app.NewChatRuntime(ctx, cfg)
-// if err != nil { ... }
-// defer runtime.Close() // Single cleanup method (implements io.Closer)
-// // Use runtime.Flow for agent interactions
-func NewChatRuntime(ctx context.Context, cfg *config.Config) (*ChatRuntime, error) {
- // Initialize application
- application, cleanup, err := InitializeApp(ctx, cfg)
- if err != nil {
- return nil, fmt.Errorf("failed to initialize application: %w", err)
- }
-
- // Create Chat Agent (uses pre-registered tools)
- chatAgent, err := application.CreateAgent(ctx)
- if err != nil {
- // Must close application first (stops background goroutines)
- // then cleanup (closes DB pool, OTel)
- if closeErr := application.Close(); closeErr != nil {
- slog.Warn("app close failed during CreateAgent recovery", "error", closeErr)
- }
- cleanup()
- return nil, fmt.Errorf("failed to create agent: %w", err)
- }
-
- // Initialize Chat Flow (singleton pattern with explicit lifecycle)
- chatFlow, err := chat.InitFlow(application.Genkit, chatAgent)
- if err != nil {
- // InitFlow failed (likely called twice) - cleanup and return error
- if closeErr := application.Close(); closeErr != nil {
- slog.Warn("app close failed during InitFlow recovery", "error", closeErr)
- }
- cleanup()
- return nil, fmt.Errorf("failed to init flow: %w", err)
- }
-
- return &ChatRuntime{
- App: application,
- Flow: chatFlow,
- cleanup: cleanup,
- }, nil
-}
diff --git a/internal/app/runtime_test.go b/internal/app/runtime_test.go
deleted file mode 100644
index e34c138..0000000
--- a/internal/app/runtime_test.go
+++ /dev/null
@@ -1,177 +0,0 @@
-package app
-
-import (
- "errors"
- "testing"
-)
-
-// ============================================================================
-// ChatChatRuntime.Close() Tests
-// ============================================================================
-
-func TestChatRuntime_Close(t *testing.T) {
- t.Run("close with nil app", func(t *testing.T) {
- cleanupCalled := false
- r := &ChatRuntime{
- App: nil,
- cleanup: func() { cleanupCalled = true },
- }
-
- err := r.Close()
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
- if !cleanupCalled {
- t.Error("cleanup function should be called")
- }
- })
-
- t.Run("close with nil cleanup", func(t *testing.T) {
- r := &ChatRuntime{
- App: nil,
- cleanup: nil,
- }
-
- err := r.Close()
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
- })
-
- t.Run("close propagates app error", func(t *testing.T) {
- // Create an app that will return an error on Close
- app := &App{
- eg: nil, // nil errgroup means Close returns nil
- }
-
- cleanupCalled := false
- r := &ChatRuntime{
- App: app,
- cleanup: func() { cleanupCalled = true },
- }
-
- err := r.Close()
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
- if !cleanupCalled {
- t.Error("cleanup should be called even after app.Close succeeds")
- }
- })
-
- t.Run("cleanup called after app close", func(t *testing.T) {
- // This test verifies the shutdown order:
- // 1. App.Close() (cancel context, wait for goroutines)
- // 2. Cleanup (DB pool, OTel)
- var order []string
-
- app := &App{
- cancel: func() { order = append(order, "cancel") },
- }
-
- r := &ChatRuntime{
- App: app,
- cleanup: func() { order = append(order, "cleanup") },
- }
-
- _ = r.Close()
-
- if len(order) != 2 {
- t.Fatalf("expected 2 operations, got %d", len(order))
- }
- if order[0] != "cancel" {
- t.Errorf("expected cancel first, got %s", order[0])
- }
- if order[1] != "cleanup" {
- t.Errorf("expected cleanup second, got %s", order[1])
- }
- })
-}
-
-// TestChatRuntime_Close_ErrorAggregation tests that errors are properly joined.
-func TestChatRuntime_Close_ErrorAggregation(t *testing.T) {
- t.Run("errors from app close are returned", func(t *testing.T) {
- // We can't easily make App.Close() return an error without
- // setting up an errgroup that returns an error.
- // This documents the expected behavior.
-
- // When App.Close() returns an error, ChatRuntime.Close() should:
- // 1. Still call cleanup()
- // 2. Return the error
-
- cleanupCalled := false
- r := &ChatRuntime{
- App: &App{
- cancel: nil,
- eg: nil,
- },
- cleanup: func() { cleanupCalled = true },
- }
-
- err := r.Close()
- if err != nil {
- t.Errorf("unexpected error: %v", err)
- }
- if !cleanupCalled {
- t.Error("cleanup should always be called")
- }
- })
-}
-
-// TestNewChatRuntime_CleanupOnFailure documents the expected behavior when
-// CreateAgent fails during NewChatRuntime initialization.
-//
-// The fix in runtime.go:60-66 ensures:
-// 1. application.Close() is called to stop background goroutines
-// 2. cleanup() is called to release DB pool and OTel resources
-// 3. The original error is returned
-//
-// This test cannot easily verify the behavior without mocking,
-// but documents the contract for future maintainers.
-func TestNewChatRuntime_CleanupOnFailure(t *testing.T) {
- t.Run("documented behavior on CreateAgent failure", func(t *testing.T) {
- // When CreateAgent fails:
- // 1. application is already created (with background goroutine)
- // 2. We must call application.Close() to:
- // - Cancel context
- // - Wait for errgroup (background IndexSystemKnowledge)
- // 3. Then call cleanup() to close DB pool
- //
- // Without this order:
- // - Background goroutine may use closed DB pool
- // - Goroutine leak if context not canceled
- //
- // Integration testing is required to fully verify this behavior.
- t.Log("See runtime.go:60-66 for implementation")
- })
-}
-
-// TestErrors_Join verifies errors.Join behavior used in Close().
-func TestErrors_Join(t *testing.T) {
- t.Run("nil errors return nil", func(t *testing.T) {
- err := errors.Join(nil, nil)
- if err != nil {
- t.Errorf("expected nil, got %v", err)
- }
- })
-
- t.Run("empty slice returns nil", func(t *testing.T) {
- var errs []error
- err := errors.Join(errs...)
- if err != nil {
- t.Errorf("expected nil, got %v", err)
- }
- })
-
- t.Run("single error preserved", func(t *testing.T) {
- original := errors.New("test error")
- errs := []error{original}
- err := errors.Join(errs...)
- if err == nil {
- t.Fatal("expected error")
- }
- if !errors.Is(err, original) {
- t.Error("error should wrap original")
- }
- })
-}
diff --git a/internal/app/setup.go b/internal/app/setup.go
index d172497..fc1e071 100644
--- a/internal/app/setup.go
+++ b/internal/app/setup.go
@@ -2,24 +2,26 @@ package app
import (
"context"
+ "errors"
"fmt"
"log/slog"
+ "os"
"time"
- "golang.org/x/sync/errgroup"
-
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/core/api"
+ "github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/plugins/compat_oai/openai"
"github.com/firebase/genkit/go/plugins/googlegenai"
"github.com/firebase/genkit/go/plugins/ollama"
"github.com/firebase/genkit/go/plugins/postgresql"
"github.com/jackc/pgx/v5/pgxpool"
+ "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
+ sdktrace "go.opentelemetry.io/otel/sdk/trace"
"github.com/koopa0/koopa/db"
"github.com/koopa0/koopa/internal/config"
- "github.com/koopa0/koopa/internal/observability"
"github.com/koopa0/koopa/internal/rag"
"github.com/koopa0/koopa/internal/security"
"github.com/koopa0/koopa/internal/session"
@@ -27,96 +29,126 @@ import (
"github.com/koopa0/koopa/internal/tools"
)
-// InitializeApp creates and initializes all application dependencies.
-// Returns the App, a cleanup function for infrastructure resources (DB pool, OTel),
-// and any initialization error.
-func InitializeApp(ctx context.Context, cfg *config.Config) (*App, func(), error) {
- otelCleanup, err := provideOtelShutdown(ctx, cfg)
- if err != nil {
- return nil, nil, err
- }
+// Setup creates and initializes the application.
+// Returns an App with embedded cleanup — call Close() to release.
+func Setup(ctx context.Context, cfg *config.Config) (_ *App, retErr error) {
+ a := &App{Config: cfg}
+
+ // On error, clean up everything already initialized
+ defer func() {
+ if retErr != nil {
+ if err := a.Close(); err != nil {
+ slog.Warn("cleanup during setup failure", "error", err)
+ }
+ }
+ }()
+
+ a.otelCleanup = provideOtelShutdown(ctx, cfg)
+
pool, dbCleanup, err := provideDBPool(ctx, cfg)
if err != nil {
- otelCleanup()
- return nil, nil, err
+ return nil, err
}
+ a.dbCleanup = dbCleanup
+ a.DBPool = pool
+
postgres, err := providePostgresPlugin(ctx, pool, cfg)
if err != nil {
- dbCleanup()
- otelCleanup()
- return nil, nil, err
+ return nil, err
}
+
g, err := provideGenkit(ctx, cfg, postgres)
if err != nil {
- dbCleanup()
- otelCleanup()
- return nil, nil, err
+ return nil, err
}
+ a.Genkit = g
+
embedder := provideEmbedder(g, cfg)
+ if embedder == nil {
+ return nil, fmt.Errorf("embedder %q not found for provider %q", cfg.EmbedderModel, cfg.Provider)
+ }
+ a.Embedder = embedder
+
docStore, retriever, err := provideRAGComponents(ctx, g, postgres, embedder)
if err != nil {
- dbCleanup()
- otelCleanup()
- return nil, nil, err
+ return nil, err
}
- store := provideSessionStore(pool)
+ a.DocStore = docStore
+ a.Retriever = retriever
+
+ a.SessionStore = provideSessionStore(pool)
+
path, err := providePathValidator()
if err != nil {
- dbCleanup()
- otelCleanup()
- return nil, nil, err
+ return nil, err
}
- v, err := provideTools(g, path, retriever, docStore, cfg)
- if err != nil {
- dbCleanup()
- otelCleanup()
- return nil, nil, err
- }
- application := newApp(ctx, cfg, g, embedder, pool, docStore, retriever, store, path, v)
-
- // Start background system knowledge indexing.
- // Launched here (not in newApp) to keep the constructor side-effect free.
- //nolint:contextcheck // Independent context: indexing must complete even if parent is canceled
- application.Go(func() error {
- indexCtx, indexCancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer indexCancel()
-
- count, err := rag.IndexSystemKnowledge(indexCtx, docStore, pool)
- if err != nil {
- slog.Debug("system knowledge indexing failed (non-critical)", "error", err)
- return nil
- }
- slog.Debug("system knowledge indexed successfully", "count", count)
- return nil
- })
-
- return application, func() {
- dbCleanup()
- otelCleanup()
- }, nil
+ a.PathValidator = path
+
+ if err := provideTools(a); err != nil {
+ return nil, err
+ }
+
+ // Set up lifecycle management
+ _, cancel := context.WithCancel(ctx)
+ a.cancel = cancel
+
+ return a, nil
}
// provideOtelShutdown sets up Datadog tracing before Genkit initialization.
// Must be called before provideGenkit to ensure TracerProvider is ready.
-func provideOtelShutdown(ctx context.Context, cfg *config.Config) (func(), error) {
- shutdown, err := observability.SetupDatadog(ctx, observability.Config{
- AgentHost: cfg.Datadog.AgentHost,
- Environment: cfg.Datadog.Environment,
- ServiceName: cfg.Datadog.ServiceName,
- })
+//
+// Traces are exported to a local Datadog Agent via OTLP HTTP (localhost:4318).
+// The Agent handles authentication, buffering, and forwarding to Datadog backend.
+func provideOtelShutdown(ctx context.Context, cfg *config.Config) func() {
+ dd := cfg.Datadog
+
+ agentHost := dd.AgentHost
+ if agentHost == "" {
+ agentHost = "localhost:4318"
+ }
+
+ // Set OTEL env vars for Genkit's TracerProvider to pick up.
+ // SAFETY: os.Setenv is not concurrent-safe, but this function is called
+ // exactly once during startup in Setup, before goroutines are spawned.
+ if dd.ServiceName != "" {
+ _ = os.Setenv("OTEL_SERVICE_NAME", dd.ServiceName)
+ }
+ if dd.Environment != "" {
+ _ = os.Setenv("OTEL_RESOURCE_ATTRIBUTES", "deployment.environment="+dd.Environment)
+ }
+
+ // Create OTLP HTTP exporter pointing to local Datadog Agent.
+ // Agent handles authentication and forwarding to Datadog backend.
+ exporter, err := otlptracehttp.New(ctx,
+ otlptracehttp.WithEndpoint(agentHost),
+ otlptracehttp.WithInsecure(), // localhost doesn't need TLS
+ )
if err != nil {
- return nil, err
+ slog.Warn("creating datadog exporter, tracing disabled", "error", err)
+ return func() {}
}
+ // Register BatchSpanProcessor with Genkit's TracerProvider.
+ processor := sdktrace.NewBatchSpanProcessor(exporter)
+ tracing.TracerProvider().RegisterSpanProcessor(processor)
+
+ slog.Debug("datadog tracing enabled",
+ "agent", agentHost,
+ "service", dd.ServiceName,
+ "environment", dd.Environment,
+ )
+
+ shutdown := tracing.TracerProvider().Shutdown
+
//nolint:contextcheck // Independent context: shutdown runs during teardown when parent is canceled
- cleanupFn := func() {
+ return func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := shutdown(shutdownCtx); err != nil {
- slog.Warn("failed to shutdown tracer provider", "error", err)
+ slog.Warn("shutting down tracer provider", "error", err)
}
}
- return cleanupFn, nil
}
// providePostgresPlugin creates the Genkit PostgreSQL plugin.
@@ -124,7 +156,7 @@ func provideOtelShutdown(ctx context.Context, cfg *config.Config) (func(), error
func providePostgresPlugin(ctx context.Context, pool *pgxpool.Pool, cfg *config.Config) (*postgresql.Postgres, error) {
pEngine, err := postgresql.NewPostgresEngine(ctx, postgresql.WithPool(pool), postgresql.WithDatabase(cfg.PostgresDBName))
if err != nil {
- return nil, fmt.Errorf("failed to create postgres engine: %w", err)
+ return nil, fmt.Errorf("creating postgres engine: %w", err)
}
return &postgresql.Postgres{Engine: pEngine}, nil
@@ -132,7 +164,7 @@ func providePostgresPlugin(ctx context.Context, pool *pgxpool.Pool, cfg *config.
// provideGenkit initializes Genkit with the configured AI provider and PostgreSQL plugins.
// Supports gemini (default), ollama, and openai providers.
-// Call ordering in InitializeApp ensures tracing is set up first.
+// Call ordering in Setup ensures tracing is set up first.
func provideGenkit(ctx context.Context, cfg *config.Config, postgres *postgresql.Postgres) (*genkit.Genkit, error) {
promptDir := cfg.PromptDir
if promptDir == "" {
@@ -154,7 +186,7 @@ func provideGenkit(ctx context.Context, cfg *config.Config, postgres *postgresql
genkit.WithPromptDir(promptDir),
)
if g == nil {
- return nil, fmt.Errorf("failed to initialize Genkit with ollama provider")
+ return nil, errors.New("initializing genkit with ollama provider")
}
// Ollama requires explicit model registration (no auto-discovery)
ollamaPlugin.DefineModel(g, ollama.ModelDefinition{
@@ -172,7 +204,7 @@ func provideGenkit(ctx context.Context, cfg *config.Config, postgres *postgresql
genkit.WithPromptDir(promptDir),
)
if g == nil {
- return nil, fmt.Errorf("failed to initialize Genkit with openai provider")
+ return nil, errors.New("initializing genkit with openai provider")
}
slog.Info("initialized Genkit with openai provider", "model", cfg.ModelName)
@@ -182,7 +214,7 @@ func provideGenkit(ctx context.Context, cfg *config.Config, postgres *postgresql
genkit.WithPromptDir(promptDir),
)
if g == nil {
- return nil, fmt.Errorf("failed to initialize Genkit with gemini provider")
+ return nil, errors.New("initializing genkit with gemini provider")
}
slog.Info("initialized Genkit with gemini provider", "model", cfg.ModelName)
}
@@ -217,12 +249,12 @@ func provideEmbedder(g *genkit.Genkit, cfg *config.Config) ai.Embedder {
// Pool is configured with sensible defaults for connection management.
func provideDBPool(ctx context.Context, cfg *config.Config) (*pgxpool.Pool, func(), error) {
if err := db.Migrate(cfg.PostgresURL()); err != nil {
- return nil, nil, fmt.Errorf("failed to run migrations: %w", err)
+ return nil, nil, fmt.Errorf("running migrations: %w", err)
}
poolCfg, err := pgxpool.ParseConfig(cfg.PostgresConnectionString())
if err != nil {
- return nil, nil, fmt.Errorf("failed to parse connection config: %w", err)
+ return nil, nil, fmt.Errorf("parsing connection config: %w", err)
}
poolCfg.MaxConns = 10
@@ -233,14 +265,14 @@ func provideDBPool(ctx context.Context, cfg *config.Config) (*pgxpool.Pool, func
pool, err := pgxpool.NewWithConfig(ctx, poolCfg)
if err != nil {
- return nil, nil, fmt.Errorf("failed to create connection pool: %w", err)
+ return nil, nil, fmt.Errorf("creating connection pool: %w", err)
}
pingCtx, pingCancel := context.WithTimeout(ctx, 5*time.Second)
defer pingCancel()
if err := pool.Ping(pingCtx); err != nil {
pool.Close()
- return nil, nil, fmt.Errorf("failed to ping database: %w", err)
+ return nil, nil, fmt.Errorf("pinging database: %w", err)
}
cleanup := func() {
@@ -256,7 +288,7 @@ func provideRAGComponents(ctx context.Context, g *genkit.Genkit, postgres *postg
cfg := rag.NewDocStoreConfig(embedder)
docStore, retriever, err := postgresql.DefineRetriever(ctx, g, postgres, cfg)
if err != nil {
- return nil, nil, fmt.Errorf("failed to define retriever: %w", err)
+ return nil, nil, fmt.Errorf("defining retriever: %w", err)
}
return docStore, retriever, nil
@@ -272,96 +304,65 @@ func providePathValidator() (*security.Path, error) {
return security.NewPath([]string{"."})
}
-// provideTools registers all tools at construction time.
-// Tools are registered once here, not lazily in CreateAgent.
-func provideTools(g *genkit.Genkit, pathValidator *security.Path, retriever ai.Retriever, docStore *postgresql.DocStore, cfg *config.Config) ([]ai.Tool, error) {
+// provideTools creates toolsets, registers them with Genkit, and stores both
+// the concrete toolsets and the Genkit-wrapped references in a.
+func provideTools(a *App) error {
logger := slog.Default()
+ cfg := a.Config
var allTools []ai.Tool
- ft, err := tools.NewFileTools(pathValidator, logger)
+ ft, err := tools.NewFile(a.PathValidator, logger)
if err != nil {
- return nil, fmt.Errorf("creating file tools: %w", err)
+ return fmt.Errorf("creating file tools: %w", err)
}
- fileTools, err := tools.RegisterFileTools(g, ft)
+ a.File = ft
+ fileTools, err := tools.RegisterFile(a.Genkit, ft)
if err != nil {
- return nil, fmt.Errorf("registering file tools: %w", err)
+ return fmt.Errorf("registering file tools: %w", err)
}
allTools = append(allTools, fileTools...)
cmdValidator := security.NewCommand()
envValidator := security.NewEnv()
- st, err := tools.NewSystemTools(cmdValidator, envValidator, logger)
+ st, err := tools.NewSystem(cmdValidator, envValidator, logger)
if err != nil {
- return nil, fmt.Errorf("creating system tools: %w", err)
+ return fmt.Errorf("creating system tools: %w", err)
}
- systemTools, err := tools.RegisterSystemTools(g, st)
+ a.System = st
+ systemTools, err := tools.RegisterSystem(a.Genkit, st)
if err != nil {
- return nil, fmt.Errorf("registering system tools: %w", err)
+ return fmt.Errorf("registering system tools: %w", err)
}
allTools = append(allTools, systemTools...)
- nt, err := tools.NewNetworkTools(tools.NetworkConfig{
+ nt, err := tools.NewNetwork(tools.NetConfig{
SearchBaseURL: cfg.SearXNG.BaseURL,
FetchParallelism: cfg.WebScraper.Parallelism,
FetchDelay: time.Duration(cfg.WebScraper.DelayMs) * time.Millisecond,
FetchTimeout: time.Duration(cfg.WebScraper.TimeoutMs) * time.Millisecond,
}, logger)
if err != nil {
- return nil, fmt.Errorf("creating network tools: %w", err)
+ return fmt.Errorf("creating network tools: %w", err)
}
- networkTools, err := tools.RegisterNetworkTools(g, nt)
+ a.Network = nt
+ networkTools, err := tools.RegisterNetwork(a.Genkit, nt)
if err != nil {
- return nil, fmt.Errorf("registering network tools: %w", err)
+ return fmt.Errorf("registering network tools: %w", err)
}
allTools = append(allTools, networkTools...)
- kt, err := tools.NewKnowledgeTools(retriever, docStore, logger)
+ kt, err := tools.NewKnowledge(a.Retriever, a.DocStore, logger)
if err != nil {
- return nil, fmt.Errorf("creating knowledge tools: %w", err)
+ return fmt.Errorf("creating knowledge tools: %w", err)
}
- knowledgeTools, err := tools.RegisterKnowledgeTools(g, kt)
+ a.Knowledge = kt
+ knowledgeTools, err := tools.RegisterKnowledge(a.Genkit, kt)
if err != nil {
- return nil, fmt.Errorf("registering knowledge tools: %w", err)
+ return fmt.Errorf("registering knowledge tools: %w", err)
}
allTools = append(allTools, knowledgeTools...)
- slog.Info("tools registered at construction", "count", len(allTools))
- return allTools, nil
-}
-// newApp constructs an App instance with all dependencies.
-// All dependencies are injected by InitializeApp.
-// Tools are pre-registered by provideTools.
-// NOTE: This constructor has no side effects. Background tasks are started by the caller.
-func newApp(
- ctx context.Context,
- cfg *config.Config,
- g *genkit.Genkit,
- embedder ai.Embedder,
- pool *pgxpool.Pool,
- docStore *postgresql.DocStore,
- retriever ai.Retriever,
- sessionStore *session.Store,
- pathValidator *security.Path, registeredTools []ai.Tool,
-) *App {
-
- appCtx, cancel := context.WithCancel(ctx)
-
- eg, _ := errgroup.WithContext(appCtx)
-
- app := &App{
- Config: cfg,
- ctx: appCtx,
- cancel: cancel,
- eg: eg,
- Genkit: g,
- Embedder: embedder,
- DBPool: pool,
- DocStore: docStore,
- Retriever: retriever,
- SessionStore: sessionStore,
- PathValidator: pathValidator,
- Tools: registeredTools,
- }
-
- return app
+ a.Tools = allTools
+ slog.Info("tools registered at construction", "count", len(allTools))
+ return nil
}
diff --git a/internal/agent/chat/chat.go b/internal/chat/chat.go
similarity index 69%
rename from internal/agent/chat/chat.go
rename to internal/chat/chat.go
index de4163d..680b3bc 100644
--- a/internal/agent/chat/chat.go
+++ b/internal/chat/chat.go
@@ -4,17 +4,15 @@ import (
"context"
"errors"
"fmt"
+ "log/slog"
"strings"
"time"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
- "github.com/firebase/genkit/go/plugins/postgresql"
"github.com/google/uuid"
"golang.org/x/time/rate"
- "github.com/koopa0/koopa/internal/log"
- "github.com/koopa0/koopa/internal/rag"
"github.com/koopa0/koopa/internal/session"
)
@@ -31,8 +29,17 @@ const (
// NOTE: The LLM model is configured in the Dotprompt file, not via Config.
KoopaPromptName = "koopa"
- // FallbackResponseMessage is the message returned when the model produces an empty response.
- FallbackResponseMessage = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
+ // fallbackResponseMessage is the message returned when the model produces an empty response.
+ fallbackResponseMessage = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
+)
+
+// Sentinel errors for agent operations.
+var (
+ // ErrInvalidSession indicates the session ID is invalid or malformed.
+ ErrInvalidSession = errors.New("invalid session")
+
+ // ErrExecutionFailed indicates agent execution failed.
+ ErrExecutionFailed = errors.New("execution failed")
)
// Response represents the complete result of an agent execution.
@@ -49,15 +56,13 @@ type StreamCallback func(ctx context.Context, chunk *ai.ModelResponseChunk) erro
// Config contains all required parameters for Chat agent.
type Config struct {
Genkit *genkit.Genkit
- Retriever ai.Retriever // Genkit Retriever for RAG context
SessionStore *session.Store
- Logger log.Logger
+ Logger *slog.Logger
Tools []ai.Tool // Pre-registered tools from RegisterXxxTools()
// Configuration values
ModelName string // Provider-qualified model name (e.g., "googleai/gemini-2.5-flash", "ollama/llama3.3")
MaxTurns int // Maximum agentic loop turns
- RAGTopK int // Number of RAG documents to retrieve
Language string // Response language preference
// Resilience configuration
@@ -74,9 +79,6 @@ func (cfg Config) validate() error {
if cfg.Genkit == nil {
return errors.New("genkit instance is required")
}
- if cfg.Retriever == nil {
- return errors.New("retriever is required")
- }
if cfg.SessionStore == nil {
return errors.New("session store is required")
}
@@ -89,20 +91,19 @@ func (cfg Config) validate() error {
return nil
}
-// Chat is Koopa's main conversational agent.
+// Agent is Koopa's main conversational agent.
// It provides LLM-powered conversations with tool calling and knowledge base integration.
//
-// Chat is stateless and uses dependency injection.
+// Agent is stateless and uses dependency injection.
// Required parameters are provided via Config struct.
//
// All configuration values are captured immutably at construction time
// to ensure thread-safe concurrent access.
-type Chat struct {
+type Agent struct {
// Immutable configuration (captured at construction)
modelName string // Provider-qualified model name (overrides Dotprompt model)
languagePrompt string // Resolved language for prompt template
maxTurns int
- ragTopK int
// Resilience (captured at construction)
retryConfig RetryConfig
@@ -114,32 +115,32 @@ type Chat struct {
// Dependencies (read-only after construction)
g *genkit.Genkit
- retriever ai.Retriever // Genkit Retriever for RAG context
sessions *session.Store
- logger log.Logger
+ logger *slog.Logger
tools []ai.Tool // Pre-registered tools (passed in via Config)
toolRefs []ai.ToolRef // Cached at construction (ai.Tool implements ai.ToolRef)
toolNames string // Cached as comma-separated for logging
prompt ai.Prompt // Cached Dotprompt instance (model configured in prompt file)
}
-// New creates a new Chat agent with required configuration.
+// New creates a new Agent with required configuration.
+//
+// RAG context is provided by knowledge tools (search_documents, search_history,
+// search_system_knowledge) which the LLM calls when it determines context is needed.
//
// NOTE: The LLM model is configured in prompts/koopa.prompt, not via Config.
//
// Example:
//
-// chat, err := chat.New(chat.Config{
+// agent, err := chat.New(chat.Config{
// Genkit: g,
-// Retriever: retriever, // Genkit Retriever from postgresql.DefineRetriever
// SessionStore: sessionStore,
// Logger: logger,
// Tools: tools, // Pre-registered via RegisterXxxTools()
// MaxTurns: cfg.MaxTurns,
-// RAGTopK: cfg.RAGTopK,
// Language: cfg.Language,
// })
-func New(cfg Config) (*Chat, error) {
+func New(cfg Config) (*Agent, error) {
if err := cfg.validate(); err != nil {
return nil, err
}
@@ -187,12 +188,11 @@ func New(cfg Config) (*Chat, error) {
names[i] = t.Name()
}
- c := &Chat{
+ a := &Agent{
// Immutable configuration
modelName: cfg.ModelName,
languagePrompt: languagePrompt,
maxTurns: maxTurns,
- ragTopK: cfg.RAGTopK,
// Resilience
retryConfig: retryConfig,
@@ -204,7 +204,6 @@ func New(cfg Config) (*Chat, error) {
// Dependencies
g: cfg.Genkit,
- retriever: cfg.Retriever,
sessions: cfg.SessionStore,
logger: cfg.Logger,
tools: cfg.Tools, // Already registered with Genkit
@@ -214,44 +213,44 @@ func New(cfg Config) (*Chat, error) {
// Load Dotprompt (koopa.prompt) - REQUIRED
// NOTE: Model is configured in the prompt file, not via Config
- c.prompt = genkit.LookupPrompt(c.g, KoopaPromptName)
- if c.prompt == nil {
+ a.prompt = genkit.LookupPrompt(a.g, KoopaPromptName)
+ if a.prompt == nil {
return nil, fmt.Errorf("dotprompt '%s' not found: ensure prompts directory is configured correctly", KoopaPromptName)
}
- c.logger.Debug("loaded dotprompt successfully", "prompt_name", KoopaPromptName)
+ a.logger.Debug("loaded dotprompt successfully", "prompt_name", KoopaPromptName)
- c.logger.Info("chat agent initialized",
- "totalTools", len(c.tools),
- "maxTurns", c.maxTurns,
+ a.logger.Info("chat agent initialized",
+ "totalTools", len(a.tools),
+ "maxTurns", a.maxTurns,
)
- return c, nil
+ return a, nil
}
// Execute runs the chat agent with the given input (non-streaming).
// This is a convenience wrapper around ExecuteStream with nil callback.
-func (c *Chat) Execute(ctx context.Context, sessionID uuid.UUID, input string) (*Response, error) {
- return c.ExecuteStream(ctx, sessionID, input, nil)
+func (a *Agent) Execute(ctx context.Context, sessionID uuid.UUID, input string) (*Response, error) {
+ return a.ExecuteStream(ctx, sessionID, input, nil)
}
// ExecuteStream runs the chat agent with optional streaming output.
// If callback is non-nil, it is called for each chunk of the response as it's generated.
// If callback is nil, the response is generated without streaming (equivalent to Execute).
// The final response is always returned after generation completes.
-func (c *Chat) ExecuteStream(ctx context.Context, sessionID uuid.UUID, input string, callback StreamCallback) (*Response, error) {
+func (a *Agent) ExecuteStream(ctx context.Context, sessionID uuid.UUID, input string, callback StreamCallback) (*Response, error) {
streaming := callback != nil
- c.logger.Debug("executing chat agent",
+ a.logger.Debug("executing chat agent",
"session_id", sessionID,
"streaming", streaming)
// Load session history
- history, err := c.sessions.History(ctx, sessionID)
+ historyMessages, err := a.sessions.History(ctx, sessionID)
if err != nil {
- return nil, fmt.Errorf("failed to get history: %w", err)
+ return nil, fmt.Errorf("getting history: %w", err)
}
// Generate response using unified core logic
- resp, err := c.generateResponse(ctx, input, history.Messages(), callback)
+ resp, err := a.generateResponse(ctx, input, historyMessages, callback)
if err != nil {
return nil, err
}
@@ -261,22 +260,18 @@ func (c *Chat) ExecuteStream(ctx context.Context, sessionID uuid.UUID, input str
// Only apply fallback when truly empty (no text AND no tool requests)
// When LLM returns empty text but has tool requests, this is valid agentic behavior
if strings.TrimSpace(responseText) == "" && len(resp.ToolRequests()) == 0 {
- c.logger.Warn("model returned empty response with no tool requests",
+ a.logger.Warn("model returned empty response with no tool requests",
"session_id", sessionID)
- responseText = FallbackResponseMessage
+ responseText = fallbackResponseMessage
}
- // Update history with user input and response
- history.Add(input, responseText)
-
- // Save updated history to session store using AppendMessages (preferred)
+ // Save new messages to session store
newMessages := []*ai.Message{
ai.NewUserMessage(ai.NewTextPart(input)),
ai.NewModelMessage(ai.NewTextPart(responseText)),
}
- if err := c.sessions.AppendMessages(ctx, sessionID, newMessages); err != nil {
- c.logger.Error("failed to append messages to history", "error", err)
- // Don't fail the request, just log the error
+ if err := a.sessions.AppendMessages(ctx, sessionID, newMessages); err != nil {
+ a.logger.Error("appending messages to history", "error", err) // best-effort: don't fail the request
}
// Return formatted response
@@ -288,7 +283,7 @@ func (c *Chat) ExecuteStream(ctx context.Context, sessionID uuid.UUID, input str
// generateResponse is the unified response generation logic for both streaming and non-streaming modes.
// If callback is non-nil, streaming is enabled; otherwise, standard generation is used.
-func (c *Chat) generateResponse(ctx context.Context, input string, historyMessages []*ai.Message, callback StreamCallback) (*ai.ModelResponse, error) {
+func (a *Agent) generateResponse(ctx context.Context, input string, historyMessages []*ai.Message, callback StreamCallback) (*ai.ModelResponse, error) {
// Build messages: deep copy history and append current user input
// CRITICAL: Deep copy is required to prevent DATA RACE in Genkit's renderMessages()
// Genkit modifies msg.Content in-place, so concurrent executions sharing the same
@@ -297,33 +292,25 @@ func (c *Chat) generateResponse(ctx context.Context, input string, historyMessag
// Apply token budget before adding new message
// This ensures we don't exceed context window limits
- messages = c.truncateHistory(messages, c.tokenBudget.MaxHistoryTokens)
+ messages = a.truncateHistory(messages, a.tokenBudget.MaxHistoryTokens)
messages = append(messages, ai.NewUserMessage(ai.NewTextPart(input)))
- // Retrieve relevant documents for RAG context (graceful fallback on error)
- ragDocs := c.retrieveRAGContext(ctx, input)
-
// Build execute options (using cached toolRefs and languagePrompt)
opts := []ai.PromptExecuteOption{
ai.WithInput(map[string]any{
- "language": c.languagePrompt,
+ "language": a.languagePrompt,
}),
ai.WithMessagesFn(func(_ context.Context, _ any) ([]*ai.Message, error) {
return messages, nil
}),
- ai.WithTools(c.toolRefs...),
- ai.WithMaxTurns(c.maxTurns),
+ ai.WithTools(a.toolRefs...),
+ ai.WithMaxTurns(a.maxTurns),
}
// Override model from Dotprompt if configured (supports multi-provider)
- if c.modelName != "" {
- opts = append(opts, ai.WithModelName(c.modelName))
- }
-
- // Add RAG documents if available
- if len(ragDocs) > 0 {
- opts = append(opts, ai.WithDocs(ragDocs...))
+ if a.modelName != "" {
+ opts = append(opts, ai.WithModelName(a.modelName))
}
// Add streaming callback if provided
@@ -332,84 +319,31 @@ func (c *Chat) generateResponse(ctx context.Context, input string, historyMessag
}
// Diagnostic logging (using cached toolNames - zero allocation)
- c.logger.Debug("executing prompt",
- "toolCount", len(c.tools),
- "tools", c.toolNames,
- "maxTurns", c.maxTurns,
+ a.logger.Debug("executing prompt",
+ "toolCount", len(a.tools),
+ "tools", a.toolNames,
+ "maxTurns", a.maxTurns,
"queryLength", len(input),
)
// Check circuit breaker before attempting request
- if err := c.circuitBreaker.Allow(); err != nil {
- c.logger.Warn("circuit breaker is open, rejecting request",
- "state", c.circuitBreaker.State().String())
+ if err := a.circuitBreaker.Allow(); err != nil {
+ a.logger.Warn("circuit breaker is open, rejecting request",
+ "state", a.circuitBreaker.State().String())
return nil, fmt.Errorf("service unavailable: %w", err)
}
// Execute prompt with retry mechanism
- resp, err := c.executeWithRetry(ctx, opts)
+ resp, err := a.executeWithRetry(ctx, opts)
if err != nil {
- c.circuitBreaker.Failure()
+ a.circuitBreaker.Failure()
return nil, err
}
- c.circuitBreaker.Success()
+ a.circuitBreaker.Success()
return resp, nil
}
-// ragRetrievalTimeout is the maximum time allowed for RAG document retrieval.
-// This prevents slow queries from blocking the entire chat request.
-const ragRetrievalTimeout = 5 * time.Second
-
-// retrieveRAGContext retrieves relevant documents from the knowledge base.
-// Returns empty slice on error (graceful degradation).
-func (c *Chat) retrieveRAGContext(ctx context.Context, query string) []*ai.Document {
- // Skip RAG if topK is not configured or zero
- if c.ragTopK <= 0 {
- return nil
- }
-
- // Add dedicated timeout for RAG retrieval to prevent slow queries
- // from blocking the entire chat request
- ragCtx, cancel := context.WithTimeout(ctx, ragRetrievalTimeout)
- defer cancel()
-
- // Build retriever request with source_type filter for files (documents)
- req := &ai.RetrieverRequest{
- Query: ai.DocumentFromText(query, nil),
- Options: &postgresql.RetrieverOptions{
- Filter: "source_type = '" + rag.SourceTypeFile + "'",
- K: c.ragTopK,
- },
- }
-
- // Retrieve documents with timeout
- resp, err := c.retriever.Retrieve(ragCtx, req)
- if err != nil {
- // Use Debug for expected errors (timeout, cancellation)
- // Use Warn for unexpected errors (DB issues, etc.) that ops should know about
- if ctx.Err() != nil || ragCtx.Err() != nil {
- c.logger.Debug("RAG retrieval canceled or timed out (continuing without context)",
- "error", err,
- "timeout", ragRetrievalTimeout,
- "query_length", len(query))
- } else {
- c.logger.Warn("RAG retrieval failed (continuing without context)",
- "error", err,
- "query_length", len(query))
- }
- return nil
- }
-
- if len(resp.Documents) > 0 {
- c.logger.Debug("retrieved RAG context",
- "document_count", len(resp.Documents),
- "query_length", len(query))
- }
-
- return resp.Documents
-}
-
// deepCopyMessages creates independent copies of Message and Part structs.
//
// WORKAROUND: Genkit's renderMessages() modifies msg.Content in-place,
@@ -420,7 +354,7 @@ func (c *Chat) retrieveRAGContext(ctx context.Context, query string) []*ai.Docum
//
// To remove this workaround:
// 1. Upgrade Genkit: go get -u github.com/firebase/genkit/go@latest
-// 2. Run: go test -race ./internal/agent/chat/...
+// 2. Run: go test -race ./internal/chat/...
// 3. If race detector passes, remove deepCopyMessages() calls
// 4. If race still fails, update version in this comment
func deepCopyMessages(msgs []*ai.Message) []*ai.Message {
@@ -492,3 +426,56 @@ func shallowCopyMap(m map[string]any) map[string]any {
}
return cp
}
+
+// Title generation constants.
+const (
+ titleMaxLength = 50
+ titleGenerationTimeout = 5 * time.Second
+ titleInputMaxRunes = 500
+)
+
+const titlePrompt = `Generate a concise title (max 50 characters) for a chat session based on this first message.
+The title should capture the main topic or intent.
+Return ONLY the title text, no quotes, no explanations, no punctuation at the end.
+
+Message: %s
+
+Title:`
+
+// GenerateTitle generates a concise session title from the user's first message.
+// Uses AI generation with fallback to simple truncation.
+// Returns empty string on failure (best-effort).
+func (a *Agent) GenerateTitle(ctx context.Context, userMessage string) string {
+ ctx, cancel := context.WithTimeout(ctx, titleGenerationTimeout)
+ defer cancel()
+
+ inputRunes := []rune(userMessage)
+ if len(inputRunes) > titleInputMaxRunes {
+ userMessage = string(inputRunes[:titleInputMaxRunes]) + "..."
+ }
+
+ opts := []ai.GenerateOption{
+ ai.WithPrompt(titlePrompt, userMessage),
+ }
+ if a.modelName != "" {
+ opts = append(opts, ai.WithModelName(a.modelName))
+ }
+
+ response, err := genkit.Generate(ctx, a.g, opts...)
+ if err != nil {
+ a.logger.Debug("AI title generation failed", "error", err)
+ return ""
+ }
+
+ title := strings.TrimSpace(response.Text())
+ if title == "" {
+ return ""
+ }
+
+ titleRunes := []rune(title)
+ if len(titleRunes) > titleMaxLength {
+ title = string(titleRunes[:titleMaxLength-3]) + "..."
+ }
+
+ return title
+}
diff --git a/internal/chat/chat_test.go b/internal/chat/chat_test.go
new file mode 100644
index 0000000..f0b4c70
--- /dev/null
+++ b/internal/chat/chat_test.go
@@ -0,0 +1,310 @@
+package chat
+
+import (
+ "log/slog"
+ "strings"
+ "testing"
+
+ "github.com/firebase/genkit/go/ai"
+ "github.com/firebase/genkit/go/genkit"
+
+ "github.com/koopa0/koopa/internal/session"
+)
+
+// TestConfig_validate tests that each validation check in Config.validate()
+// fires independently. Each case provides enough deps to pass prior checks.
+func TestConfig_validate(t *testing.T) {
+ t.Parallel()
+
+ // Minimal non-nil stubs — validate() only checks nil, never dereferences.
+ stubG := new(genkit.Genkit)
+ stubS := new(session.Store)
+ stubL := slog.New(slog.DiscardHandler)
+
+ tests := []struct {
+ name string
+ cfg Config
+ errContains string
+ }{
+ {
+ name: "nil genkit",
+ cfg: Config{},
+ errContains: "genkit instance is required",
+ },
+ {
+ name: "nil session store",
+ cfg: Config{
+ Genkit: stubG,
+ },
+ errContains: "session store is required",
+ },
+ {
+ name: "nil logger",
+ cfg: Config{
+ Genkit: stubG,
+ SessionStore: stubS,
+ },
+ errContains: "logger is required",
+ },
+ {
+ name: "empty tools",
+ cfg: Config{
+ Genkit: stubG,
+ SessionStore: stubS,
+ Logger: stubL,
+ Tools: []ai.Tool{},
+ },
+ errContains: "at least one tool is required",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ err := tt.cfg.validate()
+ if err == nil {
+ t.Fatal("validate() expected error, got nil")
+ }
+ if !strings.Contains(err.Error(), tt.errContains) {
+ t.Errorf("validate() error = %q, want to contain %q", err.Error(), tt.errContains)
+ }
+ })
+ }
+}
+
+func TestDeepCopyMessages_NilInput(t *testing.T) {
+ t.Parallel()
+ got := deepCopyMessages(nil)
+ if got != nil {
+ t.Errorf("deepCopyMessages(nil) = %v, want nil", got)
+ }
+}
+
+func TestDeepCopyMessages_EmptySlice(t *testing.T) {
+ t.Parallel()
+ got := deepCopyMessages([]*ai.Message{})
+ if got == nil {
+ t.Fatal("deepCopyMessages(empty) = nil, want non-nil empty slice")
+ }
+ if len(got) != 0 {
+ t.Errorf("deepCopyMessages(empty) len = %d, want 0", len(got))
+ }
+}
+
+func TestDeepCopyMessages_MutateOriginalText(t *testing.T) {
+ t.Parallel()
+
+ original := []*ai.Message{
+ ai.NewUserMessage(ai.NewTextPart("hello world")),
+ }
+
+ copied := deepCopyMessages(original)
+
+ // Mutate the original message's content slice
+ original[0].Content[0].Text = "MUTATED"
+
+ if copied[0].Content[0].Text != "hello world" {
+ t.Errorf("deepCopyMessages() copy was affected by original mutation: got %q, want %q",
+ copied[0].Content[0].Text, "hello world")
+ }
+}
+
+func TestDeepCopyMessages_MutateOriginalContentSlice(t *testing.T) {
+ t.Parallel()
+
+ original := []*ai.Message{
+ ai.NewUserMessage(ai.NewTextPart("first"), ai.NewTextPart("second")),
+ }
+
+ copied := deepCopyMessages(original)
+
+ // Append to original's content slice — should not affect copy
+ original[0].Content = append(original[0].Content, ai.NewTextPart("third"))
+
+ if len(copied[0].Content) != 2 {
+ t.Errorf("deepCopyMessages() copy content len = %d, want 2", len(copied[0].Content))
+ }
+}
+
+func TestDeepCopyMessages_PreservesRole(t *testing.T) {
+ t.Parallel()
+
+ original := []*ai.Message{
+ ai.NewUserMessage(ai.NewTextPart("q")),
+ ai.NewModelMessage(ai.NewTextPart("a")),
+ }
+
+ copied := deepCopyMessages(original)
+
+ if copied[0].Role != ai.RoleUser {
+ t.Errorf("deepCopyMessages()[0].Role = %q, want %q", copied[0].Role, ai.RoleUser)
+ }
+ if copied[1].Role != ai.RoleModel {
+ t.Errorf("deepCopyMessages()[1].Role = %q, want %q", copied[1].Role, ai.RoleModel)
+ }
+}
+
+func TestDeepCopyMessages_Metadata(t *testing.T) {
+ t.Parallel()
+
+ original := []*ai.Message{{
+ Role: ai.RoleUser,
+ Content: []*ai.Part{ai.NewTextPart("test")},
+ Metadata: map[string]any{"key": "value"},
+ }}
+
+ copied := deepCopyMessages(original)
+
+ // Mutate original metadata
+ original[0].Metadata["key"] = "MUTATED"
+
+ if copied[0].Metadata["key"] != "value" {
+ t.Errorf("deepCopyMessages() metadata was affected by mutation: got %q, want %q",
+ copied[0].Metadata["key"], "value")
+ }
+}
+
+func TestDeepCopyPart_NilInput(t *testing.T) {
+ t.Parallel()
+ got := deepCopyPart(nil)
+ if got != nil {
+ t.Errorf("deepCopyPart(nil) = %v, want nil", got)
+ }
+}
+
+func TestDeepCopyPart_TextPart(t *testing.T) {
+ t.Parallel()
+
+ original := ai.NewTextPart("hello")
+ copied := deepCopyPart(original)
+
+ original.Text = "MUTATED"
+
+ if copied.Text != "hello" {
+ t.Errorf("deepCopyPart() text affected by mutation: got %q, want %q", copied.Text, "hello")
+ }
+}
+
+func TestDeepCopyPart_ToolRequest(t *testing.T) {
+ t.Parallel()
+
+ original := &ai.Part{
+ Kind: ai.PartToolRequest,
+ ToolRequest: &ai.ToolRequest{
+ Name: "read_file",
+ Input: map[string]any{"path": "/tmp/test"},
+ },
+ }
+
+ copied := deepCopyPart(original)
+
+ // Mutate original ToolRequest name
+ original.ToolRequest.Name = "MUTATED"
+
+ if copied.ToolRequest.Name != "read_file" {
+ t.Errorf("deepCopyPart() ToolRequest.Name affected by mutation: got %q, want %q",
+ copied.ToolRequest.Name, "read_file")
+ }
+}
+
+func TestDeepCopyPart_ToolResponse(t *testing.T) {
+ t.Parallel()
+
+ original := &ai.Part{
+ Kind: ai.PartToolResponse,
+ ToolResponse: &ai.ToolResponse{
+ Name: "read_file",
+ Output: "file contents",
+ },
+ }
+
+ copied := deepCopyPart(original)
+
+ original.ToolResponse.Name = "MUTATED"
+
+ if copied.ToolResponse.Name != "read_file" {
+ t.Errorf("deepCopyPart() ToolResponse.Name affected by mutation: got %q, want %q",
+ copied.ToolResponse.Name, "read_file")
+ }
+}
+
+func TestDeepCopyPart_Resource(t *testing.T) {
+ t.Parallel()
+
+ original := &ai.Part{
+ Kind: ai.PartMedia,
+ Resource: &ai.ResourcePart{Uri: "https://example.com/image.png"},
+ }
+
+ copied := deepCopyPart(original)
+
+ original.Resource.Uri = "MUTATED"
+
+ if copied.Resource.Uri != "https://example.com/image.png" {
+ t.Errorf("deepCopyPart() Resource.Uri affected by mutation: got %q, want %q",
+ copied.Resource.Uri, "https://example.com/image.png")
+ }
+}
+
+func TestDeepCopyPart_PartMetadata(t *testing.T) {
+ t.Parallel()
+
+ original := &ai.Part{
+ Kind: ai.PartText,
+ Text: "test",
+ Custom: map[string]any{"c": "custom"},
+ Metadata: map[string]any{"m": "meta"},
+ }
+
+ copied := deepCopyPart(original)
+
+ original.Custom["c"] = "MUTATED"
+ original.Metadata["m"] = "MUTATED"
+
+ if copied.Custom["c"] != "custom" {
+ t.Errorf("deepCopyPart() Custom map affected: got %q, want %q", copied.Custom["c"], "custom")
+ }
+ if copied.Metadata["m"] != "meta" {
+ t.Errorf("deepCopyPart() Metadata map affected: got %q, want %q", copied.Metadata["m"], "meta")
+ }
+}
+
+func TestShallowCopyMap_NilInput(t *testing.T) {
+ t.Parallel()
+ got := shallowCopyMap(nil)
+ if got != nil {
+ t.Errorf("shallowCopyMap(nil) = %v, want nil", got)
+ }
+}
+
+func TestShallowCopyMap_IndependentKeys(t *testing.T) {
+ t.Parallel()
+
+ original := map[string]any{"a": "1", "b": "2"}
+ copied := shallowCopyMap(original)
+
+ // Add new key to original
+ original["c"] = "3"
+
+ if _, ok := copied["c"]; ok {
+ t.Error("shallowCopyMap() new key in original appeared in copy")
+ }
+ if len(copied) != 2 {
+ t.Errorf("shallowCopyMap() copy len = %d, want 2", len(copied))
+ }
+}
+
+func TestShallowCopyMap_MutateValue(t *testing.T) {
+ t.Parallel()
+
+ original := map[string]any{"key": "value"}
+ copied := shallowCopyMap(original)
+
+ // Overwrite original value
+ original["key"] = "MUTATED"
+
+ if copied["key"] != "value" {
+ t.Errorf("shallowCopyMap() value affected by mutation: got %q, want %q",
+ copied["key"], "value")
+ }
+}
diff --git a/internal/agent/chat/circuit.go b/internal/chat/circuit.go
similarity index 100%
rename from internal/agent/chat/circuit.go
rename to internal/chat/circuit.go
diff --git a/internal/agent/chat/circuit_test.go b/internal/chat/circuit_test.go
similarity index 94%
rename from internal/agent/chat/circuit_test.go
rename to internal/chat/circuit_test.go
index c9a470f..4fad3b7 100644
--- a/internal/agent/chat/circuit_test.go
+++ b/internal/chat/circuit_test.go
@@ -243,21 +243,21 @@ func TestCircuitState_String(t *testing.T) {
t.Parallel()
tests := []struct {
- state CircuitState
- expected string
+ state CircuitState
+ want string
}{
- {CircuitClosed, "closed"},
- {CircuitOpen, "open"},
- {CircuitHalfOpen, "half-open"},
- {CircuitState(99), "unknown"},
+ {state: CircuitClosed, want: "closed"},
+ {state: CircuitOpen, want: "open"},
+ {state: CircuitHalfOpen, want: "half-open"},
+ {state: CircuitState(99), want: "unknown"},
}
for _, tt := range tests {
- t.Run(tt.expected, func(t *testing.T) {
+ t.Run(tt.want, func(t *testing.T) {
t.Parallel()
- if got := tt.state.String(); got != tt.expected {
- t.Errorf("String() = %q, want %q", got, tt.expected)
+ if got := tt.state.String(); got != tt.want {
+ t.Errorf("String() = %q, want %q", got, tt.want)
}
})
}
diff --git a/internal/agent/chat/doc.go b/internal/chat/doc.go
similarity index 74%
rename from internal/agent/chat/doc.go
rename to internal/chat/doc.go
index e0c5187..0f00798 100644
--- a/internal/agent/chat/doc.go
+++ b/internal/chat/doc.go
@@ -1,17 +1,17 @@
// Package chat implements Koopa's main conversational agent.
//
-// Chat is a stateless, LLM-powered agent that provides conversational interactions
+// Agent is a stateless, LLM-powered agent that provides conversational interactions
// with tool calling and knowledge base integration. It uses the Google Genkit framework
// for LLM inference and tool orchestration.
//
// # Architecture
//
-// The Chat agent follows a stateless design pattern with dependency injection:
+// The Agent follows a stateless design pattern with dependency injection:
//
// InvocationContext (input)
// |
// v
-// Chat.ExecuteStream() or Chat.Execute()
+// Agent.ExecuteStream() or Agent.Execute()
// |
// +-- Load session history from SessionStore
// |
@@ -30,27 +30,33 @@
//
// # Configuration
//
-// Chat requires configuration via the Config struct at construction time:
+// Agent requires configuration via the Config struct at construction time:
//
// type Config struct {
// Genkit *genkit.Genkit
-// Retriever ai.Retriever
// SessionStore *session.Store
-// Logger log.Logger
+// Logger *slog.Logger
// Tools []ai.Tool
//
// // Configuration values
// ModelName string // e.g., "googleai/gemini-2.5-flash"
// MaxTurns int // Maximum agentic loop turns
-// RAGTopK int // Number of RAG documents to retrieve
// Language string // Response language preference
+//
+// // Resilience configuration
+// RetryConfig RetryConfig
+// CircuitBreakerConfig CircuitBreakerConfig
+// RateLimiter *rate.Limiter
+//
+// // Token management
+// TokenBudget TokenBudget
// }
//
// Required fields are validated during construction.
//
// # Streaming Support
//
-// Chat supports both streaming and non-streaming execution modes:
+// Agent supports both streaming and non-streaming execution modes:
//
// - Execute(): Non-streaming, returns complete response
// - ExecuteStream(): Streaming with optional callback for real-time output
@@ -66,18 +72,14 @@
//
// The package provides a Genkit Flow for HTTP and observability:
//
-// - InitFlow(): Initializes singleton streaming Flow (must be called once at startup)
-// - GetFlow(): Returns initialized Flow (panics if InitFlow not called)
+// - NewFlow(): Returns singleton streaming Flow (idempotent, safe to call multiple times)
// - Flow supports both Run() and Stream() methods
// - Stream() enables Server-Sent Events (SSE) for real-time responses
//
// Example Flow usage:
//
-// // Initialize Flow once during application startup
-// chatFlow, err := chat.InitFlow(g, chatAgent)
-// if err != nil {
-// return err
-// }
+// // Create Flow (idempotent — first call initializes, subsequent calls return cached)
+// chatFlow := chat.NewFlow(g, chatAgent)
//
// // Non-streaming
// output, err := chatFlow.Run(ctx, chat.Input{Query: "Hello", SessionID: "..."})
@@ -113,16 +115,14 @@
//
// # Example Usage
//
-// // Create Chat agent with required configuration
-// chatAgent, err := chat.New(chat.Config{
+// // Create agent with required configuration
+// agent, err := chat.New(chat.Config{
// Genkit: g,
-// Retriever: retriever,
// SessionStore: sessionStore,
// Logger: slog.Default(),
// Tools: tools,
// ModelName: "googleai/gemini-2.5-flash",
// MaxTurns: 10,
-// RAGTopK: 5,
// Language: "auto",
// })
// if err != nil {
@@ -130,10 +130,10 @@
// }
//
// // Non-streaming execution
-// resp, err := chatAgent.Execute(ctx, sessionID, "What is the weather?")
+// resp, err := agent.Execute(ctx, sessionID, "What is the weather?")
//
// // Streaming execution with callback
-// resp, err := chatAgent.ExecuteStream(ctx, sessionID, "What is the weather?",
+// resp, err := agent.ExecuteStream(ctx, sessionID, "What is the weather?",
// func(ctx context.Context, chunk *ai.ModelResponseChunk) error {
// fmt.Print(chunk.Text()) // Real-time output
// return nil
@@ -143,14 +143,14 @@
//
// The package uses sentinel errors for categorization:
//
-// - agent.ErrInvalidSession: Invalid session ID format
-// - agent.ErrExecutionFailed: LLM or tool execution failed
+// - ErrInvalidSession: Invalid session ID format
+// - ErrExecutionFailed: LLM or tool execution failed
//
// Empty responses are handled with a fallback message to improve UX.
//
// # Testing
//
-// Chat is designed for testability:
+// Agent is designed for testability:
//
// - Dependencies are concrete types with clear interfaces
// - Stateless design eliminates test ordering issues
@@ -158,6 +158,6 @@
//
// # Thread Safety
//
-// Chat is safe for concurrent use. The underlying dependencies (SessionStore,
+// Agent is safe for concurrent use. The underlying dependencies (SessionStore,
// Genkit) must also be thread-safe.
package chat
diff --git a/internal/agent/chat/flow.go b/internal/chat/flow.go
similarity index 67%
rename from internal/agent/chat/flow.go
rename to internal/chat/flow.go
index 145c4c6..59423fe 100644
--- a/internal/agent/chat/flow.go
+++ b/internal/chat/flow.go
@@ -1,27 +1,23 @@
-// Package chat provides Flow definition for Chat Agent
package chat
import (
"context"
"fmt"
"sync"
- "sync/atomic"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/genkit"
"github.com/google/uuid"
-
- "github.com/koopa0/koopa/internal/agent"
)
-// Input is the input for Chat Agent Flow
+// Input defines the request payload for the chat agent flow.
type Input struct {
Query string `json:"query"`
SessionID string `json:"sessionId"` // Required field: session ID
}
-// Output is the output for Chat Agent Flow
+// Output defines the response payload from the chat agent flow.
type Output struct {
Response string `json:"response"`
SessionID string `json:"sessionId"`
@@ -40,39 +36,20 @@ const FlowName = "koopa/chat"
// Exported for use in api package with genkit.Handler().
type Flow = core.Flow[Input, Output, StreamChunk]
-// Package-level singleton for Flow to prevent Panic on re-registration.
-// sync.Once ensures the Flow is defined only once, even in tests.
+// Package-level singleton for Flow to prevent panic on re-registration.
+// sync.Once ensures genkit.DefineStreamingFlow is called only once.
var (
- flowOnce sync.Once
- flow *Flow
- flowInitDone atomic.Bool
+ flowOnce sync.Once
+ flow *Flow
)
-// InitFlow initializes the Chat Flow singleton.
-// Must be called exactly once during application startup.
-// Returns error if called more than once.
-//
-// This explicit API prevents the dangerous pattern where parameters
-// are silently ignored on subsequent calls (as in the old GetFlow).
-func InitFlow(g *genkit.Genkit, chatAgent *Chat) (*Flow, error) {
- var initialized bool
+// NewFlow returns the Chat Flow singleton, initializing it on first call.
+// Subsequent calls return the existing Flow (parameters are ignored).
+// This is safe because genkit.DefineStreamingFlow panics on re-registration.
+func NewFlow(g *genkit.Genkit, agent *Agent) *Flow {
flowOnce.Do(func() {
- flow = chatAgent.DefineFlow(g)
- flowInitDone.Store(true)
- initialized = true
+ flow = agent.DefineFlow(g)
})
- if !initialized && flowInitDone.Load() {
- return nil, fmt.Errorf("InitFlow called more than once")
- }
- return flow, nil
-}
-
-// GetFlow returns the initialized Flow singleton.
-// Panics if InitFlow was not called - this indicates a programming error.
-func GetFlow() *Flow {
- if !flowInitDone.Load() {
- panic("GetFlow called before InitFlow")
- }
return flow
}
@@ -82,14 +59,13 @@ func GetFlow() *Flow {
func ResetFlowForTesting() {
flowOnce = sync.Once{}
flow = nil
- flowInitDone.Store(false)
}
// DefineFlow defines the Genkit Streaming Flow for Chat Agent.
// Supports both streaming (via callback) and non-streaming modes.
//
-// IMPORTANT: Use GetFlow() instead of calling DefineFlow() directly.
-// DefineFlow registers a global Flow; calling it twice causes Panic.
+// IMPORTANT: Use NewFlow() instead of calling DefineFlow() directly.
+// DefineFlow registers a global Flow; calling it twice causes panic.
//
// Each Agent has its own dedicated Flow, responsible for:
// 1. Observability (Genkit DevUI tracing)
@@ -103,15 +79,13 @@ func ResetFlowForTesting() {
// - Errors are now properly returned using sentinel errors from agent package
// - Genkit tracing will correctly show error spans
// - HTTP handlers can use errors.Is() to determine error type and HTTP status
-//
-//nolint:gocognit // Genkit Flow requires orchestration logic in single function
-func (c *Chat) DefineFlow(g *genkit.Genkit) *Flow {
+func (a *Agent) DefineFlow(g *genkit.Genkit) *Flow {
return genkit.DefineStreamingFlow(g, FlowName,
func(ctx context.Context, input Input, streamCb func(context.Context, StreamChunk) error) (Output, error) {
// Parse session ID from input
sessionID, err := uuid.Parse(input.SessionID)
if err != nil {
- return Output{SessionID: input.SessionID}, fmt.Errorf("%w: %w", agent.ErrInvalidSession, err)
+ return Output{SessionID: input.SessionID}, fmt.Errorf("%w: %w", ErrInvalidSession, err)
}
// Create StreamCallback wrapper if streaming is enabled
@@ -135,10 +109,10 @@ func (c *Chat) DefineFlow(g *genkit.Genkit) *Flow {
}
// Execute with streaming callback (or non-streaming if callback is nil)
- resp, err := c.ExecuteStream(ctx, sessionID, input.Query, agentCallback)
+ resp, err := a.ExecuteStream(ctx, sessionID, input.Query, agentCallback)
if err != nil {
// Genkit will mark this span as failed, enabling proper observability
- return Output{SessionID: input.SessionID}, fmt.Errorf("%w: %w", agent.ErrExecutionFailed, err)
+ return Output{SessionID: input.SessionID}, fmt.Errorf("%w: %w", ErrExecutionFailed, err)
}
return Output{
diff --git a/internal/chat/flow_test.go b/internal/chat/flow_test.go
new file mode 100644
index 0000000..c18251b
--- /dev/null
+++ b/internal/chat/flow_test.go
@@ -0,0 +1,52 @@
+package chat
+
+import (
+ "errors"
+ "testing"
+)
+
+// TestSentinelErrors_CanBeChecked tests that sentinel errors work correctly with errors.Is
+func TestSentinelErrors_CanBeChecked(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ sentinel error
+ }{
+ {name: "ErrInvalidSession", err: ErrInvalidSession, sentinel: ErrInvalidSession},
+ {name: "ErrExecutionFailed", err: ErrExecutionFailed, sentinel: ErrExecutionFailed},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if !errors.Is(tt.err, tt.sentinel) {
+ t.Errorf("errors.Is(%v, %v) = false, want true", tt.err, tt.sentinel)
+ }
+ })
+ }
+}
+
+// TestWrappedErrors_PreserveSentinel tests that wrapped errors preserve sentinel checking
+func TestWrappedErrors_PreserveSentinel(t *testing.T) {
+ t.Parallel()
+
+ t.Run("wrapped invalid session error", func(t *testing.T) {
+ t.Parallel()
+ err := errors.New("original error")
+ wrapped := errors.Join(ErrInvalidSession, err)
+ if !errors.Is(wrapped, ErrInvalidSession) {
+ t.Errorf("errors.Is(wrapped, ErrInvalidSession) = false, want true")
+ }
+ })
+
+ t.Run("wrapped execution failed error", func(t *testing.T) {
+ t.Parallel()
+ err := errors.New("LLM timeout")
+ wrapped := errors.Join(ErrExecutionFailed, err)
+ if !errors.Is(wrapped, ErrExecutionFailed) {
+ t.Errorf("errors.Is(wrapped, ErrExecutionFailed) = false, want true")
+ }
+ })
+}
diff --git a/internal/agent/chat/integration_rag_test.go b/internal/chat/integration_rag_test.go
similarity index 71%
rename from internal/agent/chat/integration_rag_test.go
rename to internal/chat/integration_rag_test.go
index 73582aa..c825f00 100644
--- a/internal/agent/chat/integration_rag_test.go
+++ b/internal/chat/integration_rag_test.go
@@ -13,10 +13,6 @@ import (
"github.com/koopa0/koopa/internal/rag"
)
-// =============================================================================
-// Phase 0.2: RAG Integration Tests
-// =============================================================================
-
// TestChatAgent_RAGIntegration_EndToEnd verifies the complete RAG workflow:
// 1. Index document into knowledge store
// 2. Query triggers retrieval
@@ -25,16 +21,9 @@ import (
// Per Proposal 030: topK > 0 with real retriever, verify documents returned
// and integrated into prompt.
func TestChatAgent_RAGIntegration_EndToEnd(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
-
+ framework := SetupTest(t)
ctx := context.Background()
- // Ensure RAG is enabled
- if framework.Config.RAGTopK <= 0 {
- t.Fatal("RAG must be enabled for this test (RAGTopK > 0)")
- }
-
// STEP 1: Index test document using DocStore
docID := uuid.New()
testContent := "The secret password is KOOPA_TEST_123. This is a unique test string."
@@ -46,8 +35,7 @@ func TestChatAgent_RAGIntegration_EndToEnd(t *testing.T) {
t.Logf("Indexed document %s with test content", docID)
// STEP 2: Query should trigger RAG retrieval
- invCtx, sessionID := newInvocationContext(ctx, framework.SessionID)
- resp, err := framework.Agent.ExecuteStream(invCtx, sessionID,
+ resp, err := framework.Agent.ExecuteStream(ctx, framework.SessionID,
"What is the secret password from the test document?",
nil,
)
@@ -73,9 +61,7 @@ func TestChatAgent_RAGIntegration_EndToEnd(t *testing.T) {
// returns documents when topK > 0.
// Per Proposal 030: Verify documents returned and integrated into prompt.
func TestRetrieveRAGContext_ActualRetrieval(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
-
+ framework := SetupTest(t)
ctx := context.Background()
// Index multiple documents for retrieval
@@ -95,8 +81,7 @@ func TestRetrieveRAGContext_ActualRetrieval(t *testing.T) {
t.Log("Indexed 2 test documents")
// Query that should trigger retrieval
- invCtx, sessionID := newInvocationContext(ctx, framework.SessionID)
- resp, err := framework.Agent.ExecuteStream(invCtx, sessionID,
+ resp, err := framework.Agent.ExecuteStream(ctx, framework.SessionID,
"Tell me about Go programming language",
nil,
)
@@ -119,62 +104,14 @@ func TestRetrieveRAGContext_ActualRetrieval(t *testing.T) {
t.Logf("Response with RAG: %s", resp.FinalText)
}
-// TestRetrieveRAGContext_DisabledWhenTopKZero verifies that RAG is skipped
-// when topK is 0.
-func TestRetrieveRAGContext_DisabledWhenTopKZero(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
-
- ctx := context.Background()
-
- // Override config to disable RAG
- originalRAGTopK := framework.Config.RAGTopK
- framework.Config.RAGTopK = 0
- defer func() { framework.Config.RAGTopK = originalRAGTopK }()
-
- // Index document (should NOT be retrieved)
- docID := uuid.New()
- framework.IndexDocument(t, "This content should be ignored because RAG is disabled.", map[string]any{
- "id": docID.String(),
- "filename": "ignored.txt",
- "source_type": rag.SourceTypeFile,
- })
-
- // Query - should NOT trigger retrieval
- invCtx, sessionID := newInvocationContext(ctx, framework.SessionID)
- resp, err := framework.Agent.ExecuteStream(invCtx, sessionID,
- "What does the ignored document say?",
- nil,
- )
-
- if err != nil {
- t.Fatalf("ExecuteStream() unexpected error: %v", err)
- }
- if resp == nil {
- t.Fatal("ExecuteStream() response is nil, want non-nil when error is nil")
- }
- if resp.FinalText == "" {
- t.Error("ExecuteStream() response.FinalText is empty, want non-empty")
- }
-
- // Response should NOT contain the ignored content
- if strings.Contains(resp.FinalText, "should be ignored") {
- t.Errorf("ExecuteStream() response = %q, should not contain %q (RAG disabled)", resp.FinalText, "should be ignored")
- }
- t.Logf("Response without RAG (topK=0): %s", resp.FinalText)
-}
-
// TestRetrieveRAGContext_EmptyKnowledgeBase verifies graceful handling
// when knowledge base has no matching documents.
func TestRetrieveRAGContext_EmptyKnowledgeBase(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
-
+ framework := SetupTest(t)
ctx := context.Background()
// Query with empty knowledge base (no documents indexed)
- invCtx, sessionID := newInvocationContext(ctx, framework.SessionID)
- resp, err := framework.Agent.ExecuteStream(invCtx, sessionID,
+ resp, err := framework.Agent.ExecuteStream(ctx, framework.SessionID,
"What is in the knowledge base?",
nil,
)
@@ -194,9 +131,7 @@ func TestRetrieveRAGContext_EmptyKnowledgeBase(t *testing.T) {
// TestRetrieveRAGContext_MultipleRelevantDocuments verifies that when multiple
// documents match the query, the RAG system retrieves and uses them.
func TestRetrieveRAGContext_MultipleRelevantDocuments(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
-
+ framework := SetupTest(t)
ctx := context.Background()
// Index 5 related documents
@@ -222,8 +157,7 @@ func TestRetrieveRAGContext_MultipleRelevantDocuments(t *testing.T) {
t.Logf("Indexed %d related documents", len(topics))
// Query should retrieve multiple relevant documents
- invCtx, sessionID := newInvocationContext(ctx, framework.SessionID)
- resp, err := framework.Agent.ExecuteStream(invCtx, sessionID,
+ resp, err := framework.Agent.ExecuteStream(ctx, framework.SessionID,
"Summarize the key features of Go programming language",
nil,
)
diff --git a/internal/agent/chat/integration_streaming_test.go b/internal/chat/integration_streaming_test.go
similarity index 85%
rename from internal/agent/chat/integration_streaming_test.go
rename to internal/chat/integration_streaming_test.go
index 8ec1b04..784c94c 100644
--- a/internal/agent/chat/integration_streaming_test.go
+++ b/internal/chat/integration_streaming_test.go
@@ -12,19 +12,15 @@ import (
"github.com/firebase/genkit/go/ai"
)
-// =============================================================================
-// Phase 0.3: Streaming Integration Tests
-// =============================================================================
-
// TestChatAgent_StreamingCallbackError verifies that errors from streaming
// callbacks are properly propagated and stop the stream.
// Per Proposal 030: Callback returns error after N chunks, verify stream stops,
// error propagated.
func TestChatAgent_StreamingCallbackError(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
+ framework := SetupTest(t)
+ ctx := context.Background()
+ sessionID := framework.SessionID
- ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID)
chunks := 0
maxChunks := 3
@@ -66,10 +62,10 @@ func TestChatAgent_StreamingCallbackError(t *testing.T) {
// TestChatAgent_StreamingCallbackSuccess verifies that streaming works
// correctly when callback always succeeds.
func TestChatAgent_StreamingCallbackSuccess(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
+ framework := SetupTest(t)
+ ctx := context.Background()
+ sessionID := framework.SessionID
- ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID)
chunks := 0
var receivedTexts []string
@@ -105,9 +101,7 @@ func TestChatAgent_StreamingCallbackSuccess(t *testing.T) {
// TestChatAgent_StreamingVsNonStreaming verifies that streaming and non-streaming
// modes produce equivalent results.
func TestChatAgent_StreamingVsNonStreaming(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
-
+ framework := SetupTest(t)
ctx := context.Background()
query := "What is 2+2? Answer with just the number."
@@ -116,8 +110,7 @@ func TestChatAgent_StreamingVsNonStreaming(t *testing.T) {
// This is a standard Go idiom (nil function = skip optional behavior).
// Contract: When callback is nil, the method returns only after full completion.
session1 := framework.CreateTestSession(t, "Non-streaming test")
- invCtx1, sessionID1 := newInvocationContext(ctx, session1)
- respNoStream, err := framework.Agent.ExecuteStream(invCtx1, sessionID1,
+ respNoStream, err := framework.Agent.ExecuteStream(ctx, session1,
query,
nil, // No callback = non-streaming mode (returns complete response)
)
@@ -133,14 +126,13 @@ func TestChatAgent_StreamingVsNonStreaming(t *testing.T) {
// Streaming execution
session2 := framework.CreateTestSession(t, "Streaming test")
- invCtx2, sessionID2 := newInvocationContext(ctx, session2)
var streamedResponse string
callback := func(ctx context.Context, chunk *ai.ModelResponseChunk) error {
streamedResponse += chunk.Text()
return nil
}
- respStream, err := framework.Agent.ExecuteStream(invCtx2, sessionID2,
+ respStream, err := framework.Agent.ExecuteStream(ctx, session2,
query,
callback,
)
@@ -169,17 +161,14 @@ func TestChatAgent_StreamingVsNonStreaming(t *testing.T) {
// TestChatAgent_StreamingContextCancellation verifies that canceling the
// context stops streaming.
func TestChatAgent_StreamingContextCancellation(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
+ framework := SetupTest(t)
// Cancel BEFORE starting stream to guarantee cancellation
// This is deterministic - no race between stream completion and cancellation
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
- _, sessionID := newInvocationContext(context.Background(), framework.SessionID)
-
- resp, err := framework.Agent.ExecuteStream(ctx, sessionID,
+ resp, err := framework.Agent.ExecuteStream(ctx, framework.SessionID,
"Write a very long story",
nil,
)
@@ -200,10 +189,10 @@ func TestChatAgent_StreamingContextCancellation(t *testing.T) {
// TestChatAgent_StreamingEmptyChunks verifies handling of empty chunks
// in streaming mode.
func TestChatAgent_StreamingEmptyChunks(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
+ framework := SetupTest(t)
+ ctx := context.Background()
+ sessionID := framework.SessionID
- ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID)
totalChunks := 0
emptyChunks := 0
diff --git a/internal/agent/chat/integration_test.go b/internal/chat/integration_test.go
similarity index 80%
rename from internal/agent/chat/integration_test.go
rename to internal/chat/integration_test.go
index e69e6e0..7d9906f 100644
--- a/internal/agent/chat/integration_test.go
+++ b/internal/chat/integration_test.go
@@ -16,15 +16,14 @@ import (
"github.com/firebase/genkit/go/ai"
- "github.com/koopa0/koopa/internal/agent/chat"
+ "github.com/koopa0/koopa/internal/chat"
)
// TestChatAgent_BasicExecution tests basic chat agent execution
func TestChatAgent_BasicExecution(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
-
- ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID)
+ framework := SetupTest(t)
+ ctx := context.Background()
+ sessionID := framework.SessionID
t.Run("simple question", func(t *testing.T) {
resp, err := framework.Agent.Execute(ctx, sessionID, "Hello, how are you?")
@@ -42,10 +41,9 @@ func TestChatAgent_BasicExecution(t *testing.T) {
// TestChatAgent_SessionPersistence tests conversation history persistence
func TestChatAgent_SessionPersistence(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
-
- ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID)
+ framework := SetupTest(t)
+ ctx := context.Background()
+ sessionID := framework.SessionID
t.Run("first message creates history", func(t *testing.T) {
resp, err := framework.Agent.Execute(ctx, sessionID, "My name is Koopa")
@@ -77,10 +75,9 @@ func TestChatAgent_SessionPersistence(t *testing.T) {
// TestChatAgent_ToolIntegration tests tool calling capability
func TestChatAgent_ToolIntegration(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
-
- ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID)
+ framework := SetupTest(t)
+ ctx := context.Background()
+ sessionID := framework.SessionID
t.Run("can use file tools", func(t *testing.T) {
// Create unique marker file to verify tool was actually invoked
@@ -114,19 +111,24 @@ func TestChatAgent_ToolIntegration(t *testing.T) {
// TestChatAgent_ErrorHandling tests error scenarios
func TestChatAgent_ErrorHandling(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
+ framework := SetupTest(t)
t.Run("handles empty input gracefully", func(t *testing.T) {
- ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID)
+ ctx := context.Background()
+ sessionID := framework.SessionID
resp, err := framework.Agent.Execute(ctx, sessionID, "")
- // Should handle empty input without crashing
- // Either returns error or empty response
- if err == nil {
- if resp == nil {
- t.Error("Execute() with empty input returned nil response and nil error")
- }
+ // Agent should handle empty input without panicking.
+ // The LLM may return a valid response or an error — both are acceptable.
+ if err != nil {
+ t.Logf("Execute(\"\") returned error (acceptable): %v", err)
+ return
+ }
+ if resp == nil {
+ t.Fatal("Execute(\"\") = nil, nil — want non-nil response or non-nil error")
+ }
+ if resp.FinalText == "" {
+ t.Error("Execute(\"\") response.FinalText is empty, want non-empty (at minimum the fallback message)")
}
})
}
@@ -134,12 +136,10 @@ func TestChatAgent_ErrorHandling(t *testing.T) {
// TestChatAgent_NewChatValidation tests constructor validation
func TestChatAgent_NewChatValidation(t *testing.T) {
// Setup test framework once for all validation tests
- framework, cleanup := SetupTest(t)
- defer cleanup()
+ framework := SetupTest(t)
t.Run("requires genkit", func(t *testing.T) {
_, err := chat.New(chat.Config{
- Retriever: framework.Retriever,
SessionStore: framework.SessionStore,
Logger: slog.Default(),
Tools: []ai.Tool{},
@@ -152,27 +152,11 @@ func TestChatAgent_NewChatValidation(t *testing.T) {
}
})
- t.Run("requires retriever", func(t *testing.T) {
- _, err := chat.New(chat.Config{
- Genkit: framework.Genkit,
- SessionStore: framework.SessionStore,
- Logger: slog.Default(),
- Tools: []ai.Tool{},
- })
- if err == nil {
- t.Fatal("New() expected error, got nil")
- }
- if !strings.Contains(err.Error(), "retriever is required") {
- t.Errorf("New() error = %q, want to contain %q", err.Error(), "retriever is required")
- }
- })
-
t.Run("requires session store", func(t *testing.T) {
_, err := chat.New(chat.Config{
- Genkit: framework.Genkit,
- Retriever: framework.Retriever,
- Logger: slog.Default(),
- Tools: []ai.Tool{},
+ Genkit: framework.Genkit,
+ Logger: slog.Default(),
+ Tools: []ai.Tool{},
})
if err == nil {
t.Fatal("New() expected error, got nil")
@@ -185,7 +169,6 @@ func TestChatAgent_NewChatValidation(t *testing.T) {
t.Run("requires logger", func(t *testing.T) {
_, err := chat.New(chat.Config{
Genkit: framework.Genkit,
- Retriever: framework.Retriever,
SessionStore: framework.SessionStore,
Tools: []ai.Tool{},
})
@@ -200,7 +183,6 @@ func TestChatAgent_NewChatValidation(t *testing.T) {
t.Run("requires at least one tool", func(t *testing.T) {
_, err := chat.New(chat.Config{
Genkit: framework.Genkit,
- Retriever: framework.Retriever,
SessionStore: framework.SessionStore,
Logger: slog.Default(),
Tools: []ai.Tool{},
@@ -218,14 +200,14 @@ func TestChatAgent_NewChatValidation(t *testing.T) {
// Uses mutex-protected error collection instead of assert/require in goroutines
// to avoid test reliability issues with t.FailNow() from goroutines.
func TestChatAgent_ConcurrentExecution(t *testing.T) {
- framework, cleanup := SetupTest(t)
- defer cleanup()
+ framework := SetupTest(t)
numConcurrentQueries := 5
var wg sync.WaitGroup
wg.Add(numConcurrentQueries)
- ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID)
+ ctx := context.Background()
+ sessionID := framework.SessionID
// Collect results safely
type result struct {
diff --git a/internal/agent/chat/retry.go b/internal/chat/retry.go
similarity index 63%
rename from internal/agent/chat/retry.go
rename to internal/chat/retry.go
index 9b19f33..3dbc70f 100644
--- a/internal/agent/chat/retry.go
+++ b/internal/chat/retry.go
@@ -25,29 +25,31 @@ func DefaultRetryConfig() RetryConfig {
}
}
-// retryableError determines if an error should trigger a retry.
+// retryablePatterns groups error substrings by category.
+// Matched case-insensitively against err.Error().
+//
+// NOTE: This uses string matching because Genkit and LLM provider SDKs
+// do not expose typed/sentinel errors for transient failures.
+// This is a documented exception to the project rule against
+// strings.Contains(err.Error(), ...).
+// Re-evaluate if Genkit adds structured error types in a future version.
+var retryablePatterns = [][]string{
+ {"rate limit", "quota exceeded", "429"}, // rate limiting
+ {"500", "502", "503", "504", "unavailable"}, // transient server errors
+ {"connection reset", "timeout", "temporary"}, // network errors
+}
+
+// retryableError reports whether err is transient and should trigger a retry.
func retryableError(err error) bool {
if err == nil {
return false
}
-
errStr := err.Error()
-
- // Rate limit errors - always retry
- if containsAny(errStr, "rate limit", "quota exceeded", "429") {
- return true
- }
-
- // Transient server errors - retry
- if containsAny(errStr, "500", "502", "503", "504", "unavailable") {
- return true
- }
-
- // Network errors - retry
- if containsAny(errStr, "connection reset", "timeout", "temporary") {
- return true
+ for _, group := range retryablePatterns {
+ if containsAny(errStr, group...) {
+ return true
+ }
}
-
return false
}
@@ -69,27 +71,27 @@ func containsAny(s string, substrs ...string) bool {
// - Rate limits EACH attempt (per golang-master review)
// - Tracks elapsed time for observability (per Rob Pike review)
// - Exponential backoff with configurable max interval
-func (c *Chat) executeWithRetry(
+func (a *Agent) executeWithRetry(
ctx context.Context,
opts []ai.PromptExecuteOption,
) (*ai.ModelResponse, error) {
var lastErr error
- delay := c.retryConfig.InitialInterval
+ delay := a.retryConfig.InitialInterval
start := time.Now() // Track elapsed time
- for attempt := 0; attempt <= c.retryConfig.MaxRetries; attempt++ {
+ for attempt := 0; attempt <= a.retryConfig.MaxRetries; attempt++ {
// Rate limit EACH attempt (per golang-master review)
- if c.rateLimiter != nil {
- if err := c.rateLimiter.Wait(ctx); err != nil {
+ if a.rateLimiter != nil {
+ if err := a.rateLimiter.Wait(ctx); err != nil {
return nil, fmt.Errorf("rate limit wait: %w", err)
}
}
// Attempt execution
- resp, err := c.prompt.Execute(ctx, opts...)
+ resp, err := a.prompt.Execute(ctx, opts...)
if err == nil {
elapsed := time.Since(start)
- c.logger.Debug("prompt executed successfully",
+ a.logger.Debug("prompt executed successfully",
"attempts", attempt+1,
"elapsed", elapsed,
)
@@ -104,12 +106,12 @@ func (c *Chat) executeWithRetry(
}
// Last attempt - don't sleep
- if attempt == c.retryConfig.MaxRetries {
+ if attempt == a.retryConfig.MaxRetries {
break
}
// Exponential backoff with context cancellation check
- c.logger.Debug("retrying after error",
+ a.logger.Debug("retrying after error",
"attempt", attempt+1,
"delay", delay,
"elapsed", time.Since(start),
@@ -120,11 +122,11 @@ func (c *Chat) executeWithRetry(
case <-ctx.Done():
return nil, fmt.Errorf("context canceled during retry: %w", ctx.Err())
case <-time.After(delay):
- delay = min(delay*2, c.retryConfig.MaxInterval)
+ delay = min(delay*2, a.retryConfig.MaxInterval)
}
}
elapsed := time.Since(start)
return nil, fmt.Errorf("prompt execute after %d retries (elapsed: %v): %w",
- c.retryConfig.MaxRetries, elapsed, lastErr)
+ a.retryConfig.MaxRetries, elapsed, lastErr)
}
diff --git a/internal/chat/retry_test.go b/internal/chat/retry_test.go
new file mode 100644
index 0000000..425c41d
--- /dev/null
+++ b/internal/chat/retry_test.go
@@ -0,0 +1,196 @@
+package chat
+
+import (
+ "errors"
+ "testing"
+)
+
+func TestDefaultRetryConfig(t *testing.T) {
+ t.Parallel()
+
+ cfg := DefaultRetryConfig()
+
+ if cfg.MaxRetries <= 0 {
+ t.Errorf("MaxRetries should be positive, got %d", cfg.MaxRetries)
+ }
+ if cfg.InitialInterval <= 0 {
+ t.Errorf("InitialInterval should be positive, got %v", cfg.InitialInterval)
+ }
+ if cfg.MaxInterval <= 0 {
+ t.Errorf("MaxInterval should be positive, got %v", cfg.MaxInterval)
+ }
+ if cfg.MaxInterval < cfg.InitialInterval {
+ t.Error("MaxInterval should be >= InitialInterval")
+ }
+}
+
+func TestRetryableError(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ want bool
+ }{
+ {
+ name: "nil error",
+ err: nil,
+ want: false,
+ },
+ {
+ name: "rate limit error",
+ err: errors.New("rate limit exceeded"),
+ want: true,
+ },
+ {
+ name: "quota exceeded error",
+ err: errors.New("quota exceeded for project"),
+ want: true,
+ },
+ {
+ name: "429 status code",
+ err: errors.New("HTTP 429: Too Many Requests"),
+ want: true,
+ },
+ {
+ name: "500 server error",
+ err: errors.New("HTTP 500 Internal Server Error"),
+ want: true,
+ },
+ {
+ name: "502 bad gateway",
+ err: errors.New("502 Bad Gateway"),
+ want: true,
+ },
+ {
+ name: "503 unavailable",
+ err: errors.New("503 Service Unavailable"),
+ want: true,
+ },
+ {
+ name: "504 gateway timeout",
+ err: errors.New("504 Gateway Timeout"),
+ want: true,
+ },
+ {
+ name: "unavailable keyword",
+ err: errors.New("service unavailable"),
+ want: true,
+ },
+ {
+ name: "connection reset",
+ err: errors.New("connection reset by peer"),
+ want: true,
+ },
+ {
+ name: "timeout error",
+ err: errors.New("request timeout"),
+ want: true,
+ },
+ {
+ name: "temporary error",
+ err: errors.New("temporary failure"),
+ want: true,
+ },
+ {
+ name: "non-retryable error",
+ err: errors.New("invalid API key"),
+ want: false,
+ },
+ {
+ name: "non-retryable 400 error",
+ err: errors.New("HTTP 400 Bad Request"),
+ want: false,
+ },
+ {
+ name: "non-retryable 401 error",
+ err: errors.New("HTTP 401 Unauthorized"),
+ want: false,
+ },
+ {
+ name: "non-retryable 403 error",
+ err: errors.New("HTTP 403 Forbidden"),
+ want: false,
+ },
+ {
+ name: "case insensitive rate limit",
+ err: errors.New("RATE LIMIT reached"),
+ want: true,
+ },
+ {
+ name: "case insensitive timeout",
+ err: errors.New("TIMEOUT occurred"),
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ got := retryableError(tt.err)
+ if got != tt.want {
+ t.Errorf("retryableError(%v) = %v, want %v", tt.err, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestContainsAny(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ s string
+ substrs []string
+ want bool
+ }{
+ {
+ name: "empty string",
+ s: "",
+ substrs: []string{"foo"},
+ want: false,
+ },
+ {
+ name: "empty substrs",
+ s: "foo bar",
+ substrs: []string{},
+ want: false,
+ },
+ {
+ name: "contains first substr",
+ s: "foo bar baz",
+ substrs: []string{"foo", "qux"},
+ want: true,
+ },
+ {
+ name: "contains last substr",
+ s: "foo bar baz",
+ substrs: []string{"qux", "baz"},
+ want: true,
+ },
+ {
+ name: "case insensitive match",
+ s: "FOO BAR BAZ",
+ substrs: []string{"foo"},
+ want: true,
+ },
+ {
+ name: "no match",
+ s: "foo bar baz",
+ substrs: []string{"qux", "quux"},
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ got := containsAny(tt.s, tt.substrs...)
+ if got != tt.want {
+ t.Errorf("containsAny(%q, %v) = %v, want %v", tt.s, tt.substrs, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/internal/agent/chat/setup_test.go b/internal/chat/setup_test.go
similarity index 73%
rename from internal/agent/chat/setup_test.go
rename to internal/chat/setup_test.go
index eef3123..8cd1bed 100644
--- a/internal/agent/chat/setup_test.go
+++ b/internal/chat/setup_test.go
@@ -22,7 +22,7 @@ import (
"github.com/firebase/genkit/go/plugins/postgresql"
"github.com/google/uuid"
- "github.com/koopa0/koopa/internal/agent/chat"
+ "github.com/koopa0/koopa/internal/chat"
"github.com/koopa0/koopa/internal/config"
"github.com/koopa0/koopa/internal/rag"
"github.com/koopa0/koopa/internal/security"
@@ -34,9 +34,10 @@ import (
// TestFramework provides a complete test environment for chat integration tests.
// This is the chat-specific equivalent of testutil.AgentTestFramework.
+// Cleanup is automatic via tb.Cleanup — no manual cleanup needed.
type TestFramework struct {
// Core components
- Agent *chat.Chat
+ Agent *chat.Agent
Flow *chat.Flow
DocStore *postgresql.DocStore // For indexing documents in tests
Retriever ai.Retriever // Genkit Retriever for RAG
@@ -50,8 +51,6 @@ type TestFramework struct {
// Test session (fresh per framework instance)
SessionID uuid.UUID
-
- cleanup func()
}
// SetupTest creates a complete chat test environment.
@@ -66,13 +65,12 @@ type TestFramework struct {
// Example:
//
// func TestChatFeature(t *testing.T) {
-// framework, cleanup := SetupTest(t)
-// defer cleanup()
+// framework := SetupTest(t)
//
// ctx, sessionID := newInvocationContext(context.Background(), framework.SessionID)
// resp, err := framework.Agent.Execute(ctx, sessionID, "test query")
// }
-func SetupTest(t *testing.T) (*TestFramework, func()) {
+func SetupTest(t *testing.T) *TestFramework {
t.Helper()
apiKey := os.Getenv("GEMINI_API_KEY")
@@ -82,8 +80,8 @@ func SetupTest(t *testing.T) (*TestFramework, func()) {
ctx := context.Background()
- // Layer 1: Use testutil primitives
- dbContainer, dbCleanup := testutil.SetupTestDB(t)
+ // Layer 1: Use testutil primitives (cleanup is automatic via tb.Cleanup)
+ dbContainer := testutil.SetupTestDB(t)
// Setup RAG with Genkit PostgreSQL plugin
ragSetup := testutil.SetupRAG(t, dbContainer.Pool)
@@ -94,10 +92,9 @@ func SetupTest(t *testing.T) (*TestFramework, func()) {
cfg := &config.Config{
ModelName: "googleai/gemini-2.5-flash",
- EmbedderModel: "text-embedding-004",
+ EmbedderModel: "gemini-embedding-001",
Temperature: 0.7,
MaxTokens: 8192,
- RAGTopK: 5,
PostgresHost: "localhost",
PostgresPort: 5432,
PostgresUser: "koopa_test",
@@ -109,58 +106,47 @@ func SetupTest(t *testing.T) (*TestFramework, func()) {
}
// Create test session
- testSession, err := sessionStore.CreateSession(ctx, "Chat Integration Test", cfg.ModelName, "")
+ testSession, err := sessionStore.CreateSession(ctx, "Chat Integration Test")
if err != nil {
- dbCleanup()
- t.Fatalf("Failed to create test session: %v", err)
+ t.Fatalf("creating test session: %v", err)
}
// Create toolsets
pathValidator, err := security.NewPath([]string{os.TempDir()})
if err != nil {
- dbCleanup()
- t.Fatalf("Failed to create path validator: %v", err)
+ t.Fatalf("creating path validator: %v", err)
}
// Create file tools instance
- fileToolset, err := tools.NewFileTools(pathValidator, slog.Default())
+ fileToolset, err := tools.NewFile(pathValidator, slog.Default())
if err != nil {
- dbCleanup()
- t.Fatalf("Failed to create file tools: %v", err)
+ t.Fatalf("creating file tools: %v", err)
}
// Register file tools with Genkit
- fileTools, err := tools.RegisterFileTools(ragSetup.Genkit, fileToolset)
+ fileTools, err := tools.RegisterFile(ragSetup.Genkit, fileToolset)
if err != nil {
- dbCleanup()
- t.Fatalf("Failed to register file tools: %v", err)
+ t.Fatalf("registering file tools: %v", err)
}
// Create Chat Agent
chatAgent, err := chat.New(chat.Config{
Genkit: ragSetup.Genkit,
- Retriever: ragSetup.Retriever,
SessionStore: sessionStore,
Logger: slog.Default(),
Tools: fileTools,
MaxTurns: cfg.MaxTurns,
- RAGTopK: cfg.RAGTopK,
Language: cfg.Language,
})
if err != nil {
- dbCleanup()
- t.Fatalf("Failed to create chat agent: %v", err)
+ t.Fatalf("creating chat agent: %v", err)
}
// Initialize Flow singleton (reset first for test isolation)
chat.ResetFlowForTesting()
- flow, err := chat.InitFlow(ragSetup.Genkit, chatAgent)
- if err != nil {
- dbCleanup()
- t.Fatalf("Failed to init chat flow: %v", err)
- }
+ flow := chat.NewFlow(ragSetup.Genkit, chatAgent)
- framework := &TestFramework{
+ return &TestFramework{
Agent: chatAgent,
Flow: flow,
DocStore: ragSetup.DocStore,
@@ -171,30 +157,20 @@ func SetupTest(t *testing.T) (*TestFramework, func()) {
Genkit: ragSetup.Genkit,
Embedder: ragSetup.Embedder,
SessionID: testSession.ID,
- cleanup: dbCleanup,
}
-
- return framework, dbCleanup
}
// CreateTestSession creates a new isolated session for test isolation.
func (f *TestFramework) CreateTestSession(t *testing.T, name string) uuid.UUID {
t.Helper()
ctx := context.Background()
- sess, err := f.SessionStore.CreateSession(ctx, name, f.Config.ModelName, "")
+ sess, err := f.SessionStore.CreateSession(ctx, name)
if err != nil {
- t.Fatalf("Failed to create test session: %v", err)
+ t.Fatalf("creating test session: %v", err)
}
return sess.ID
}
-// newInvocationContext creates a simple invocation context for integration tests.
-// This helper exists to maintain test readability after removing agent.InvocationContext.
-// NOTE: Branch parameter was removed in Proposal 051.
-func newInvocationContext(ctx context.Context, sessionID uuid.UUID) (context.Context, uuid.UUID) {
- return ctx, sessionID
-}
-
// IndexDocument indexes a document using the Genkit DocStore.
// This is a test helper for adding documents to the RAG knowledge base.
func (f *TestFramework) IndexDocument(t *testing.T, content string, metadata map[string]any) {
@@ -211,6 +187,6 @@ func (f *TestFramework) IndexDocument(t *testing.T, content string, metadata map
doc := ai.DocumentFromText(content, metadata)
if err := f.DocStore.Index(ctx, []*ai.Document{doc}); err != nil {
- t.Fatalf("Failed to index document: %v", err)
+ t.Fatalf("indexing document: %v", err)
}
}
diff --git a/internal/agent/chat/tokens.go b/internal/chat/tokens.go
similarity index 87%
rename from internal/agent/chat/tokens.go
rename to internal/chat/tokens.go
index 24c7dde..1572177 100644
--- a/internal/agent/chat/tokens.go
+++ b/internal/chat/tokens.go
@@ -12,10 +12,11 @@ type TokenBudget struct {
MaxHistoryTokens int // Maximum tokens for conversation history
}
-// DefaultTokenBudget returns conservative defaults for Gemini models.
+// DefaultTokenBudget returns defaults for modern large-context models.
+// 32K tokens ≈ 64 conversation turns — balances context retention with cost.
func DefaultTokenBudget() TokenBudget {
return TokenBudget{
- MaxHistoryTokens: 8000, // ~8K tokens for history
+ MaxHistoryTokens: 32000,
}
}
@@ -49,7 +50,7 @@ func estimateMessagesTokens(msgs []*ai.Message) int {
// truncateHistory removes oldest messages to fit within budget.
// Preserves system message (if present) and keeps most recent messages.
-func (c *Chat) truncateHistory(msgs []*ai.Message, budget int) []*ai.Message {
+func (a *Agent) truncateHistory(msgs []*ai.Message, budget int) []*ai.Message {
if len(msgs) == 0 {
return msgs
}
@@ -59,7 +60,7 @@ func (c *Chat) truncateHistory(msgs []*ai.Message, budget int) []*ai.Message {
return msgs
}
- c.logger.Debug("truncating history",
+ a.logger.Debug("truncating history",
"current_tokens", currentTokens,
"budget", budget,
"message_count", len(msgs),
@@ -92,7 +93,7 @@ func (c *Chat) truncateHistory(msgs []*ai.Message, budget int) []*ai.Message {
// Append kept messages after system message
result = append(result, kept...)
- c.logger.Debug("history truncated",
+ a.logger.Debug("history truncated",
"original_count", len(msgs),
"new_count", len(result),
"tokens_used", budget-remaining,
diff --git a/internal/agent/chat/tokens_test.go b/internal/chat/tokens_test.go
similarity index 82%
rename from internal/agent/chat/tokens_test.go
rename to internal/chat/tokens_test.go
index fbcaa35..650b1d9 100644
--- a/internal/agent/chat/tokens_test.go
+++ b/internal/chat/tokens_test.go
@@ -1,11 +1,10 @@
package chat
import (
+ "log/slog"
"testing"
"github.com/firebase/genkit/go/ai"
-
- "github.com/koopa0/koopa/internal/log"
)
func TestDefaultTokenBudget(t *testing.T) {
@@ -22,39 +21,39 @@ func TestEstimateTokens(t *testing.T) {
t.Parallel()
tests := []struct {
- name string
- text string
- expected int
+ name string
+ text string
+ want int
}{
{
- name: "empty string",
- text: "",
- expected: 0,
+ name: "empty string",
+ text: "",
+ want: 0,
},
{
- name: "single char returns 1",
- text: "a",
- expected: 1, // 1 rune / 2 = 0, but min 1 for non-empty
+ name: "single char returns 1",
+ text: "a",
+ want: 1, // 1 rune / 2 = 0, but min 1 for non-empty
},
{
- name: "short english",
- text: "hello",
- expected: 2, // 5 runes / 2 = 2
+ name: "short english",
+ text: "hello",
+ want: 2, // 5 runes / 2 = 2
},
{
- name: "longer english",
- text: "This is a longer test message with multiple words.",
- expected: 25, // 50 runes / 2 = 25
+ name: "longer english",
+ text: "This is a longer test message with multiple words.",
+ want: 25, // 50 runes / 2 = 25
},
{
- name: "cjk text",
- text: "你好世界",
- expected: 2, // 4 runes / 2 = 2
+ name: "cjk text",
+ text: "你好世界",
+ want: 2, // 4 runes / 2 = 2
},
{
- name: "mixed text",
- text: "Hello 世界",
- expected: 4, // 8 runes / 2 = 4
+ name: "mixed text",
+ text: "Hello 世界",
+ want: 4, // 8 runes / 2 = 4
},
}
@@ -63,8 +62,8 @@ func TestEstimateTokens(t *testing.T) {
t.Parallel()
got := estimateTokens(tt.text)
- if got != tt.expected {
- t.Errorf("estimateTokens(%q) = %d, want %d", tt.text, got, tt.expected)
+ if got != tt.want {
+ t.Errorf("estimateTokens(%q) = %d, want %d", tt.text, got, tt.want)
}
})
}
@@ -74,26 +73,26 @@ func TestEstimateMessagesTokens(t *testing.T) {
t.Parallel()
tests := []struct {
- name string
- msgs []*ai.Message
- expected int
+ name string
+ msgs []*ai.Message
+ want int
}{
{
- name: "nil messages",
- msgs: nil,
- expected: 0,
+ name: "nil messages",
+ msgs: nil,
+ want: 0,
},
{
- name: "empty messages",
- msgs: []*ai.Message{},
- expected: 0,
+ name: "empty messages",
+ msgs: []*ai.Message{},
+ want: 0,
},
{
name: "single message",
msgs: []*ai.Message{
ai.NewUserMessage(ai.NewTextPart("hello world")), // 11 runes / 2 = 5
},
- expected: 5,
+ want: 5,
},
{
name: "multiple messages",
@@ -102,7 +101,7 @@ func TestEstimateMessagesTokens(t *testing.T) {
ai.NewModelMessage(ai.NewTextPart("world")), // 5 / 2 = 2
ai.NewUserMessage(ai.NewTextPart("how are you")), // 11 / 2 = 5
},
- expected: 9,
+ want: 9,
},
}
@@ -111,8 +110,8 @@ func TestEstimateMessagesTokens(t *testing.T) {
t.Parallel()
got := estimateMessagesTokens(tt.msgs)
- if got != tt.expected {
- t.Errorf("estimateMessagesTokens() = %d, want %d", got, tt.expected)
+ if got != tt.want {
+ t.Errorf("estimateMessagesTokens() = %d, want %d", got, tt.want)
}
})
}
@@ -122,9 +121,9 @@ func TestTruncateHistory(t *testing.T) {
t.Parallel()
// Helper to create a Chat with nop logger for testing
- makeChat := func() *Chat {
- return &Chat{
- logger: log.NewNop(),
+ makeAgent := func() *Agent {
+ return &Agent{
+ logger: slog.New(slog.DiscardHandler),
}
}
@@ -234,8 +233,8 @@ func TestTruncateHistory(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
- chat := makeChat()
- got := chat.truncateHistory(tt.msgs, tt.budget)
+ agent := makeAgent()
+ got := agent.truncateHistory(tt.msgs, tt.budget)
if len(got) != tt.wantLen {
t.Errorf("truncateHistory() len = %d, want %d", len(got), tt.wantLen)
@@ -248,7 +247,7 @@ func TestTruncateHistory(t *testing.T) {
// Check system message preservation
if tt.wantHasSystem {
if got[0].Role != ai.RoleSystem {
- t.Errorf("expected first message to be system, got %s", got[0].Role)
+ t.Errorf("want first message to be system, got %s", got[0].Role)
}
}
@@ -267,7 +266,7 @@ func TestTruncateHistory(t *testing.T) {
// wantTexts[i] must match got[i] - this implicitly verifies ordering
if len(tt.wantTexts) > 0 {
if len(got) != len(tt.wantTexts) {
- t.Fatalf("got %d messages but expected %d texts to verify", len(got), len(tt.wantTexts))
+ t.Fatalf("got %d messages but want %d texts to verify", len(got), len(tt.wantTexts))
}
for i, want := range tt.wantTexts {
if len(got[i].Content) == 0 {
@@ -286,7 +285,7 @@ func TestTruncateHistory(t *testing.T) {
func TestTruncateHistory_ChronologicalOrder(t *testing.T) {
t.Parallel()
- chat := &Chat{logger: log.NewNop()}
+ agent := &Agent{logger: slog.New(slog.DiscardHandler)}
// Create a conversation with alternating user/model messages
msgs := []*ai.Message{
@@ -298,7 +297,7 @@ func TestTruncateHistory_ChronologicalOrder(t *testing.T) {
}
// Budget should keep only last 2-3 messages
- result := chat.truncateHistory(msgs, 6)
+ result := agent.truncateHistory(msgs, 6)
// Verify messages are still in chronological order
for i := 1; i < len(result); i++ {
diff --git a/internal/config/ai.go b/internal/config/ai.go
deleted file mode 100644
index 7823a8d..0000000
--- a/internal/config/ai.go
+++ /dev/null
@@ -1,14 +0,0 @@
-package config
-
-// AIConfig holds AI model configuration.
-// Fields are embedded in the main Config struct for backward compatibility.
-// Documented separately for clarity.
-//
-// Configuration options:
-// - Provider: AI provider ("gemini", "ollama", "openai")
-// - ModelName: Model identifier (e.g., "gemini-2.5-flash", "llama3.3", "gpt-4o")
-// - Temperature: 0.0 (deterministic) to 2.0 (creative)
-// - MaxTokens: 1 to 2,097,152 (Gemini 2.5 max context)
-// - Language: Response language ("auto", "English", "zh-TW")
-// - PromptDir: Directory for .prompt files (Dotprompt)
-// - OllamaHost: Ollama server address (default: "http://localhost:11434")
diff --git a/internal/config/ai_test.go b/internal/config/ai_test.go
deleted file mode 100644
index 5921c5c..0000000
--- a/internal/config/ai_test.go
+++ /dev/null
@@ -1,33 +0,0 @@
-package config
-
-import "testing"
-
-// TestFullModelName tests that FullModelName derives correct provider-qualified names.
-func TestFullModelName(t *testing.T) {
- tests := []struct {
- name string
- provider string
- modelName string
- want string
- }{
- {name: "gemini default", provider: "", modelName: "gemini-2.5-flash", want: "googleai/gemini-2.5-flash"},
- {name: "gemini explicit", provider: "gemini", modelName: "gemini-2.5-pro", want: "googleai/gemini-2.5-pro"},
- {name: "ollama", provider: "ollama", modelName: "llama3.3", want: "ollama/llama3.3"},
- {name: "openai", provider: "openai", modelName: "gpt-4o", want: "openai/gpt-4o"},
- {name: "already qualified", provider: "ollama", modelName: "ollama/llama3.3", want: "ollama/llama3.3"},
- {name: "cross-qualified", provider: "gemini", modelName: "openai/gpt-4o", want: "openai/gpt-4o"},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- cfg := &Config{
- Provider: tt.provider,
- ModelName: tt.modelName,
- }
- got := cfg.FullModelName()
- if got != tt.want {
- t.Errorf("FullModelName() = %q, want %q", got, tt.want)
- }
- })
- }
-}
diff --git a/internal/config/config.go b/internal/config/config.go
index 109b973..2f10ade 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -6,9 +6,9 @@
// 3. Default values (sensible defaults for quick start)
//
// Main configuration categories:
-// - AI: Model selection, temperature, max tokens, embedder (see ai.go)
-// - Storage: SQLite database path, PostgreSQL connection (see storage.go)
-// - RAG: Number of documents to retrieve (RAGTopK)
+// - AI: Model selection, temperature, max tokens, embedder
+// - Storage: PostgreSQL connection (see storage.go)
+// - RAG: Embedder model for vector embeddings
// - MCP: Model Context Protocol server management (see tools.go)
// - Observability: Datadog APM tracing (see observability.go)
//
@@ -32,10 +32,6 @@ import (
"github.com/spf13/viper"
)
-// ============================================================================
-// Sentinel Errors
-// ============================================================================
-
var (
// ErrConfigNil indicates the configuration is nil.
ErrConfigNil = errors.New("configuration is nil")
@@ -52,12 +48,12 @@ var (
// ErrInvalidMaxTokens indicates the max tokens value is out of range.
ErrInvalidMaxTokens = errors.New("invalid max tokens")
- // ErrInvalidRAGTopK indicates the RAG top-k value is out of range.
- ErrInvalidRAGTopK = errors.New("invalid RAG top-k")
-
// ErrInvalidEmbedderModel indicates the embedder model is invalid.
ErrInvalidEmbedderModel = errors.New("invalid embedder model")
+ // ErrInvalidEmbedderDimension indicates the embedder produces incompatible vector dimensions.
+ ErrInvalidEmbedderDimension = errors.New("incompatible embedder dimension")
+
// ErrInvalidPostgresHost indicates the PostgreSQL host is invalid.
ErrInvalidPostgresHost = errors.New("invalid PostgreSQL host")
@@ -67,9 +63,6 @@ var (
// ErrInvalidPostgresDBName indicates the PostgreSQL database name is invalid.
ErrInvalidPostgresDBName = errors.New("invalid PostgreSQL database name")
- // ErrConfigParse indicates configuration parsing failed.
- ErrConfigParse = errors.New("failed to parse configuration")
-
// ErrInvalidProvider indicates the AI provider is not supported.
ErrInvalidProvider = errors.New("invalid provider")
@@ -81,15 +74,20 @@ var (
// ErrInvalidPostgresSSLMode indicates the PostgreSQL SSL mode is invalid.
ErrInvalidPostgresSSLMode = errors.New("invalid PostgreSQL SSL mode")
-)
-// ============================================================================
-// Constants
-// ============================================================================
+ // ErrMissingHMACSecret indicates the HMAC secret is not set.
+ ErrMissingHMACSecret = errors.New("missing HMAC secret")
+
+ // ErrInvalidHMACSecret indicates the HMAC secret is too short.
+ ErrInvalidHMACSecret = errors.New("invalid HMAC secret")
+)
const (
- // DefaultEmbedderModel is the default embedder model for vector embeddings.
- DefaultEmbedderModel = "text-embedding-004"
+ // DefaultGeminiEmbedderModel is the default Gemini embedder model.
+ // gemini-embedding-001 outputs 3072 dimensions by default, but supports
+ // truncation to 768 via OutputDimensionality (Matryoshka Representation Learning).
+ // Our pgvector schema uses 768 dimensions; see rag.VectorDimension.
+ DefaultGeminiEmbedderModel = "gemini-embedding-001"
// DefaultMaxHistoryMessages is the default number of messages to load.
DefaultMaxHistoryMessages int32 = 100
@@ -101,9 +99,13 @@ const (
MinHistoryMessages int32 = 10
)
-// ============================================================================
-// Config Struct
-// ============================================================================
+// AI provider identifiers used in Config.Provider.
+const (
+ ProviderGemini = "gemini"
+ ProviderOllama = "ollama"
+ ProviderOpenAI = "openai"
+ ProviderGoogleAI = "googleai"
+)
// Config stores application configuration.
// SECURITY: Sensitive fields are explicitly masked in MarshalJSON().
@@ -133,13 +135,8 @@ type Config struct {
PostgresSSLMode string `mapstructure:"postgres_ssl_mode" json:"postgres_ssl_mode"`
// RAG configuration
- RAGTopK int `mapstructure:"rag_top_k" json:"rag_top_k"`
EmbedderModel string `mapstructure:"embedder_model" json:"embedder_model"`
- // MCP configuration (see tools.go for type definitions)
- MCP MCPConfig `mapstructure:"mcp" json:"mcp"`
- MCPServers map[string]MCPServer `mapstructure:"mcp_servers" json:"mcp_servers"`
-
// Tool configuration (see tools.go for type definitions)
SearXNG SearXNGConfig `mapstructure:"searxng" json:"searxng"`
WebScraper WebScraperConfig `mapstructure:"web_scraper" json:"web_scraper"`
@@ -153,24 +150,20 @@ type Config struct {
TrustProxy bool `mapstructure:"trust_proxy" json:"trust_proxy"` // Trust X-Real-IP/X-Forwarded-For headers (set true behind reverse proxy)
}
-// ============================================================================
-// Load Function
-// ============================================================================
-
// Load loads configuration.
// Priority: Environment variables > Configuration file > Default values
func Load() (*Config, error) {
// Configuration directory: ~/.koopa/
home, err := os.UserHomeDir()
if err != nil {
- return nil, fmt.Errorf("failed to get user home directory: %w", err)
+ return nil, fmt.Errorf("getting user home directory: %w", err)
}
configDir := filepath.Join(home, ".koopa")
// Ensure directory exists (use 0750 permission for better security)
if err := os.MkdirAll(configDir, 0o750); err != nil {
- return nil, fmt.Errorf("failed to create config directory: %w", err)
+ return nil, fmt.Errorf("creating config directory: %w", err)
}
// Configure Viper
@@ -180,7 +173,7 @@ func Load() (*Config, error) {
viper.AddConfigPath(".") // Also support current directory
// Set default values
- setDefaults(configDir)
+ setDefaults()
// Bind environment variables
bindEnvVariables()
@@ -190,7 +183,7 @@ func Load() (*Config, error) {
// Configuration file not found is not an error, use default values
var configNotFound viper.ConfigFileNotFoundError
if !errors.As(err, &configNotFound) {
- return nil, fmt.Errorf("failed to read config file: %w", err)
+ return nil, fmt.Errorf("reading config file: %w", err)
}
slog.Debug("configuration file not found, using default values",
"search_paths", []string{configDir, "."},
@@ -200,26 +193,26 @@ func Load() (*Config, error) {
// Use Unmarshal to automatically map to struct (type-safe)
var cfg Config
if err := viper.Unmarshal(&cfg); err != nil {
- return nil, fmt.Errorf("failed to parse configuration: %w", err)
+ return nil, fmt.Errorf("parsing configuration: %w", err)
}
// Parse DATABASE_URL if set (highest priority for PostgreSQL config)
if err := cfg.parseDatabaseURL(); err != nil {
- return nil, fmt.Errorf("failed to parse DATABASE_URL: %w", err)
+ return nil, fmt.Errorf("parsing DATABASE_URL: %w", err)
}
// CRITICAL: Validate immediately (fail-fast)
if err := cfg.Validate(); err != nil {
- return nil, fmt.Errorf("configuration validation failed: %w", err)
+ return nil, fmt.Errorf("validating configuration: %w", err)
}
return &cfg, nil
}
// setDefaults sets all default configuration values.
-func setDefaults(configDir string) {
+func setDefaults() {
// AI defaults
- viper.SetDefault("provider", "gemini")
+ viper.SetDefault("provider", ProviderGemini)
viper.SetDefault("model_name", "gemini-2.5-flash")
viper.SetDefault("temperature", 0.7)
viper.SetDefault("max_tokens", 2048)
@@ -229,8 +222,6 @@ func setDefaults(configDir string) {
// Ollama defaults
viper.SetDefault("ollama_host", "http://localhost:11434")
- viper.SetDefault("database_path", filepath.Join(configDir, "koopa.db"))
-
// PostgreSQL defaults (matching docker-compose.yml)
viper.SetDefault("postgres_host", "localhost")
viper.SetDefault("postgres_port", 5432)
@@ -240,8 +231,7 @@ func setDefaults(configDir string) {
viper.SetDefault("postgres_ssl_mode", "disable")
// RAG defaults
- viper.SetDefault("rag_top_k", 3)
- viper.SetDefault("embedder_model", DefaultEmbedderModel)
+ viper.SetDefault("embedder_model", DefaultGeminiEmbedderModel)
// MCP defaults
viper.SetDefault("mcp.timeout", 5)
@@ -302,10 +292,6 @@ func bindEnvVariables() {
// Validation checks their presence based on the selected provider in cfg.Validate()
}
-// ============================================================================
-// Sensitive Data Masking
-// ============================================================================
-
// maskedValue is the placeholder for masked sensitive data.
// Using ████████ (full-width blocks U+2588) to avoid substring matching
// Previous attempts:
@@ -346,7 +332,6 @@ func maskSecret(s string) string {
// - PostgresPassword
// - HMACSecret
// - Datadog.APIKey (via DatadogConfig.MarshalJSON)
-// - MCPServers[*].Env (via MCPServer.MarshalJSON)
//
// When adding new sensitive fields, update this method or the nested struct's MarshalJSON.
// The compiler will remind you when tests fail.
@@ -355,7 +340,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
a := alias(c)
a.PostgresPassword = maskSecret(a.PostgresPassword)
a.HMACSecret = maskSecret(a.HMACSecret)
- // Note: Datadog.APIKey and MCPServers[*].Env are handled by their own MarshalJSON
+ // Note: Datadog.APIKey is handled by its own MarshalJSON
data, err := json.Marshal(a)
if err != nil {
return nil, fmt.Errorf("marshal config: %w", err)
@@ -371,12 +356,12 @@ func (c *Config) FullModelName() string {
return c.ModelName
}
switch c.Provider {
- case "ollama":
- return "ollama/" + c.ModelName
- case "openai":
- return "openai/" + c.ModelName
+ case ProviderOllama:
+ return ProviderOllama + "/" + c.ModelName
+ case ProviderOpenAI:
+ return ProviderOpenAI + "/" + c.ModelName
default:
- return "googleai/" + c.ModelName
+ return ProviderGoogleAI + "/" + c.ModelName
}
}
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
index 4ce02b0..8be18fd 100644
--- a/internal/config/config_test.go
+++ b/internal/config/config_test.go
@@ -21,13 +21,13 @@ func TestLoadDefaults(t *testing.T) {
originalHome := os.Getenv("HOME")
defer func() {
if err := os.Setenv("HOME", originalHome); err != nil {
- t.Errorf("Failed to restore HOME: %v", err)
+ t.Errorf("restoring HOME: %v", err)
}
}()
// Set HOME to temp directory (no existing config.yaml)
if err := os.Setenv("HOME", tmpDir); err != nil {
- t.Fatalf("Failed to set HOME: %v", err)
+ t.Fatalf("setting HOME: %v", err)
}
// Save and restore original environment
@@ -35,11 +35,11 @@ func TestLoadDefaults(t *testing.T) {
defer func() {
if originalAPIKey != "" {
if err := os.Setenv("GEMINI_API_KEY", originalAPIKey); err != nil {
- t.Errorf("Failed to restore GEMINI_API_KEY: %v", err)
+ t.Errorf("restoring GEMINI_API_KEY: %v", err)
}
} else {
if err := os.Unsetenv("GEMINI_API_KEY"); err != nil {
- t.Errorf("Failed to unset GEMINI_API_KEY: %v", err)
+ t.Errorf("unsetting GEMINI_API_KEY: %v", err)
}
}
}()
@@ -55,58 +55,51 @@ func TestLoadDefaults(t *testing.T) {
// Set API key for validation
if err := os.Setenv("GEMINI_API_KEY", "test-api-key"); err != nil {
- t.Fatalf("Failed to set GEMINI_API_KEY: %v", err)
+ t.Fatalf("setting GEMINI_API_KEY: %v", err)
}
cfg, err := Load()
if err != nil {
- t.Fatalf("Load() failed: %v", err)
+ t.Fatalf("Load() unexpected error: %v", err)
}
// Verify default values
if cfg.ModelName != "gemini-2.5-flash" {
- t.Errorf("expected default ModelName 'gemini-2.5-flash', got %q", cfg.ModelName)
+ t.Errorf("Load().ModelName = %q, want %q", cfg.ModelName, "gemini-2.5-flash")
}
if cfg.Temperature != 0.7 {
- t.Errorf("expected default Temperature 0.7, got %f", cfg.Temperature)
+ t.Errorf("Load().Temperature = %f, want %f", cfg.Temperature, 0.7)
}
if cfg.MaxTokens != 2048 {
- t.Errorf("expected default MaxTokens 2048, got %d", cfg.MaxTokens)
+ t.Errorf("Load().MaxTokens = %d, want %d", cfg.MaxTokens, 2048)
}
if cfg.MaxHistoryMessages != DefaultMaxHistoryMessages {
- t.Errorf("expected default MaxHistoryMessages %d, got %d", DefaultMaxHistoryMessages, cfg.MaxHistoryMessages)
+ t.Errorf("Load().MaxHistoryMessages = %d, want %d", cfg.MaxHistoryMessages, DefaultMaxHistoryMessages)
}
if cfg.PostgresHost != "localhost" {
- t.Errorf("expected default PostgresHost 'localhost', got %q", cfg.PostgresHost)
+ t.Errorf("Load().PostgresHost = %q, want %q", cfg.PostgresHost, "localhost")
}
if cfg.PostgresPort != 5432 {
- t.Errorf("expected default PostgresPort 5432, got %d", cfg.PostgresPort)
+ t.Errorf("Load().PostgresPort = %d, want %d", cfg.PostgresPort, 5432)
}
if cfg.PostgresUser != "koopa" {
- t.Errorf("expected default PostgresUser 'koopa', got %q", cfg.PostgresUser)
+ t.Errorf("Load().PostgresUser = %q, want %q", cfg.PostgresUser, "koopa")
}
if cfg.PostgresDBName != "koopa" {
- t.Errorf("expected default PostgresDBName 'koopa', got %q", cfg.PostgresDBName)
+ t.Errorf("Load().PostgresDBName = %q, want %q", cfg.PostgresDBName, "koopa")
}
- if cfg.RAGTopK != 3 {
- t.Errorf("expected default RAGTopK 3, got %d", cfg.RAGTopK)
+ if cfg.EmbedderModel != "gemini-embedding-001" {
+ t.Errorf("Load().EmbedderModel = %q, want %q", cfg.EmbedderModel, "gemini-embedding-001")
}
- if cfg.EmbedderModel != "text-embedding-004" {
- t.Errorf("expected default EmbedderModel 'text-embedding-004', got %q", cfg.EmbedderModel)
- }
-
- if cfg.MCP.Timeout != 5 {
- t.Errorf("expected default MCP timeout 5, got %d", cfg.MCP.Timeout)
- }
}
// TestLoadConfigFile tests loading configuration from a file
@@ -119,20 +112,20 @@ func TestLoadConfigFile(t *testing.T) {
originalHome := os.Getenv("HOME")
defer func() {
if err := os.Setenv("HOME", originalHome); err != nil {
- t.Errorf("Failed to restore HOME: %v", err)
+ t.Errorf("restoring HOME: %v", err)
}
}()
// Set HOME to temp directory
if err := os.Setenv("HOME", tmpDir); err != nil {
- t.Fatalf("Failed to set HOME: %v", err)
+ t.Fatalf("setting HOME: %v", err)
}
if err := os.Setenv("GEMINI_API_KEY", "test-api-key"); err != nil {
- t.Fatalf("Failed to set GEMINI_API_KEY: %v", err)
+ t.Fatalf("setting GEMINI_API_KEY: %v", err)
}
defer func() {
if err := os.Unsetenv("GEMINI_API_KEY"); err != nil {
- t.Errorf("Failed to unset GEMINI_API_KEY: %v", err)
+ t.Errorf("unsetting GEMINI_API_KEY: %v", err)
}
}()
@@ -148,55 +141,50 @@ func TestLoadConfigFile(t *testing.T) {
// Create .koopa directory
koopaDir := filepath.Join(tmpDir, ".koopa")
if err := os.MkdirAll(koopaDir, 0o750); err != nil {
- t.Fatalf("failed to create koopa dir: %v", err)
+ t.Fatalf("creating koopa dir: %v", err)
}
// Create config file
configContent := `model_name: gemini-2.5-pro
temperature: 0.9
max_tokens: 4096
-rag_top_k: 5
postgres_host: test-host
postgres_port: 5433
postgres_db_name: test_db
`
configPath := filepath.Join(koopaDir, "config.yaml")
if err := os.WriteFile(configPath, []byte(configContent), 0o600); err != nil {
- t.Fatalf("failed to write config file: %v", err)
+ t.Fatalf("writing config file: %v", err)
}
cfg, err := Load()
if err != nil {
- t.Fatalf("Load() failed: %v", err)
+ t.Fatalf("Load() unexpected error: %v", err)
}
// Verify values from config file
if cfg.ModelName != "gemini-2.5-pro" {
- t.Errorf("expected ModelName 'gemini-2.5-pro', got %q", cfg.ModelName)
+ t.Errorf("Load().ModelName = %q, want %q", cfg.ModelName, "gemini-2.5-pro")
}
if cfg.Temperature != 0.9 {
- t.Errorf("expected Temperature 0.9, got %f", cfg.Temperature)
+ t.Errorf("Load().Temperature = %f, want %f", cfg.Temperature, 0.9)
}
if cfg.MaxTokens != 4096 {
- t.Errorf("expected MaxTokens 4096, got %d", cfg.MaxTokens)
- }
-
- if cfg.RAGTopK != 5 {
- t.Errorf("expected RAGTopK 5, got %d", cfg.RAGTopK)
+ t.Errorf("Load().MaxTokens = %d, want %d", cfg.MaxTokens, 4096)
}
if cfg.PostgresHost != "test-host" {
- t.Errorf("expected PostgresHost 'test-host', got %q", cfg.PostgresHost)
+ t.Errorf("Load().PostgresHost = %q, want %q", cfg.PostgresHost, "test-host")
}
if cfg.PostgresPort != 5433 {
- t.Errorf("expected PostgresPort 5433, got %d", cfg.PostgresPort)
+ t.Errorf("Load().PostgresPort = %d, want %d", cfg.PostgresPort, 5433)
}
if cfg.PostgresDBName != "test_db" {
- t.Errorf("expected PostgresDBName 'test_db', got %q", cfg.PostgresDBName)
+ t.Errorf("Load().PostgresDBName = %q, want %q", cfg.PostgresDBName, "test_db")
}
}
@@ -228,25 +216,25 @@ func TestConfigDirectoryCreation(t *testing.T) {
originalHome := os.Getenv("HOME")
defer func() {
if err := os.Setenv("HOME", originalHome); err != nil {
- t.Errorf("Failed to restore HOME: %v", err)
+ t.Errorf("restoring HOME: %v", err)
}
}()
if err := os.Setenv("HOME", tmpDir); err != nil {
- t.Fatalf("Failed to set HOME: %v", err)
+ t.Fatalf("setting HOME: %v", err)
}
if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil {
- t.Fatalf("Failed to set GEMINI_API_KEY: %v", err)
+ t.Fatalf("setting GEMINI_API_KEY: %v", err)
}
defer func() {
if err := os.Unsetenv("GEMINI_API_KEY"); err != nil {
- t.Errorf("Failed to unset GEMINI_API_KEY: %v", err)
+ t.Errorf("unsetting GEMINI_API_KEY: %v", err)
}
}()
_, err := Load()
if err != nil {
- t.Fatalf("Load() failed: %v", err)
+ t.Fatalf("Load() unexpected error: %v", err)
}
// Check that .koopa directory was created
@@ -257,14 +245,13 @@ func TestConfigDirectoryCreation(t *testing.T) {
}
if !info.IsDir() {
- t.Error("expected .koopa to be a directory")
+ t.Error("Load() created .koopa as file, want directory")
}
// Check permissions (0750 = drwxr-x---)
perm := info.Mode().Perm()
- expectedPerm := os.FileMode(0o750)
- if perm != expectedPerm {
- t.Errorf("expected permissions %o, got %o", expectedPerm, perm)
+ if perm != 0o750 {
+ t.Errorf("Load() dir permissions = %o, want %o", perm, 0o750)
}
}
@@ -274,26 +261,26 @@ func TestEnvironmentVariableOverride(t *testing.T) {
originalHome := os.Getenv("HOME")
defer func() {
if err := os.Setenv("HOME", originalHome); err != nil {
- t.Errorf("Failed to restore HOME: %v", err)
+ t.Errorf("restoring HOME: %v", err)
}
}()
if err := os.Setenv("HOME", tmpDir); err != nil {
- t.Fatalf("Failed to set HOME: %v", err)
+ t.Fatalf("setting HOME: %v", err)
}
if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil {
- t.Fatalf("Failed to set GEMINI_API_KEY: %v", err)
+ t.Fatalf("setting GEMINI_API_KEY: %v", err)
}
defer func() {
if err := os.Unsetenv("GEMINI_API_KEY"); err != nil {
- t.Errorf("Failed to unset GEMINI_API_KEY: %v", err)
+ t.Errorf("unsetting GEMINI_API_KEY: %v", err)
}
}()
// Create .koopa directory and config file
koopaDir := filepath.Join(tmpDir, ".koopa")
if err := os.MkdirAll(koopaDir, 0o750); err != nil {
- t.Fatalf("failed to create koopa dir: %v", err)
+ t.Fatalf("creating koopa dir: %v", err)
}
configContent := `model_name: gemini-2.5-pro
@@ -302,7 +289,7 @@ max_tokens: 1024
`
configPath := filepath.Join(koopaDir, "config.yaml")
if err := os.WriteFile(configPath, []byte(configContent), 0o600); err != nil {
- t.Fatalf("failed to write config file: %v", err)
+ t.Fatalf("writing config file: %v", err)
}
// KOOPA_* env vars NO LONGER supported (removed AutomaticEnv)
@@ -310,10 +297,10 @@ max_tokens: 1024
testHMACSecret := "test-hmac-secret-minimum-32-chars-long"
if err := os.Setenv("DD_API_KEY", testAPIKey); err != nil {
- t.Fatalf("Failed to set DD_API_KEY: %v", err)
+ t.Fatalf("setting DD_API_KEY: %v", err)
}
if err := os.Setenv("HMAC_SECRET", testHMACSecret); err != nil {
- t.Fatalf("Failed to set HMAC_SECRET: %v", err)
+ t.Fatalf("setting HMAC_SECRET: %v", err)
}
defer func() {
_ = os.Unsetenv("DD_API_KEY")
@@ -322,29 +309,29 @@ max_tokens: 1024
cfg, err := Load()
if err != nil {
- t.Fatalf("Load() failed: %v", err)
+ t.Fatalf("Load() unexpected error: %v", err)
}
// Config values should come from config.yaml (NOT env vars)
if cfg.ModelName != "gemini-2.5-pro" {
- t.Errorf("expected ModelName from config 'gemini-2.5-pro', got %q", cfg.ModelName)
+ t.Errorf("Load().ModelName = %q, want %q", cfg.ModelName, "gemini-2.5-pro")
}
if cfg.Temperature != 0.5 {
- t.Errorf("expected Temperature from config 0.5, got %f", cfg.Temperature)
+ t.Errorf("Load().Temperature = %f, want %f", cfg.Temperature, 0.5)
}
if cfg.MaxTokens != 1024 {
- t.Errorf("expected MaxTokens from config 1024, got %d", cfg.MaxTokens)
+ t.Errorf("Load().MaxTokens = %d, want %d", cfg.MaxTokens, 1024)
}
// Sensitive env vars should be bound
if cfg.Datadog.APIKey != testAPIKey {
- t.Errorf("expected Datadog.APIKey from env %q, got %q", testAPIKey, cfg.Datadog.APIKey)
+ t.Errorf("Load().Datadog.APIKey = %q, want %q", cfg.Datadog.APIKey, testAPIKey)
}
if cfg.HMACSecret != testHMACSecret {
- t.Errorf("expected HMACSecret from env %q, got %q", testHMACSecret, cfg.HMACSecret)
+ t.Errorf("Load().HMACSecret = %q, want %q", cfg.HMACSecret, testHMACSecret)
}
}
@@ -354,26 +341,26 @@ func TestLoadInvalidYAML(t *testing.T) {
originalHome := os.Getenv("HOME")
defer func() {
if err := os.Setenv("HOME", originalHome); err != nil {
- t.Errorf("Failed to restore HOME: %v", err)
+ t.Errorf("restoring HOME: %v", err)
}
}()
if err := os.Setenv("HOME", tmpDir); err != nil {
- t.Fatalf("Failed to set HOME: %v", err)
+ t.Fatalf("setting HOME: %v", err)
}
if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil {
- t.Fatalf("Failed to set GEMINI_API_KEY: %v", err)
+ t.Fatalf("setting GEMINI_API_KEY: %v", err)
}
defer func() {
if err := os.Unsetenv("GEMINI_API_KEY"); err != nil {
- t.Errorf("Failed to unset GEMINI_API_KEY: %v", err)
+ t.Errorf("unsetting GEMINI_API_KEY: %v", err)
}
}()
// Create .koopa directory
koopaDir := filepath.Join(tmpDir, ".koopa")
if err := os.MkdirAll(koopaDir, 0o750); err != nil {
- t.Fatalf("failed to create koopa dir: %v", err)
+ t.Fatalf("creating koopa dir: %v", err)
}
// Create invalid YAML config file
@@ -384,12 +371,12 @@ max_tokens: not_a_number
`
configPath := filepath.Join(koopaDir, "config.yaml")
if err := os.WriteFile(configPath, []byte(invalidYAML), 0o600); err != nil {
- t.Fatalf("failed to write invalid config file: %v", err)
+ t.Fatalf("writing invalid config file: %v", err)
}
_, err := Load()
if err == nil {
- t.Error("expected error for invalid YAML, got none")
+ t.Error("Load() error = nil, want error for invalid YAML")
}
}
@@ -399,26 +386,26 @@ func TestLoadUnmarshalError(t *testing.T) {
originalHome := os.Getenv("HOME")
defer func() {
if err := os.Setenv("HOME", originalHome); err != nil {
- t.Errorf("Failed to restore HOME: %v", err)
+ t.Errorf("restoring HOME: %v", err)
}
}()
if err := os.Setenv("HOME", tmpDir); err != nil {
- t.Fatalf("Failed to set HOME: %v", err)
+ t.Fatalf("setting HOME: %v", err)
}
if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil {
- t.Fatalf("Failed to set GEMINI_API_KEY: %v", err)
+ t.Fatalf("setting GEMINI_API_KEY: %v", err)
}
defer func() {
if err := os.Unsetenv("GEMINI_API_KEY"); err != nil {
- t.Errorf("Failed to unset GEMINI_API_KEY: %v", err)
+ t.Errorf("unsetting GEMINI_API_KEY: %v", err)
}
}()
// Create .koopa directory
koopaDir := filepath.Join(tmpDir, ".koopa")
if err := os.MkdirAll(koopaDir, 0o750); err != nil {
- t.Fatalf("failed to create koopa dir: %v", err)
+ t.Fatalf("creating koopa dir: %v", err)
}
// Create config with type mismatch
@@ -428,30 +415,28 @@ max_tokens: "this should also be a number"
`
configPath := filepath.Join(koopaDir, "config.yaml")
if err := os.WriteFile(configPath, []byte(invalidTypeYAML), 0o600); err != nil {
- t.Fatalf("failed to write config file: %v", err)
+ t.Fatalf("writing config file: %v", err)
}
- // This will succeed because viper is flexible with type conversion
- // but we document this test to show that invalid types are handled
- _, err := Load()
- // Note: viper may successfully parse string "0.7" as float, so we don't assert error here
- _ = err
+ // Viper is flexible with type conversion so this may succeed or fail.
+ // The test verifies Load() doesn't panic on type-mismatched YAML.
+ _, _ = Load()
}
// BenchmarkLoad benchmarks configuration loading
func BenchmarkLoad(b *testing.B) {
if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil {
- b.Fatalf("Failed to set GEMINI_API_KEY: %v", err)
+ b.Fatalf("setting GEMINI_API_KEY: %v", err)
}
defer func() {
if err := os.Unsetenv("GEMINI_API_KEY"); err != nil {
- b.Errorf("Failed to unset GEMINI_API_KEY: %v", err)
+ b.Errorf("unsetting GEMINI_API_KEY: %v", err)
}
}()
// Verify Load() works before starting benchmark
if _, err := Load(); err != nil {
- b.Fatalf("Load() failed: %v", err)
+ b.Fatalf("Load() unexpected error: %v", err)
}
b.ResetTimer()
@@ -473,7 +458,7 @@ func TestConfig_MarshalJSON_MasksSensitiveFields(t *testing.T) {
data, err := json.Marshal(cfg)
if err != nil {
- t.Fatalf("MarshalJSON failed: %v", err)
+ t.Fatalf("MarshalJSON() unexpected error: %v", err)
}
jsonStr := string(data)
@@ -488,9 +473,9 @@ func TestConfig_MarshalJSON_MasksSensitiveFields(t *testing.T) {
// 1. Not be the original password
// 2. Contain masking characters (****)
// 3. Be present in the JSON output
- var result map[string]interface{}
+ var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
- t.Fatalf("failed to unmarshal result: %v", err)
+ t.Fatalf("unmarshaling result: %v", err)
}
maskedPwd, ok := result["postgres_password"].(string)
@@ -522,17 +507,17 @@ func TestConfig_MarshalJSON_EmptyPassword(t *testing.T) {
data, err := json.Marshal(cfg)
if err != nil {
- t.Fatalf("MarshalJSON failed: %v", err)
+ t.Fatalf("MarshalJSON() unexpected error: %v", err)
}
// Empty password should remain empty, not cause panic
- var result map[string]interface{}
+ var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
- t.Fatalf("failed to unmarshal result: %v", err)
+ t.Fatalf("unmarshaling result: %v", err)
}
if result["postgres_password"] != "" {
- t.Errorf("expected empty password to remain empty, got %v", result["postgres_password"])
+ t.Errorf("MarshalJSON() postgres_password = %v, want empty string", result["postgres_password"])
}
}
@@ -544,7 +529,7 @@ func TestConfig_MarshalJSON_ShortPassword(t *testing.T) {
data, err := json.Marshal(cfg)
if err != nil {
- t.Fatalf("MarshalJSON failed: %v", err)
+ t.Fatalf("MarshalJSON() unexpected error: %v", err)
}
jsonStr := string(data)
@@ -555,7 +540,7 @@ func TestConfig_MarshalJSON_ShortPassword(t *testing.T) {
}
if !strings.Contains(jsonStr, `"postgres_password":"████████"`) {
- t.Errorf("expected fully masked password '████████', got: %s", jsonStr)
+ t.Errorf("MarshalJSON() short password not fully masked, got: %s", jsonStr)
}
}
@@ -582,23 +567,16 @@ func TestConfig_SensitiveFieldsAreMasked(t *testing.T) {
Datadog: DatadogConfig{
APIKey: "datadogapikey789",
},
- MCPServers: map[string]MCPServer{
- "test": {
- Env: map[string]string{
- "API_KEY": "envvarkey123",
- },
- },
- },
}
data, err := json.Marshal(cfg)
if err != nil {
- t.Fatalf("MarshalJSON failed: %v", err)
+ t.Fatalf("MarshalJSON() unexpected error: %v", err)
}
jsonStr := string(data)
// Verify no raw secrets appear in output
- secrets := []string{"secretpassword123", "hmacsecret456", "datadogapikey789", "envvarkey123"}
+ secrets := []string{"secretpassword123", "hmacsecret456", "datadogapikey789"}
for _, secret := range secrets {
if strings.Contains(jsonStr, secret) {
t.Errorf("sensitive value %q should be masked in JSON output", secret)
@@ -612,17 +590,6 @@ func TestConfig_MarshalJSON_NestedStructs(t *testing.T) {
cfg := Config{
ModelName: "test-model",
PostgresPassword: "secretpassword",
- MCP: MCPConfig{
- Timeout: 10,
- Allowed: []string{"server1", "server2"},
- },
- MCPServers: map[string]MCPServer{
- "test-server": {
- Command: "npx",
- Args: []string{"-y", "test-mcp"},
- Timeout: 30,
- },
- },
SearXNG: SearXNGConfig{
BaseURL: "http://localhost:8080",
},
@@ -635,48 +602,30 @@ func TestConfig_MarshalJSON_NestedStructs(t *testing.T) {
data, err := json.Marshal(cfg)
if err != nil {
- t.Fatalf("MarshalJSON failed: %v", err)
+ t.Fatalf("MarshalJSON() unexpected error: %v", err)
}
- var result map[string]interface{}
+ var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
- t.Fatalf("failed to unmarshal result: %v", err)
- }
-
- // Verify nested MCP config is present
- mcp, ok := result["mcp"].(map[string]interface{})
- if !ok {
- t.Fatal("mcp should be a nested object in JSON output")
- }
- if mcp["timeout"] != float64(10) {
- t.Errorf("expected mcp.timeout = 10, got %v", mcp["timeout"])
- }
-
- // Verify MCPServers map is present
- servers, ok := result["mcp_servers"].(map[string]interface{})
- if !ok {
- t.Fatal("mcp_servers should be a map in JSON output")
- }
- if _, exists := servers["test-server"]; !exists {
- t.Error("expected test-server in mcp_servers")
+ t.Fatalf("unmarshaling result: %v", err)
}
// Verify SearXNG config
- searxng, ok := result["searxng"].(map[string]interface{})
+ searxng, ok := result["searxng"].(map[string]any)
if !ok {
t.Fatal("searxng should be a nested object")
}
if searxng["base_url"] != "http://localhost:8080" {
- t.Errorf("expected searxng.base_url = 'http://localhost:8080', got %v", searxng["base_url"])
+ t.Errorf("MarshalJSON() searxng.base_url = %v, want %q", searxng["base_url"], "http://localhost:8080")
}
// Verify Datadog config
- datadog, ok := result["datadog"].(map[string]interface{})
+ datadog, ok := result["datadog"].(map[string]any)
if !ok {
t.Fatal("datadog should be a nested object")
}
if datadog["environment"] != "test" {
- t.Errorf("expected datadog.environment = 'test', got %v", datadog["environment"])
+ t.Errorf("MarshalJSON() datadog.environment = %v, want %q", datadog["environment"], "test")
}
// CRITICAL: Verify sensitive field is still masked
@@ -686,95 +635,6 @@ func TestConfig_MarshalJSON_NestedStructs(t *testing.T) {
}
}
-// TestConfig_MarshalJSON_MCPServerEnvMasked verifies that MCPServer.Env (sensitive map) is masked
-// SECURITY: MCPServer.Env commonly contains API keys, tokens, and secrets
-func TestConfig_MarshalJSON_MCPServerEnvMasked(t *testing.T) {
- cfg := Config{
- MCPServers: map[string]MCPServer{
- "github-mcp": {
- Command: "npx",
- Args: []string{"-y", "@modelcontextprotocol/server-github"},
- Env: map[string]string{
- "GITHUB_TOKEN": "ghp_supersecrettoken12345678",
- "API_KEY": "sk-proj-secretapikey67890",
- "OPENAI_API_KEY": "sk-openai-verysecretkey",
- "ANTHROPIC_KEY": "anthropic-secret-key-xxx",
- "DATABASE_PASSWORD": "dbpassword123",
- },
- Timeout: 30,
- },
- "another-server": {
- Command: "node",
- Env: map[string]string{
- "SECRET_TOKEN": "another_secret_value",
- },
- },
- },
- }
-
- data, err := json.Marshal(cfg)
- if err != nil {
- t.Fatalf("MarshalJSON failed: %v", err)
- }
-
- jsonStr := string(data)
-
- // CRITICAL: All secret values in Env must be masked
- secrets := []string{
- "ghp_supersecrettoken12345678",
- "sk-proj-secretapikey67890",
- "sk-openai-verysecretkey",
- "anthropic-secret-key-xxx",
- "dbpassword123",
- "another_secret_value",
- }
-
- for _, secret := range secrets {
- if strings.Contains(jsonStr, secret) {
- t.Errorf("SECURITY: MCPServer.Env secret leaked in JSON output: %s", secret)
- }
- }
-
- // Verify the Env field is present but masked
- var result map[string]interface{}
- if err := json.Unmarshal(data, &result); err != nil {
- t.Fatalf("failed to unmarshal result: %v", err)
- }
-
- servers, ok := result["mcp_servers"].(map[string]interface{})
- if !ok {
- t.Fatal("mcp_servers should be present in JSON output")
- }
-
- githubServer, ok := servers["github-mcp"].(map[string]interface{})
- if !ok {
- t.Fatal("github-mcp server should be present")
- }
-
- // Env map values should be masked individually (keys visible, values masked)
- env, ok := githubServer["env"].(map[string]interface{})
- if !ok {
- t.Fatal("MCPServer.Env should be a map")
- }
- // Check that original secrets are not present
- for _, v := range env {
- strVal, ok := v.(string)
- if !ok {
- t.Error("Env values should be strings")
- continue
- }
- // Masked values should contain maskedValue (████████)
- if !strings.Contains(strVal, "████████") && strVal != "" {
- t.Errorf("Env value should be masked, got: %s", strVal)
- }
- }
-
- // Verify non-sensitive fields are NOT masked
- if githubServer["command"] != "npx" {
- t.Error("non-sensitive field Command should not be masked")
- }
-}
-
// NOTE: TestConfig_MarshalJSON_AllSensitiveFields was removed because:
// 1. We no longer use `sensitive:"true"` tags (replaced with explicit MarshalJSON)
// 2. TestConfig_SensitiveFieldsAreMasked provides equivalent coverage
@@ -821,7 +681,7 @@ func TestMaskSecret_Unicode(t *testing.T) {
// Verify masking pattern is present (when expected)
if tt.wantContains != "" && !strings.Contains(masked, tt.wantContains) {
- t.Errorf("expected masked output to contain %q, got: %q", tt.wantContains, masked)
+ t.Errorf("maskSecret(%q) = %q, want contains %q", tt.input, masked, tt.wantContains)
}
// CRITICAL: Original value must NEVER appear in masked output
@@ -866,7 +726,7 @@ func TestConfig_MarshalJSON_UnicodePasswords(t *testing.T) {
data, err := json.Marshal(cfg)
if err != nil {
- t.Fatalf("MarshalJSON failed: %v", err)
+ t.Fatalf("MarshalJSON() unexpected error: %v", err)
}
jsonStr := string(data)
@@ -878,7 +738,7 @@ func TestConfig_MarshalJSON_UnicodePasswords(t *testing.T) {
// Verify masking was applied
if !strings.Contains(jsonStr, "████████") {
- t.Errorf("expected masked output to contain '████████', got: %s", jsonStr)
+ t.Errorf("MarshalJSON() output missing mask chars, got: %s", jsonStr)
}
})
}
@@ -1082,19 +942,6 @@ func BenchmarkConfig_MarshalJSON(b *testing.B) {
PostgresPort: 5432,
PostgresUser: "koopa",
PostgresDBName: "koopa",
- MCP: MCPConfig{
- Timeout: 5,
- Allowed: []string{"server1", "server2"},
- },
- MCPServers: map[string]MCPServer{
- "github": {
- Command: "npx",
- Args: []string{"-y", "@modelcontextprotocol/server-github"},
- Env: map[string]string{
- "GITHUB_TOKEN": "ghp_secrettoken12345",
- },
- },
- },
}
b.ResetTimer()
@@ -1104,16 +951,40 @@ func BenchmarkConfig_MarshalJSON(b *testing.B) {
}
}
+// TestFullModelName tests that FullModelName derives correct provider-qualified names.
+func TestFullModelName(t *testing.T) {
+ tests := []struct {
+ name string
+ provider string
+ modelName string
+ want string
+ }{
+ {name: "gemini default", provider: "", modelName: "gemini-2.5-flash", want: "googleai/gemini-2.5-flash"},
+ {name: "gemini explicit", provider: "gemini", modelName: "gemini-2.5-pro", want: "googleai/gemini-2.5-pro"},
+ {name: "ollama", provider: "ollama", modelName: "llama3.3", want: "ollama/llama3.3"},
+ {name: "openai", provider: "openai", modelName: "gpt-4o", want: "openai/gpt-4o"},
+ {name: "already qualified", provider: "ollama", modelName: "ollama/llama3.3", want: "ollama/llama3.3"},
+ {name: "cross-qualified", provider: "gemini", modelName: "openai/gpt-4o", want: "openai/gpt-4o"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &Config{
+ Provider: tt.provider,
+ ModelName: tt.modelName,
+ }
+ got := cfg.FullModelName()
+ if got != tt.want {
+ t.Errorf("FullModelName() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
// BenchmarkConfig_MarshalJSON_Parallel benchmarks concurrent Config marshaling
func BenchmarkConfig_MarshalJSON_Parallel(b *testing.B) {
cfg := Config{
PostgresPassword: "supersecretpassword123",
- MCPServers: map[string]MCPServer{
- "test": {
- Command: "npx",
- Env: map[string]string{"SECRET": "value"},
- },
- },
}
b.ResetTimer()
diff --git a/internal/config/observability.go b/internal/config/observability.go
index 8ecf78a..c542795 100644
--- a/internal/config/observability.go
+++ b/internal/config/observability.go
@@ -8,7 +8,7 @@ import (
// DatadogConfig holds Datadog APM tracing configuration.
//
// Tracing uses the local Datadog Agent for OTLP ingestion.
-// See internal/observability/datadog.go for detailed setup instructions.
+// Setup is inlined in internal/app/setup.go (provideOtelShutdown).
type DatadogConfig struct {
// APIKey is the Datadog API key (optional, for observability)
APIKey string `mapstructure:"api_key" json:"api_key"`
diff --git a/internal/config/storage.go b/internal/config/storage.go
index 5771070..f2230be 100644
--- a/internal/config/storage.go
+++ b/internal/config/storage.go
@@ -8,43 +8,39 @@ import (
"strings"
)
-// StorageConfig documentation.
-// Fields are embedded in the main Config struct for backward compatibility.
-//
-// PostgreSQL (for pgvector):
-// - PostgresHost: Database host (default: localhost)
-// - PostgresPort: Database port (default: 5432)
-// - PostgresUser: Database user (default: koopa)
-// - PostgresPassword: Database password
-// - PostgresDBName: Database name (default: koopa)
-// - PostgresSSLMode: SSL mode (default: disable)
-//
-// RAG:
-// - RAGTopK: Number of documents to retrieve (1-10, default: 3)
-// - EmbedderModel: Embedding model name (default: text-embedding-004)
+// quoteDSNValue quotes a value for PostgreSQL key=value DSN format.
+// Within single quotes, backslashes and single quotes are escaped.
+// This prevents parsing errors when values contain spaces or special characters.
+func quoteDSNValue(s string) string {
+ s = strings.ReplaceAll(s, `\`, `\\`)
+ s = strings.ReplaceAll(s, `'`, `\'`)
+ return "'" + s + "'"
+}
// PostgresConnectionString returns the PostgreSQL DSN for pgx driver.
+// Password is single-quoted to handle special characters (spaces, =, quotes).
func (c *Config) PostgresConnectionString() string {
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
c.PostgresHost,
c.PostgresPort,
c.PostgresUser,
- c.PostgresPassword,
+ quoteDSNValue(c.PostgresPassword),
c.PostgresDBName,
c.PostgresSSLMode,
)
}
// PostgresURL returns the PostgreSQL URL for golang-migrate.
+// Uses url.URL for proper encoding of special characters in credentials.
func (c *Config) PostgresURL() string {
- return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s",
- c.PostgresUser,
- c.PostgresPassword,
- c.PostgresHost,
- c.PostgresPort,
- c.PostgresDBName,
- c.PostgresSSLMode,
- )
+ u := &url.URL{
+ Scheme: "postgres",
+ User: url.UserPassword(c.PostgresUser, c.PostgresPassword),
+ Host: fmt.Sprintf("%s:%d", c.PostgresHost, c.PostgresPort),
+ Path: c.PostgresDBName,
+ RawQuery: fmt.Sprintf("sslmode=%s", c.PostgresSSLMode),
+ }
+ return u.String()
}
// parseDatabaseURL parses DATABASE_URL environment variable and sets PostgreSQL config.
diff --git a/internal/config/storage_test.go b/internal/config/storage_test.go
index 22ddad4..8f43b45 100644
--- a/internal/config/storage_test.go
+++ b/internal/config/storage_test.go
@@ -23,7 +23,7 @@ func TestPostgresConnectionString(t *testing.T) {
"host=test-host",
"port=5433",
"user=test-user",
- "password=test-password",
+ "password='test-password'",
"dbname=test-db",
"sslmode=require",
}
@@ -35,22 +35,130 @@ func TestPostgresConnectionString(t *testing.T) {
}
}
+// TestPostgresConnectionStringSpecialChars tests DSN quoting handles special characters.
+// The password is single-quoted in DSN format, so injection payloads like
+// "' host=evil.com" become part of the quoted value, not separate DSN keys.
+func TestPostgresConnectionStringSpecialChars(t *testing.T) {
+ tests := []struct {
+ name string
+ password string
+ wantQuoted string // expected quoted form in DSN
+ }{
+ {
+ name: "password with spaces",
+ password: "my secret pass",
+ wantQuoted: "'my secret pass'",
+ },
+ {
+ name: "password with single quotes",
+ password: "pass'word",
+ wantQuoted: `'pass\'word'`,
+ },
+ {
+ name: "DSN injection attempt",
+ password: "' host=evil.com user=hacker password='",
+ wantQuoted: `'\' host=evil.com user=hacker password=\''`,
+ },
+ {
+ name: "password with backslash",
+ password: `pass\word`,
+ wantQuoted: `'pass\\word'`,
+ },
+ {
+ name: "password with equals",
+ password: "pass=word",
+ wantQuoted: "'pass=word'",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &Config{
+ PostgresHost: "localhost",
+ PostgresPort: 5432,
+ PostgresUser: "user",
+ PostgresPassword: tt.password,
+ PostgresDBName: "db",
+ PostgresSSLMode: "disable",
+ }
+ dsn := cfg.PostgresConnectionString()
+
+ // Verify password is correctly quoted in DSN
+ wantPart := "password=" + tt.wantQuoted
+ if !strings.Contains(dsn, wantPart) {
+ t.Errorf("PostgresConnectionString() missing %q\ngot: %s", wantPart, dsn)
+ }
+
+ // Verify DSN ends with expected structure after password
+ if !strings.HasSuffix(dsn, "dbname=db sslmode=disable") {
+ t.Errorf("PostgresConnectionString() has corrupted suffix\ngot: %s", dsn)
+ }
+ })
+ }
+}
+
+// TestQuoteDSNValue tests the DSN value quoting helper.
+func TestQuoteDSNValue(t *testing.T) {
+ tests := []struct {
+ input string
+ want string
+ }{
+ {input: "simple", want: "'simple'"},
+ {input: "with space", want: "'with space'"},
+ {input: "with'quote", want: `'with\'quote'`},
+ {input: `with\backslash`, want: `'with\\backslash'`},
+ {input: "", want: "''"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ got := quoteDSNValue(tt.input)
+ if got != tt.want {
+ t.Errorf("quoteDSNValue(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
// TestPostgresURL tests PostgreSQL URL generation for golang-migrate
func TestPostgresURL(t *testing.T) {
- cfg := &Config{
- PostgresHost: "test-host",
- PostgresPort: 5433,
- PostgresUser: "test-user",
- PostgresPassword: "test-password",
- PostgresDBName: "test-db",
- PostgresSSLMode: "require",
+ tests := []struct {
+ name string
+ cfg *Config
+ want string
+ }{
+ {
+ name: "simple credentials",
+ cfg: &Config{
+ PostgresHost: "test-host",
+ PostgresPort: 5433,
+ PostgresUser: "test-user",
+ PostgresPassword: "test-password",
+ PostgresDBName: "test-db",
+ PostgresSSLMode: "require",
+ },
+ want: "postgres://test-user:test-password@test-host:5433/test-db?sslmode=require",
+ },
+ {
+ name: "password with special characters",
+ cfg: &Config{
+ PostgresHost: "localhost",
+ PostgresPort: 5432,
+ PostgresUser: "koopa",
+ PostgresPassword: "p@ss/word#123",
+ PostgresDBName: "koopa",
+ PostgresSSLMode: "disable",
+ },
+ want: "postgres://koopa:p%40ss%2Fword%23123@localhost:5432/koopa?sslmode=disable",
+ },
}
- url := cfg.PostgresURL()
-
- expected := "postgres://test-user:test-password@test-host:5433/test-db?sslmode=require"
- if url != expected {
- t.Errorf("PostgresURL() = %q, want %q", url, expected)
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.cfg.PostgresURL()
+ if got != tt.want {
+ t.Errorf("PostgresURL() = %q, want %q", got, tt.want)
+ }
+ })
}
}
diff --git a/internal/config/testdata/fuzz/FuzzConfigMarshalJSON/771e938e4458e983 b/internal/config/testdata/fuzz/FuzzConfigMarshalJSON/771e938e4458e983
deleted file mode 100644
index ee3f339..0000000
--- a/internal/config/testdata/fuzz/FuzzConfigMarshalJSON/771e938e4458e983
+++ /dev/null
@@ -1,2 +0,0 @@
-go test fuzz v1
-string("0")
diff --git a/internal/config/tools.go b/internal/config/tools.go
index efb3a0e..4ceb5c8 100644
--- a/internal/config/tools.go
+++ b/internal/config/tools.go
@@ -1,46 +1,5 @@
package config
-import (
- "encoding/json"
- "fmt"
-)
-
-// MCPConfig controls global MCP (Model Context Protocol) behavior.
-type MCPConfig struct {
- Allowed []string `mapstructure:"allowed" json:"allowed"` // Whitelist of server names (empty = all configured servers)
- Excluded []string `mapstructure:"excluded" json:"excluded"` // Blacklist of server names (higher priority than Allowed)
- Timeout int `mapstructure:"timeout" json:"timeout"` // Connection timeout in seconds (default: 5)
-}
-
-// MCPServer defines a single MCP server configuration.
-type MCPServer struct {
- Command string `mapstructure:"command" json:"command"` // Required: executable path (e.g., "npx")
- Args []string `mapstructure:"args" json:"args"` // Optional: command arguments
- Env map[string]string `mapstructure:"env" json:"env"` // Optional: environment variables - SECURITY: May contain API keys/tokens
- Timeout int `mapstructure:"timeout" json:"timeout"` // Optional: per-server timeout (overrides global)
- IncludeTools []string `mapstructure:"include_tools" json:"include_tools"` // Optional: tool whitelist
- ExcludeTools []string `mapstructure:"exclude_tools" json:"exclude_tools"` // Optional: tool blacklist
-}
-
-// MarshalJSON implements json.Marshaler with sensitive field masking.
-// Masks all values in the Env map as they may contain API keys/tokens.
-func (m MCPServer) MarshalJSON() ([]byte, error) {
- type alias MCPServer
- a := alias(m)
- if a.Env != nil {
- maskedEnv := make(map[string]string, len(a.Env))
- for k, v := range a.Env {
- maskedEnv[k] = maskSecret(v)
- }
- a.Env = maskedEnv
- }
- data, err := json.Marshal(a)
- if err != nil {
- return nil, fmt.Errorf("marshal mcp server: %w", err)
- }
- return data, nil
-}
-
// SearXNGConfig holds SearXNG service configuration for web search.
type SearXNGConfig struct {
// BaseURL is the SearXNG instance URL (e.g., http://searxng:8080)
diff --git a/internal/config/validation.go b/internal/config/validation.go
index dd88a05..43d7e2a 100644
--- a/internal/config/validation.go
+++ b/internal/config/validation.go
@@ -8,7 +8,7 @@ import (
)
// supportedProviders lists all valid AI provider values.
-var supportedProviders = []string{"gemini", "ollama", "openai"}
+var supportedProviders = []string{ProviderGemini, ProviderOllama, ProviderOpenAI}
// Validate validates configuration values.
// Returns sentinel errors that can be checked with errors.Is().
@@ -54,17 +54,17 @@ func (c *Config) validateAI() error {
}
// Ollama host
- if c.resolvedProvider() == "ollama" && c.OllamaHost == "" {
+ if c.resolvedProvider() == ProviderOllama && c.OllamaHost == "" {
return fmt.Errorf("%w: ollama_host cannot be empty when provider is ollama", ErrInvalidOllamaHost)
}
- // RAG
- if c.RAGTopK <= 0 || c.RAGTopK > 10 {
- return fmt.Errorf("%w: must be between 1 and 10, got %d", ErrInvalidRAGTopK, c.RAGTopK)
- }
+ // RAG embedder
if c.EmbedderModel == "" {
return fmt.Errorf("%w: embedder_model cannot be empty", ErrInvalidEmbedderModel)
}
+ if err := c.validateEmbedder(); err != nil {
+ return err
+ }
return nil
}
@@ -104,37 +104,98 @@ func (c *Config) validatePostgresSSL() error {
ErrInvalidPostgresSSLMode)
}
if !slices.Contains(validSSLModes, c.PostgresSSLMode) {
- return fmt.Errorf("%w: %q is not valid, must be one of: %v\n"+
- "Note: 'allow' and 'prefer' modes are deprecated (vulnerable to MITM attacks)",
+ return fmt.Errorf("%w: %q is not valid, must be one of: %v (allow/prefer excluded: MITM vulnerable)",
ErrInvalidPostgresSSLMode, c.PostgresSSLMode, validSSLModes)
}
return nil
}
+// knownEmbedderDimensions maps provider → model → native output dimension.
+// Used to catch dimension mismatches at startup before hitting pgvector errors.
+var knownEmbedderDimensions = map[string]map[string]int{
+ ProviderGemini: {
+ "gemini-embedding-001": 3072,
+ "text-embedding-004": 768,
+ },
+ ProviderOpenAI: {
+ "text-embedding-3-small": 1536,
+ "text-embedding-3-large": 3072,
+ },
+}
+
+// requiredVectorDimension must match the pgvector schema: embedding vector(768).
+const requiredVectorDimension = 768
+
+// validateEmbedder checks that the configured embedder model produces vectors
+// compatible with the database schema. For known models whose native dimension
+// differs from requiredVectorDimension, this returns an error so operators
+// know to set OutputDimensionality (handled by rag.NewDocStoreConfig).
+//
+// Unknown providers or models pass validation silently — the operator may
+// know what they are doing (e.g., a custom Ollama embedder producing 768-dim).
+func (c *Config) validateEmbedder() error {
+ models, ok := knownEmbedderDimensions[c.resolvedProvider()]
+ if !ok {
+ return nil // unknown provider (e.g., ollama) — skip
+ }
+ dim, ok := models[c.EmbedderModel]
+ if !ok {
+ return nil // unknown model — skip
+ }
+ if dim != requiredVectorDimension {
+ slog.Warn("embedder native dimension differs from schema",
+ "model", c.EmbedderModel,
+ "native_dim", dim,
+ "schema_dim", requiredVectorDimension,
+ "note", "rag.NewDocStoreConfig truncates output via OutputDimensionality")
+ }
+ return nil
+}
+
// resolvedProvider returns the effective provider, defaulting to "gemini".
func (c *Config) resolvedProvider() string {
if c.Provider == "" {
- return "gemini"
+ return ProviderGemini
}
return c.Provider
}
+// ValidateServe validates configuration specific to serve mode.
+// HMAC_SECRET is required for CSRF protection in HTTP mode.
+func (c *Config) ValidateServe() error {
+ if err := c.Validate(); err != nil {
+ return err
+ }
+ if c.HMACSecret == "" {
+ return fmt.Errorf("%w: HMAC_SECRET environment variable is required for serve mode (min 32 characters)",
+ ErrMissingHMACSecret)
+ }
+ if len(c.HMACSecret) < 32 {
+ return fmt.Errorf("%w: must be at least 32 characters, got %d",
+ ErrInvalidHMACSecret, len(c.HMACSecret))
+ }
+ if c.TrustProxy {
+ slog.Warn("trust_proxy is enabled — ensure this server is behind a reverse proxy")
+ }
+ return nil
+}
+
// validateProviderAPIKey checks that the required API key is set for the configured provider.
func (c *Config) validateProviderAPIKey() error {
switch c.resolvedProvider() {
- case "gemini":
+ case ProviderGemini:
if os.Getenv("GEMINI_API_KEY") == "" {
return fmt.Errorf("%w: GEMINI_API_KEY environment variable is required for provider %q\n"+
"Get your API key at: https://ai.google.dev/gemini-api/docs/api-key",
ErrMissingAPIKey, c.resolvedProvider())
}
- case "openai":
+ case ProviderOpenAI:
if os.Getenv("OPENAI_API_KEY") == "" {
return fmt.Errorf("%w: OPENAI_API_KEY environment variable is required for provider %q\n"+
"Get your API key at: https://platform.openai.com/api-keys",
ErrMissingAPIKey, c.resolvedProvider())
}
- case "ollama":
+ case ProviderOllama:
// Ollama runs locally, no API key required
}
return nil
diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go
index 7d0d826..05a2799 100644
--- a/internal/config/validation_test.go
+++ b/internal/config/validation_test.go
@@ -14,8 +14,7 @@ func validBaseConfig(provider string) *Config {
ModelName: "gemini-2.5-flash",
Temperature: 0.7,
MaxTokens: 2048,
- RAGTopK: 3,
- EmbedderModel: "text-embedding-004",
+ EmbedderModel: "gemini-embedding-001",
PostgresHost: "localhost",
PostgresPort: 5432,
PostgresPassword: "test_password",
@@ -39,12 +38,12 @@ func setEnvForProvider(t *testing.T, provider string) func() {
switch provider {
case "gemini", "":
if err := os.Setenv("GEMINI_API_KEY", "test-api-key"); err != nil {
- t.Fatalf("Failed to set GEMINI_API_KEY: %v", err)
+ t.Fatalf("setting GEMINI_API_KEY: %v", err)
}
return func() { os.Unsetenv("GEMINI_API_KEY") }
case "openai":
if err := os.Setenv("OPENAI_API_KEY", "test-openai-key"); err != nil {
- t.Fatalf("Failed to set OPENAI_API_KEY: %v", err)
+ t.Fatalf("setting OPENAI_API_KEY: %v", err)
}
return func() { os.Unsetenv("OPENAI_API_KEY") }
case "ollama":
@@ -69,7 +68,7 @@ func TestValidateSuccess(t *testing.T) {
cfg := validBaseConfig(provider)
if err := cfg.Validate(); err != nil {
- t.Errorf("Validate() failed with valid config (provider %q): %v", provider, err)
+ t.Errorf("Validate() unexpected error with valid config (provider %q): %v", provider, err)
}
})
}
@@ -230,43 +229,6 @@ func TestValidateOllamaHost(t *testing.T) {
}
}
-// TestValidateRAGTopK tests RAG top K validation.
-func TestValidateRAGTopK(t *testing.T) {
- cleanup := setEnvForProvider(t, "gemini")
- defer cleanup()
-
- tests := []struct {
- name string
- ragTopK int
- wantErr bool
- }{
- {name: "valid min", ragTopK: 1},
- {name: "valid mid", ragTopK: 5},
- {name: "valid max", ragTopK: 10},
- {name: "invalid zero", ragTopK: 0, wantErr: true},
- {name: "invalid negative", ragTopK: -1, wantErr: true},
- {name: "invalid too high", ragTopK: 11, wantErr: true},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- cfg := validBaseConfig("gemini")
- cfg.RAGTopK = tt.ragTopK
-
- err := cfg.Validate()
- if tt.wantErr && err == nil {
- t.Errorf("expected error for rag_top_k %d, got nil", tt.ragTopK)
- }
- if !tt.wantErr && err != nil {
- t.Errorf("unexpected error for rag_top_k %d: %v", tt.ragTopK, err)
- }
- if tt.wantErr && err != nil && !errors.Is(err, ErrInvalidRAGTopK) {
- t.Errorf("error should be ErrInvalidRAGTopK, got: %v", err)
- }
- })
- }
-}
-
// TestValidateEmbedderModel tests embedder model validation.
func TestValidateEmbedderModel(t *testing.T) {
cleanup := setEnvForProvider(t, "gemini")
@@ -279,8 +241,8 @@ func TestValidateEmbedderModel(t *testing.T) {
if err == nil {
t.Fatal("expected error for empty embedder_model, got nil")
}
- if !strings.Contains(err.Error(), "embedder_model") {
- t.Errorf("error should mention embedder_model, got: %v", err)
+ if !errors.Is(err, ErrInvalidEmbedderModel) {
+ t.Errorf("Validate() error = %v, want ErrInvalidEmbedderModel", err)
}
}
@@ -442,14 +404,14 @@ func TestValidatePostgresSSLMode(t *testing.T) {
// BenchmarkValidate benchmarks configuration validation.
func BenchmarkValidate(b *testing.B) {
if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil {
- b.Fatalf("Failed to set GEMINI_API_KEY: %v", err)
+ b.Fatalf("setting GEMINI_API_KEY: %v", err)
}
defer os.Unsetenv("GEMINI_API_KEY")
cfg := validBaseConfig("gemini")
if err := cfg.Validate(); err != nil {
- b.Fatalf("Validate() failed: %v", err)
+ b.Fatalf("Validate() unexpected error: %v", err)
}
b.ResetTimer()
diff --git a/internal/log/log.go b/internal/log/log.go
deleted file mode 100644
index 099922b..0000000
--- a/internal/log/log.go
+++ /dev/null
@@ -1,109 +0,0 @@
-// Package log provides a unified logging infrastructure for the koopa application.
-//
-// This package provides:
-// - A type alias for *slog.Logger to use as DI dependency
-// - Factory functions to create configured loggers
-// - A Nop logger for testing
-//
-// Design Philosophy:
-// - Use Dependency Injection (DI) for loggers, not globals
-// - Each component receives a logger via constructor
-// - Components can add context via logger.With()
-//
-// Usage:
-//
-// // Create a logger at application startup
-// logger := log.New(log.Config{Level: slog.LevelDebug})
-//
-// // Inject into components with context
-// fileToolset := tools.NewFileToolset(pathVal, logger.With("component", "file"))
-// agent := agent.New(logger.With("component", "agent"), ...)
-//
-// // In tests, use Nop logger or capture to buffer
-// testLogger := log.NewNop()
-// // or
-// var buf bytes.Buffer
-// testLogger := log.NewWithWriter(&buf, log.Config{})
-package log
-
-import (
- "io"
- "log/slog"
- "os"
-)
-
-// Logger is a type alias for *slog.Logger.
-// Using the standard library type directly provides:
-// - Full compatibility with slog ecosystem
-// - Access to With() for adding context
-// - No need for custom interface definitions
-//
-// Components should accept log.Logger as a dependency.
-type Logger = *slog.Logger
-
-// Config defines logger configuration options.
-type Config struct {
- // Level sets the minimum log level. Default: slog.LevelInfo
- Level slog.Level
-
- // JSON enables JSON format output. Default: false (text format)
- JSON bool
-
- // AddSource adds source file information to log entries. Default: false
- AddSource bool
-}
-
-// New creates a new logger with the given configuration.
-// Output is written to os.Stderr by default.
-//
-// Example:
-//
-// logger := log.New(log.Config{
-// Level: slog.LevelDebug,
-// JSON: true,
-// })
-func New(cfg Config) Logger {
- return NewWithWriter(os.Stderr, cfg)
-}
-
-// NewWithWriter creates a new logger that writes to the specified writer.
-// Useful for testing or custom output destinations.
-//
-// Example:
-//
-// var buf bytes.Buffer
-// logger := log.NewWithWriter(&buf, log.Config{})
-// // ... use logger
-// fmt.Println(buf.String()) // inspect log output
-func NewWithWriter(w io.Writer, cfg Config) Logger {
- opts := &slog.HandlerOptions{
- Level: cfg.Level,
- AddSource: cfg.AddSource,
- }
-
- var handler slog.Handler
- if cfg.JSON {
- handler = slog.NewJSONHandler(w, opts)
- } else {
- handler = slog.NewTextHandler(w, opts)
- }
-
- return slog.New(handler)
-}
-
-// NewNop creates a logger that discards all output.
-//
-// WARNING: This should ONLY be used in tests. Never use NewNop() in production
-// code as it will silently discard all logs, making debugging impossible.
-// Production code should always use New() or NewWithWriter() with proper configuration.
-//
-// Example:
-//
-// func TestSomething(t *testing.T) {
-// logger := log.NewNop()
-// sut := NewMyComponent(logger)
-// // ... test without log noise
-// }
-func NewNop() Logger {
- return slog.New(slog.DiscardHandler)
-}
diff --git a/internal/log/log_test.go b/internal/log/log_test.go
deleted file mode 100644
index 0a79fd6..0000000
--- a/internal/log/log_test.go
+++ /dev/null
@@ -1,128 +0,0 @@
-package log
-
-import (
- "bytes"
- "log/slog"
- "strings"
- "testing"
-)
-
-func TestNew(t *testing.T) {
- logger := New(Config{})
- if logger == nil {
- t.Fatal("New() returned nil")
- }
-}
-
-func TestNewWithWriter(t *testing.T) {
- var buf bytes.Buffer
-
- logger := NewWithWriter(&buf, Config{
- Level: slog.LevelDebug,
- })
-
- logger.Info("test message", "key", "value")
-
- output := buf.String()
- if !strings.Contains(output, "test message") {
- t.Errorf("expected output to contain 'test message', got: %s", output)
- }
- if !strings.Contains(output, "key=value") {
- t.Errorf("expected output to contain 'key=value', got: %s", output)
- }
-}
-
-func TestNewWithWriter_JSON(t *testing.T) {
- var buf bytes.Buffer
-
- logger := NewWithWriter(&buf, Config{
- Level: slog.LevelInfo,
- JSON: true,
- })
-
- logger.Info("json test", "foo", "bar")
-
- output := buf.String()
- if !strings.Contains(output, `"msg":"json test"`) {
- t.Errorf("expected JSON output with msg field, got: %s", output)
- }
-}
-
-func TestNewNop(t *testing.T) {
- logger := NewNop()
- if logger == nil {
- t.Fatal("NewNop() returned nil")
- }
-
- // Should not panic
- logger.Info("this should be discarded")
- logger.Error("this too")
-}
-
-func TestLogger_With(t *testing.T) {
- var buf bytes.Buffer
-
- logger := NewWithWriter(&buf, Config{
- Level: slog.LevelInfo,
- })
-
- // Add component context
- componentLogger := logger.With("component", "test")
- componentLogger.Info("component log")
-
- output := buf.String()
- if !strings.Contains(output, "component=test") {
- t.Errorf("expected output to contain 'component=test', got: %s", output)
- }
-}
-
-func TestLogger_Levels(t *testing.T) {
- var buf bytes.Buffer
-
- logger := NewWithWriter(&buf, Config{
- Level: slog.LevelDebug,
- })
-
- logger.Debug("debug msg")
- logger.Info("info msg")
- logger.Warn("warn msg")
- logger.Error("error msg")
-
- output := buf.String()
-
- levels := []string{"DEBUG", "INFO", "WARN", "ERROR"}
- for _, level := range levels {
- if !strings.Contains(output, level) {
- t.Errorf("expected output to contain %s level", level)
- }
- }
-}
-
-func TestLogger_LevelFiltering(t *testing.T) {
- var buf bytes.Buffer
-
- // Only INFO and above
- logger := NewWithWriter(&buf, Config{
- Level: slog.LevelInfo,
- })
-
- logger.Debug("debug should not appear")
- logger.Info("info should appear")
-
- output := buf.String()
-
- if strings.Contains(output, "debug should not appear") {
- t.Error("DEBUG message should be filtered out")
- }
- if !strings.Contains(output, "info should appear") {
- t.Error("INFO message should appear")
- }
-}
-
-func TestLoggerTypeAlias(t *testing.T) {
- // Verify that Logger is compatible with *slog.Logger
- logger := slog.Default()
- if logger == nil {
- t.Fatal("Logger type alias should be compatible with *slog.Logger")
- }
-}
diff --git a/internal/mcp/benchmark_test.go b/internal/mcp/benchmark_test.go
index afec82a..cf0791d 100644
--- a/internal/mcp/benchmark_test.go
+++ b/internal/mcp/benchmark_test.go
@@ -16,46 +16,48 @@ import (
// BenchmarkServer_Creation benchmarks MCP server creation.
// Run with: go test -bench=BenchmarkServer_Creation -benchmem ./internal/mcp/...
func BenchmarkServer_Creation(b *testing.B) {
- for b.Loop() {
- tmpDir := b.TempDir()
- pathVal, err := security.NewPath([]string{tmpDir})
- if err != nil {
- b.Fatalf("Failed to create path validator: %v", err)
- }
+ // Setup toolsets once — benchmark only NewServer + tool registration.
+ tmpDir := b.TempDir()
+ pathVal, err := security.NewPath([]string{tmpDir})
+ if err != nil {
+ b.Fatalf("creating path validator: %v", err)
+ }
- fileTools, err := tools.NewFileTools(pathVal, slog.Default())
- if err != nil {
- b.Fatalf("Failed to create file tools: %v", err)
- }
+ fileTools, err := tools.NewFile(pathVal, slog.Default())
+ if err != nil {
+ b.Fatalf("creating file tools: %v", err)
+ }
- cmdVal := security.NewCommand()
- envVal := security.NewEnv()
- systemTools, err := tools.NewSystemTools(cmdVal, envVal, slog.Default())
- if err != nil {
- b.Fatalf("Failed to create system tools: %v", err)
- }
+ cmdVal := security.NewCommand()
+ envVal := security.NewEnv()
+ systemTools, err := tools.NewSystem(cmdVal, envVal, slog.Default())
+ if err != nil {
+ b.Fatalf("creating system tools: %v", err)
+ }
- networkTools, err := tools.NewNetworkTools(tools.NetworkConfig{
- SearchBaseURL: "http://localhost:8080",
- FetchParallelism: 2,
- FetchDelay: 100 * time.Millisecond,
- FetchTimeout: 30 * time.Second,
- }, slog.Default())
- if err != nil {
- b.Fatalf("Failed to create network tools: %v", err)
- }
+ networkTools, err := tools.NewNetwork(tools.NetConfig{
+ SearchBaseURL: "http://localhost:8080",
+ FetchParallelism: 2,
+ FetchDelay: 100 * time.Millisecond,
+ FetchTimeout: 30 * time.Second,
+ }, slog.Default())
+ if err != nil {
+ b.Fatalf("creating network tools: %v", err)
+ }
- cfg := Config{
- Name: "benchmark-server",
- Version: "1.0.0",
- FileTools: fileTools,
- SystemTools: systemTools,
- NetworkTools: networkTools,
- }
+ cfg := Config{
+ Name: "benchmark-server",
+ Version: "1.0.0",
+ File: fileTools,
+ System: systemTools,
+ Network: networkTools,
+ }
+ b.ResetTimer()
+ for b.Loop() {
_, err = NewServer(cfg)
if err != nil {
- b.Fatalf("NewServer failed: %v", err)
+ b.Fatalf("NewServer(): %v", err)
}
}
}
@@ -70,7 +72,7 @@ func BenchmarkJSONRPC_Parse(b *testing.B) {
for b.Loop() {
var request map[string]any
if err := json.Unmarshal([]byte(requestJSON), &request); err != nil {
- b.Fatalf("JSON unmarshal failed: %v", err)
+ b.Fatalf("unmarshaling JSON: %v", err)
}
}
}
@@ -96,7 +98,7 @@ func BenchmarkJSONRPC_Parse_LargePayload(b *testing.B) {
for b.Loop() {
var parsed map[string]any
if err := json.Unmarshal(responseJSON, &parsed); err != nil {
- b.Fatalf("JSON unmarshal failed: %v", err)
+ b.Fatalf("unmarshaling JSON: %v", err)
}
}
}
@@ -120,7 +122,7 @@ func BenchmarkJSONRPC_Serialize(b *testing.B) {
for b.Loop() {
_, err := json.Marshal(response)
if err != nil {
- b.Fatalf("JSON marshal failed: %v", err)
+ b.Fatalf("marshaling JSON: %v", err)
}
}
}
@@ -144,7 +146,7 @@ func BenchmarkJSONRPC_Serialize_LargePayload(b *testing.B) {
for b.Loop() {
_, err := json.Marshal(response)
if err != nil {
- b.Fatalf("JSON marshal failed: %v", err)
+ b.Fatalf("marshaling JSON: %v", err)
}
}
}
@@ -157,7 +159,7 @@ func BenchmarkReadFileInput_Parse(b *testing.B) {
for b.Loop() {
var input tools.ReadFileInput
if err := json.Unmarshal([]byte(inputJSON), &input); err != nil {
- b.Fatalf("JSON unmarshal failed: %v", err)
+ b.Fatalf("unmarshaling JSON: %v", err)
}
}
}
@@ -167,37 +169,37 @@ func BenchmarkConfig_Validation(b *testing.B) {
tmpDir := b.TempDir()
pathVal, err := security.NewPath([]string{tmpDir})
if err != nil {
- b.Fatalf("Failed to create path validator: %v", err)
+ b.Fatalf("creating path validator: %v", err)
}
- fileTools, err := tools.NewFileTools(pathVal, slog.Default())
+ fileTools, err := tools.NewFile(pathVal, slog.Default())
if err != nil {
- b.Fatalf("Failed to create file tools: %v", err)
+ b.Fatalf("creating file tools: %v", err)
}
cmdVal := security.NewCommand()
envVal := security.NewEnv()
- systemTools, err := tools.NewSystemTools(cmdVal, envVal, slog.Default())
+ systemTools, err := tools.NewSystem(cmdVal, envVal, slog.Default())
if err != nil {
- b.Fatalf("Failed to create system tools: %v", err)
+ b.Fatalf("creating system tools: %v", err)
}
- networkTools, err := tools.NewNetworkTools(tools.NetworkConfig{
+ networkTools, err := tools.NewNetwork(tools.NetConfig{
SearchBaseURL: "http://localhost:8080",
FetchParallelism: 2,
FetchDelay: 100 * time.Millisecond,
FetchTimeout: 30 * time.Second,
}, slog.Default())
if err != nil {
- b.Fatalf("Failed to create network tools: %v", err)
+ b.Fatalf("creating network tools: %v", err)
}
cfg := Config{
- Name: "validation-test",
- Version: "1.0.0",
- FileTools: fileTools,
- SystemTools: systemTools,
- NetworkTools: networkTools,
+ Name: "validation-test",
+ Version: "1.0.0",
+ File: fileTools,
+ System: systemTools,
+ Network: networkTools,
}
b.ResetTimer()
diff --git a/internal/mcp/doc.go b/internal/mcp/doc.go
index 5534034..7247cc8 100644
--- a/internal/mcp/doc.go
+++ b/internal/mcp/doc.go
@@ -21,7 +21,7 @@
// +-- Handler Methods (ReadFile, WriteFile, ExecuteCommand, ...)
// |
// v
-// Toolsets (FileTools, SystemTools, NetworkTools, KnowledgeTools)
+// Toolsets (File, System, Network, Knowledge)
// |
// v
// Execution Results
@@ -34,7 +34,7 @@
//
// Network tools (2): web_search, web_fetch
//
-// Knowledge tools (3, optional): search_history, search_documents, search_system_knowledge
+// Knowledge tools (3-4, optional): search_history, search_documents, search_system_knowledge, knowledge_store
//
// # Tool Handler Pattern
//
diff --git a/internal/mcp/file.go b/internal/mcp/file.go
index b13024b..a4b35bc 100644
--- a/internal/mcp/file.go
+++ b/internal/mcp/file.go
@@ -10,61 +10,73 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"
)
-// registerFileTools registers all file operation tools to the MCP server.
+// registerFile registers all file operation tools to the MCP server.
// Tools: read_file, write_file, list_files, delete_file, get_file_info
-func (s *Server) registerFileTools() error {
+func (s *Server) registerFile() error {
// read_file
readFileSchema, err := jsonschema.For[tools.ReadFileInput](nil)
if err != nil {
- return fmt.Errorf("schema for %s: %w", tools.ToolReadFile, err)
+ return fmt.Errorf("schema for %s: %w", tools.ReadFileName, err)
}
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolReadFile,
- Description: "Read the complete content of any text-based file.",
+ Name: tools.ReadFileName,
+ Description: "Read the complete content of a text-based file. " +
+ "Use this to examine source code, configuration files, logs, or documentation. " +
+ "Supports files up to 10MB. Binary files are not supported and will return an error. " +
+ "Returns: file path, content (UTF-8), size in bytes, and line count.",
InputSchema: readFileSchema,
}, s.ReadFile)
// write_file
writeFileSchema, err := jsonschema.For[tools.WriteFileInput](nil)
if err != nil {
- return fmt.Errorf("schema for %s: %w", tools.ToolWriteFile, err)
+ return fmt.Errorf("schema for %s: %w", tools.WriteFileName, err)
}
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolWriteFile,
- Description: "Write or create any text-based file.",
+ Name: tools.WriteFileName,
+ Description: "Write or create a text-based file with the specified content. " +
+ "Creates parent directories automatically if they don't exist. " +
+ "Overwrites existing files without confirmation. " +
+ "Returns: file path, bytes written, whether file was created or updated.",
InputSchema: writeFileSchema,
}, s.WriteFile)
// list_files
listFilesSchema, err := jsonschema.For[tools.ListFilesInput](nil)
if err != nil {
- return fmt.Errorf("schema for %s: %w", tools.ToolListFiles, err)
+ return fmt.Errorf("schema for %s: %w", tools.ListFilesName, err)
}
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolListFiles,
- Description: "List all files and subdirectories in a directory.",
+ Name: tools.ListFilesName,
+ Description: "List files and subdirectories in a directory. " +
+ "Returns file names, sizes, types (file/directory), and modification times. " +
+ "Does not recurse into subdirectories.",
InputSchema: listFilesSchema,
}, s.ListFiles)
// delete_file
deleteFileSchema, err := jsonschema.For[tools.DeleteFileInput](nil)
if err != nil {
- return fmt.Errorf("schema for %s: %w", tools.ToolDeleteFile, err)
+ return fmt.Errorf("schema for %s: %w", tools.DeleteFileName, err)
}
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolDeleteFile,
- Description: "Delete a file permanently.",
+ Name: tools.DeleteFileName,
+ Description: "Permanently delete a file or empty directory. " +
+ "WARNING: This action cannot be undone. " +
+ "Only deletes empty directories.",
InputSchema: deleteFileSchema,
}, s.DeleteFile)
// get_file_info
getFileInfoSchema, err := jsonschema.For[tools.GetFileInfoInput](nil)
if err != nil {
- return fmt.Errorf("schema for %s: %w", tools.ToolGetFileInfo, err)
+ return fmt.Errorf("schema for %s: %w", tools.FileInfoName, err)
}
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolGetFileInfo,
- Description: "Get detailed metadata about a file.",
+ Name: tools.FileInfoName,
+ Description: "Get detailed metadata about a file without reading its contents. " +
+ "Returns: file size, modification time, permissions, and type (file/directory). " +
+ "More efficient than read_file when you only need metadata.",
InputSchema: getFileInfoSchema,
}, s.GetFileInfo)
@@ -74,54 +86,54 @@ func (s *Server) registerFileTools() error {
// ReadFile handles the readFile MCP tool call.
func (s *Server) ReadFile(ctx context.Context, _ *mcp.CallToolRequest, input tools.ReadFileInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.fileTools.ReadFile(toolCtx, input)
+ result, err := s.file.ReadFile(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("readFile failed: %w", err)
+ return nil, nil, fmt.Errorf("reading file: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
// WriteFile handles the writeFile MCP tool call.
func (s *Server) WriteFile(ctx context.Context, _ *mcp.CallToolRequest, input tools.WriteFileInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.fileTools.WriteFile(toolCtx, input)
+ result, err := s.file.WriteFile(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("writeFile failed: %w", err)
+ return nil, nil, fmt.Errorf("writing file: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
// ListFiles handles the listFiles MCP tool call.
func (s *Server) ListFiles(ctx context.Context, _ *mcp.CallToolRequest, input tools.ListFilesInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.fileTools.ListFiles(toolCtx, input)
+ result, err := s.file.ListFiles(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("listFiles failed: %w", err)
+ return nil, nil, fmt.Errorf("listing files: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
// DeleteFile handles the deleteFile MCP tool call.
func (s *Server) DeleteFile(ctx context.Context, _ *mcp.CallToolRequest, input tools.DeleteFileInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.fileTools.DeleteFile(toolCtx, input)
+ result, err := s.file.DeleteFile(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("deleteFile failed: %w", err)
+ return nil, nil, fmt.Errorf("deleting file: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
// GetFileInfo handles the getFileInfo MCP tool call.
func (s *Server) GetFileInfo(ctx context.Context, _ *mcp.CallToolRequest, input tools.GetFileInfoInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.fileTools.GetFileInfo(toolCtx, input)
+ result, err := s.file.GetFileInfo(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("getFileInfo failed: %w", err)
+ return nil, nil, fmt.Errorf("getting file info: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
diff --git a/internal/mcp/file_test.go b/internal/mcp/file_test.go
index deff7f6..c8b4efc 100644
--- a/internal/mcp/file_test.go
+++ b/internal/mcp/file_test.go
@@ -16,14 +16,14 @@ func TestReadFile_Success(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
// Create test file
testFile := filepath.Join(h.tempDir, "test.txt")
testContent := "hello world"
if err := os.WriteFile(testFile, []byte(testContent), 0o600); err != nil {
- t.Fatalf("failed to create test file: %v", err)
+ t.Fatalf("creating test file: %v", err)
}
// Call ReadFile handler
@@ -32,7 +32,7 @@ func TestReadFile_Success(t *testing.T) {
})
if err != nil {
- t.Fatalf("ReadFile failed: %v", err)
+ t.Fatalf("ReadFile(): %v", err)
}
if result.IsError {
@@ -51,7 +51,7 @@ func TestReadFile_FileNotFound(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
// Call ReadFile with non-existent file
@@ -74,7 +74,7 @@ func TestReadFile_SecurityViolation(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
// Try to read file outside allowed directory
@@ -97,7 +97,7 @@ func TestWriteFile_Success(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
testFile := filepath.Join(h.tempDir, "write_test.txt")
@@ -109,7 +109,7 @@ func TestWriteFile_Success(t *testing.T) {
})
if err != nil {
- t.Fatalf("WriteFile failed: %v", err)
+ t.Fatalf("WriteFile(): %v", err)
}
if result.IsError {
@@ -119,7 +119,7 @@ func TestWriteFile_Success(t *testing.T) {
// Verify file was created
content, err := os.ReadFile(testFile)
if err != nil {
- t.Fatalf("failed to read written file: %v", err)
+ t.Fatalf("reading written file: %v", err)
}
if string(content) != testContent {
@@ -133,7 +133,7 @@ func TestWriteFile_CreatesDirectory(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
// Write to a file in a non-existent subdirectory
@@ -145,7 +145,7 @@ func TestWriteFile_CreatesDirectory(t *testing.T) {
})
if err != nil {
- t.Fatalf("WriteFile failed: %v", err)
+ t.Fatalf("WriteFile(): %v", err)
}
if result.IsError {
@@ -164,19 +164,19 @@ func TestListFiles_Success(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
// Create some test files
for _, name := range []string{"file1.txt", "file2.txt"} {
if err := os.WriteFile(filepath.Join(h.tempDir, name), []byte("test"), 0o600); err != nil {
- t.Fatalf("failed to create test file: %v", err)
+ t.Fatalf("creating test file: %v", err)
}
}
// Create a subdirectory
if err := os.Mkdir(filepath.Join(h.tempDir, "subdir"), 0o750); err != nil {
- t.Fatalf("failed to create subdirectory: %v", err)
+ t.Fatalf("creating subdirectory: %v", err)
}
result, _, err := server.ListFiles(context.Background(), &mcp.CallToolRequest{}, tools.ListFilesInput{
@@ -184,7 +184,7 @@ func TestListFiles_Success(t *testing.T) {
})
if err != nil {
- t.Fatalf("ListFiles failed: %v", err)
+ t.Fatalf("ListFiles(): %v", err)
}
if result.IsError {
@@ -198,7 +198,7 @@ func TestListFiles_DirectoryNotFound(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
result, _, err := server.ListFiles(context.Background(), &mcp.CallToolRequest{}, tools.ListFilesInput{
@@ -220,13 +220,13 @@ func TestDeleteFile_Success(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
// Create test file
testFile := filepath.Join(h.tempDir, "to_delete.txt")
if err := os.WriteFile(testFile, []byte("delete me"), 0o600); err != nil {
- t.Fatalf("failed to create test file: %v", err)
+ t.Fatalf("creating test file: %v", err)
}
result, _, err := server.DeleteFile(context.Background(), &mcp.CallToolRequest{}, tools.DeleteFileInput{
@@ -234,7 +234,7 @@ func TestDeleteFile_Success(t *testing.T) {
})
if err != nil {
- t.Fatalf("DeleteFile failed: %v", err)
+ t.Fatalf("DeleteFile(): %v", err)
}
if result.IsError {
@@ -253,7 +253,7 @@ func TestDeleteFile_FileNotFound(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
result, _, err := server.DeleteFile(context.Background(), &mcp.CallToolRequest{}, tools.DeleteFileInput{
@@ -275,14 +275,14 @@ func TestGetFileInfo_Success(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
// Create test file
testFile := filepath.Join(h.tempDir, "info_test.txt")
testContent := "test content for info"
if err := os.WriteFile(testFile, []byte(testContent), 0o600); err != nil {
- t.Fatalf("failed to create test file: %v", err)
+ t.Fatalf("creating test file: %v", err)
}
result, _, err := server.GetFileInfo(context.Background(), &mcp.CallToolRequest{}, tools.GetFileInfoInput{
@@ -290,7 +290,7 @@ func TestGetFileInfo_Success(t *testing.T) {
})
if err != nil {
- t.Fatalf("GetFileInfo failed: %v", err)
+ t.Fatalf("GetFileInfo(): %v", err)
}
if result.IsError {
@@ -304,7 +304,7 @@ func TestGetFileInfo_FileNotFound(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
result, _, err := server.GetFileInfo(context.Background(), &mcp.CallToolRequest{}, tools.GetFileInfoInput{
diff --git a/internal/mcp/integration_test.go b/internal/mcp/integration_test.go
index 1e0a42b..3d64418 100644
--- a/internal/mcp/integration_test.go
+++ b/internal/mcp/integration_test.go
@@ -31,35 +31,35 @@ func createIntegrationTestConfig(t *testing.T, name string) Config {
t.Fatalf("security.NewPath(%q) unexpected error: %v", realTmpDir, err)
}
- fileTools, err := tools.NewFileTools(pathVal, slog.Default())
+ file, err := tools.NewFile(pathVal, slog.Default())
if err != nil {
- t.Fatalf("tools.NewFileTools() unexpected error: %v", err)
+ t.Fatalf("tools.NewFile() unexpected error: %v", err)
}
cmdVal := security.NewCommand()
envVal := security.NewEnv()
- systemTools, err := tools.NewSystemTools(cmdVal, envVal, slog.Default())
+ system, err := tools.NewSystem(cmdVal, envVal, slog.Default())
if err != nil {
- t.Fatalf("tools.NewSystemTools() unexpected error: %v", err)
+ t.Fatalf("tools.NewSystem() unexpected error: %v", err)
}
- networkCfg := tools.NetworkConfig{
+ networkCfg := tools.NetConfig{
SearchBaseURL: "http://localhost:8080",
FetchParallelism: 2,
FetchDelay: 100 * time.Millisecond,
FetchTimeout: 30 * time.Second,
}
- networkTools, err := tools.NewNetworkTools(networkCfg, slog.Default())
+ network, err := tools.NewNetwork(networkCfg, slog.Default())
if err != nil {
- t.Fatalf("tools.NewNetworkTools() unexpected error: %v", err)
+ t.Fatalf("tools.NewNetwork() unexpected error: %v", err)
}
return Config{
- Name: name,
- Version: "1.0.0",
- FileTools: fileTools,
- SystemTools: systemTools,
- NetworkTools: networkTools,
+ Name: name,
+ Version: "1.0.0",
+ File: file,
+ System: system,
+ Network: network,
}
}
@@ -69,24 +69,29 @@ func createIntegrationTestConfig(t *testing.T, name string) Config {
// Run with: go test -race ./internal/mcp/...
func TestServer_ConcurrentCreation(t *testing.T) {
const numGoroutines = 10
+
+ // Pre-create configs outside goroutines — createIntegrationTestConfig
+ // calls t.Fatalf which is undefined behavior from goroutines.
+ configs := make([]Config, numGoroutines)
+ for i := range configs {
+ configs[i] = createIntegrationTestConfig(t, "race-test-server")
+ }
+
var wg sync.WaitGroup
errors := make(chan error, numGoroutines)
servers := make(chan *Server, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
- go func(id int) {
+ go func(cfg Config) {
defer wg.Done()
-
- cfg := createIntegrationTestConfig(t, "race-test-server")
-
server, err := NewServer(cfg)
if err != nil {
errors <- err
return
}
servers <- server
- }(i)
+ }(configs[i])
}
wg.Wait()
@@ -143,11 +148,9 @@ func TestServer_ConcurrentToolsetAccess(t *testing.T) {
go func() {
defer wg.Done()
// Access toolset fields (read-only)
- _ = server.fileTools
- _ = server.systemTools
- _ = server.networkTools
- _ = server.name
- _ = server.version
+ _ = server.file
+ _ = server.system
+ _ = server.network
}()
}
@@ -171,36 +174,24 @@ func TestServer_RaceDetector(t *testing.T) {
// Concurrent field access (read-only operations)
for i := 0; i < numOps; i++ {
- wg.Add(5)
-
- // Read name
- go func() {
- defer wg.Done()
- _ = server.name
- }()
-
- // Read version
- go func() {
- defer wg.Done()
- _ = server.version
- }()
+ wg.Add(3)
// Read file tools
go func() {
defer wg.Done()
- _ = server.fileTools
+ _ = server.file
}()
// Read system tools
go func() {
defer wg.Done()
- _ = server.systemTools
+ _ = server.system
}()
// Read network tools
go func() {
defer wg.Done()
- _ = server.networkTools
+ _ = server.network
}()
}
@@ -215,14 +206,14 @@ func TestConfig_ConcurrentValidation(t *testing.T) {
validCfg := createIntegrationTestConfig(t, "valid")
configs := []Config{
- {Name: "server1", Version: "1.0.0", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools},
- {Name: "server2", Version: "2.0.0", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools},
- {Name: "", Version: "1.0.0", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools}, // Invalid: no name
- {Name: "server3", Version: "", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools}, // Invalid: no version
- {Name: "server4", Version: "1.0.0", FileTools: nil, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools}, // Invalid: no file tools
- {Name: "server5", Version: "1.0.0", FileTools: validCfg.FileTools, SystemTools: nil, NetworkTools: validCfg.NetworkTools}, // Invalid: no system tools
- {Name: "server6", Version: "1.0.0", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: nil}, // Invalid: no network tools
- {Name: "server7", Version: "3.0.0", FileTools: validCfg.FileTools, SystemTools: validCfg.SystemTools, NetworkTools: validCfg.NetworkTools},
+ {Name: "server1", Version: "1.0.0", File: validCfg.File, System: validCfg.System, Network: validCfg.Network},
+ {Name: "server2", Version: "2.0.0", File: validCfg.File, System: validCfg.System, Network: validCfg.Network},
+ {Name: "", Version: "1.0.0", File: validCfg.File, System: validCfg.System, Network: validCfg.Network}, // Invalid: no name
+ {Name: "server3", Version: "", File: validCfg.File, System: validCfg.System, Network: validCfg.Network}, // Invalid: no version
+ {Name: "server4", Version: "1.0.0", File: nil, System: validCfg.System, Network: validCfg.Network}, // Invalid: no file tools
+ {Name: "server5", Version: "1.0.0", File: validCfg.File, System: nil, Network: validCfg.Network}, // Invalid: no system tools
+ {Name: "server6", Version: "1.0.0", File: validCfg.File, System: validCfg.System, Network: nil}, // Invalid: no network tools
+ {Name: "server7", Version: "3.0.0", File: validCfg.File, System: validCfg.System, Network: validCfg.Network},
}
var wg sync.WaitGroup
diff --git a/internal/mcp/knowledge.go b/internal/mcp/knowledge.go
index 853f7c6..f1e1428 100644
--- a/internal/mcp/knowledge.go
+++ b/internal/mcp/knowledge.go
@@ -10,90 +10,104 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"
)
-// registerKnowledgeTools registers all knowledge tools to the MCP server.
+// registerKnowledge registers all knowledge tools to the MCP server.
// Tools: search_history, search_documents, search_system_knowledge, knowledge_store
-func (s *Server) registerKnowledgeTools() error {
+func (s *Server) registerKnowledge() error {
searchSchema, err := jsonschema.For[tools.KnowledgeSearchInput](nil)
if err != nil {
return fmt.Errorf("schema for knowledge search tools: %w", err)
}
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolSearchHistory,
+ Name: tools.SearchHistoryName,
Description: "Search conversation history using semantic similarity. " +
- "Finds past exchanges related to the query.",
+ "Finds past exchanges that are conceptually related to the query. " +
+ "Returns: matched conversation turns with timestamps and similarity scores. " +
+ "Use this to: recall past discussions, find context from earlier conversations. " +
+ "Default topK: 3. Maximum topK: 10.",
InputSchema: searchSchema,
}, s.SearchHistory)
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolSearchDocuments,
+ Name: tools.SearchDocumentsName,
Description: "Search indexed documents (PDFs, code files, notes) using semantic similarity. " +
- "Finds document sections related to the query.",
+ "Finds document sections that are conceptually related to the query. " +
+ "Returns: document titles, content excerpts, and similarity scores. " +
+ "Use this to: find relevant documentation, locate code examples, research topics. " +
+ "Default topK: 5. Maximum topK: 10.",
InputSchema: searchSchema,
}, s.SearchDocuments)
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolSearchSystemKnowledge,
+ Name: tools.SearchSystemKnowledgeName,
Description: "Search system knowledge base (tool usage, commands, patterns) using semantic similarity. " +
- "Finds internal system documentation and usage patterns.",
+ "Finds internal system documentation and usage patterns. " +
+ "Returns: knowledge entries with descriptions and examples. " +
+ "Use this to: understand tool capabilities, find command syntax, learn system patterns. " +
+ "Default topK: 3. Maximum topK: 10.",
InputSchema: searchSchema,
}, s.SearchSystemKnowledge)
- storeSchema, err := jsonschema.For[tools.KnowledgeStoreInput](nil)
- if err != nil {
- return fmt.Errorf("schema for knowledge store tool: %w", err)
+ // Register knowledge_store only when DocStore is available.
+ if s.knowledge.HasDocStore() {
+ storeSchema, err := jsonschema.For[tools.KnowledgeStoreInput](nil)
+ if err != nil {
+ return fmt.Errorf("schema for knowledge store tool: %w", err)
+ }
+
+ mcp.AddTool(s.mcpServer, &mcp.Tool{
+ Name: tools.StoreKnowledgeName,
+ Description: "Store a knowledge entry for later retrieval via search_documents. " +
+ "Use this to save important information, notes, or learnings " +
+ "that the user wants to remember across sessions. " +
+ "Each entry gets a unique ID and is indexed for semantic search.",
+ InputSchema: storeSchema,
+ }, s.StoreKnowledge)
}
- mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolStoreKnowledge,
- Description: "Store a knowledge entry for later retrieval via search_documents. " +
- "Saves important information, notes, or learnings across sessions.",
- InputSchema: storeSchema,
- }, s.StoreKnowledge)
-
return nil
}
// SearchHistory handles the search_history MCP tool call.
func (s *Server) SearchHistory(ctx context.Context, _ *mcp.CallToolRequest, input tools.KnowledgeSearchInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.knowledgeTools.SearchHistory(toolCtx, input)
+ result, err := s.knowledge.SearchHistory(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("searchHistory failed: %w", err)
+ return nil, nil, fmt.Errorf("searching history: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
// SearchDocuments handles the search_documents MCP tool call.
func (s *Server) SearchDocuments(ctx context.Context, _ *mcp.CallToolRequest, input tools.KnowledgeSearchInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.knowledgeTools.SearchDocuments(toolCtx, input)
+ result, err := s.knowledge.SearchDocuments(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("searchDocuments failed: %w", err)
+ return nil, nil, fmt.Errorf("searching documents: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
// SearchSystemKnowledge handles the search_system_knowledge MCP tool call.
func (s *Server) SearchSystemKnowledge(ctx context.Context, _ *mcp.CallToolRequest, input tools.KnowledgeSearchInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.knowledgeTools.SearchSystemKnowledge(toolCtx, input)
+ result, err := s.knowledge.SearchSystemKnowledge(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("searchSystemKnowledge failed: %w", err)
+ return nil, nil, fmt.Errorf("searching system knowledge: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
// StoreKnowledge handles the knowledge_store MCP tool call.
func (s *Server) StoreKnowledge(ctx context.Context, _ *mcp.CallToolRequest, input tools.KnowledgeStoreInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.knowledgeTools.StoreKnowledge(toolCtx, input)
+ result, err := s.knowledge.StoreKnowledge(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("storeKnowledge failed: %w", err)
+ return nil, nil, fmt.Errorf("storing knowledge: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
diff --git a/internal/mcp/network.go b/internal/mcp/network.go
index b475248..dcd8276 100644
--- a/internal/mcp/network.go
+++ b/internal/mcp/network.go
@@ -10,28 +10,34 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"
)
-// registerNetworkTools registers all network operation tools to the MCP server.
+// registerNetwork registers all network operation tools to the MCP server.
// Tools: web_search, web_fetch
-func (s *Server) registerNetworkTools() error {
+func (s *Server) registerNetwork() error {
// web_search
searchSchema, err := jsonschema.For[tools.SearchInput](nil)
if err != nil {
- return fmt.Errorf("schema for %s: %w", tools.ToolWebSearch, err)
+ return fmt.Errorf("schema for %s: %w", tools.WebSearchName, err)
}
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolWebSearch,
- Description: "Search the web for information. Returns relevant results with titles, URLs, and content snippets.",
+ Name: tools.WebSearchName,
+ Description: "Search the web for information. Returns relevant results with titles, URLs, and content snippets. " +
+ "Use this to find current information, news, or facts from the internet.",
InputSchema: searchSchema,
}, s.WebSearch)
// web_fetch
fetchSchema, err := jsonschema.For[tools.FetchInput](nil)
if err != nil {
- return fmt.Errorf("schema for %s: %w", tools.ToolWebFetch, err)
+ return fmt.Errorf("schema for %s: %w", tools.WebFetchName, err)
}
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolWebFetch,
- Description: "Fetch and extract content from one or more URLs (max 10). Supports HTML, JSON, and plain text.",
+ Name: tools.WebFetchName,
+ Description: "Fetch and extract content from one or more URLs (max 10). " +
+ "Supports HTML pages, JSON APIs, and plain text. " +
+ "For HTML: uses Readability algorithm to extract main content. " +
+ "Supports parallel fetching with rate limiting. " +
+ "Returns extracted content (max 50KB per URL). " +
+ "Note: Does not render JavaScript - for SPA pages, content may be incomplete.",
InputSchema: fetchSchema,
}, s.WebFetch)
@@ -39,28 +45,22 @@ func (s *Server) registerNetworkTools() error {
}
// WebSearch handles the web_search MCP tool call.
-// Architecture: Direct method call (consistent with file.go and system.go).
func (s *Server) WebSearch(ctx context.Context, _ *mcp.CallToolRequest, input tools.SearchInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
-
- // Direct method call - O(1), consistent with FileToolset.ReadFile() pattern
- result, err := s.networkTools.Search(toolCtx, input)
+ result, err := s.network.Search(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("web_search failed: %w", err)
+ return nil, nil, fmt.Errorf("searching web: %w", err)
}
return dataToMCP(result), nil, nil
}
// WebFetch handles the web_fetch MCP tool call.
-// Architecture: Direct method call (consistent with file.go and system.go).
func (s *Server) WebFetch(ctx context.Context, _ *mcp.CallToolRequest, input tools.FetchInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
-
- // Direct method call - O(1), consistent with FileToolset.ReadFile() pattern
- result, err := s.networkTools.Fetch(toolCtx, input)
+ result, err := s.network.Fetch(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("web_fetch failed: %w", err)
+ return nil, nil, fmt.Errorf("fetching web: %w", err)
}
return dataToMCP(result), nil, nil
diff --git a/internal/mcp/protocol_test.go b/internal/mcp/protocol_test.go
index c03885e..667fcdd 100644
--- a/internal/mcp/protocol_test.go
+++ b/internal/mcp/protocol_test.go
@@ -10,15 +10,12 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"
)
-// connectTestServer creates a Koopa MCP server and an SDK client connected
-// via in-memory transports. Returns the client session for making protocol calls.
-// The server session and client session are cleaned up via t.Cleanup.
-func connectTestServer(t *testing.T) *mcp.ClientSession {
+// connectServer creates a Koopa MCP server from the given config and an SDK
+// client connected via in-memory transports. Returns the client session for
+// making protocol calls. Both sessions are cleaned up via t.Cleanup.
+func connectServer(t *testing.T, cfg Config) *mcp.ClientSession {
t.Helper()
- h := newTestHelper(t)
- cfg := h.createValidConfig()
-
server, err := NewServer(cfg)
if err != nil {
t.Fatalf("NewServer() unexpected error: %v", err)
@@ -27,14 +24,12 @@ func connectTestServer(t *testing.T) *mcp.ClientSession {
ctx := context.Background()
serverTransport, clientTransport := mcp.NewInMemoryTransports()
- // Connect server side
serverSession, err := server.mcpServer.Connect(ctx, serverTransport, nil)
if err != nil {
t.Fatalf("server.Connect() unexpected error: %v", err)
}
t.Cleanup(func() { _ = serverSession.Close() })
- // Connect client side
client := mcp.NewClient(&mcp.Implementation{
Name: "test-client",
Version: "1.0.0",
@@ -49,6 +44,22 @@ func connectTestServer(t *testing.T) *mcp.ClientSession {
return clientSession
}
+// connectTestServer creates a Koopa MCP server without knowledge tools
+// and an SDK client connected via in-memory transports.
+func connectTestServer(t *testing.T) *mcp.ClientSession {
+ t.Helper()
+ h := newTestHelper(t)
+ return connectServer(t, h.createValidConfig())
+}
+
+// connectTestServerWithKnowledge creates a Koopa MCP server including
+// knowledge tools (backed by a mock retriever) and an SDK client.
+func connectTestServerWithKnowledge(t *testing.T) *mcp.ClientSession {
+ t.Helper()
+ h := newTestHelper(t)
+ return connectServer(t, h.createConfigWithKnowledge())
+}
+
// TestProtocol_ListTools verifies that the MCP JSON-RPC tools/list
// endpoint returns all registered tools with correct names.
func TestProtocol_ListTools(t *testing.T) {
@@ -137,7 +148,7 @@ func TestProtocol_CallTool_CurrentTime(t *testing.T) {
// Parse the JSON text result (contains mixed types: strings and numbers)
var timeResult map[string]any
if err := json.Unmarshal([]byte(textContent.Text), &timeResult); err != nil {
- t.Fatalf("CallTool(current_time) failed to parse JSON: %v\ntext: %s", err, textContent.Text)
+ t.Fatalf("CallTool(current_time) parsing JSON: %v\ntext: %s", err, textContent.Text)
}
// Should contain time fields
@@ -164,3 +175,127 @@ func TestProtocol_CallTool_UnknownTool(t *testing.T) {
t.Errorf("CallTool(nonexistent_tool) error = %q, want to contain tool name", err.Error())
}
}
+
+// TestProtocol_ListTools_WithKnowledge verifies that knowledge search tools
+// are registered when Knowledge is configured. knowledge_store is excluded
+// because the test helper creates Knowledge with docStore=nil.
+func TestProtocol_ListTools_WithKnowledge(t *testing.T) {
+ session := connectTestServerWithKnowledge(t)
+
+ result, err := session.ListTools(context.Background(), nil)
+ if err != nil {
+ t.Fatalf("ListTools() unexpected error: %v", err)
+ }
+
+ var names []string
+ for _, tool := range result.Tools {
+ names = append(names, tool.Name)
+ }
+ sort.Strings(names)
+
+ // 10 base + 3 knowledge search tools (knowledge_store excluded: docStore=nil)
+ wantNames := []string{
+ "current_time",
+ "delete_file",
+ "execute_command",
+ "get_env",
+ "get_file_info",
+ "list_files",
+ "read_file",
+ "search_documents",
+ "search_history",
+ "search_system_knowledge",
+ "web_fetch",
+ "web_search",
+ "write_file",
+ }
+
+ if len(names) != len(wantNames) {
+ t.Fatalf("ListTools() returned %d tools, want %d\ngot: %v\nwant: %v", len(names), len(wantNames), names, wantNames)
+ }
+
+ for i, got := range names {
+ if got != wantNames[i] {
+ t.Errorf("ListTools() tool[%d] = %q, want %q", i, got, wantNames[i])
+ }
+ }
+}
+
+// TestProtocol_CallTool_KnowledgeSearch verifies that each knowledge search
+// tool can be called through the MCP JSON-RPC layer and returns results.
+func TestProtocol_CallTool_KnowledgeSearch(t *testing.T) {
+ session := connectTestServerWithKnowledge(t)
+
+ tests := []struct {
+ name string
+ toolName string
+ }{
+ {name: "search_history", toolName: "search_history"},
+ {name: "search_documents", toolName: "search_documents"},
+ {name: "search_system_knowledge", toolName: "search_system_knowledge"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := session.CallTool(context.Background(), &mcp.CallToolParams{
+ Name: tt.toolName,
+ Arguments: map[string]any{
+ "query": "test query",
+ "topK": 3,
+ },
+ })
+ if err != nil {
+ t.Fatalf("CallTool(%q) unexpected error: %v", tt.toolName, err)
+ }
+
+ if result.IsError {
+ t.Fatalf("CallTool(%q) returned error result", tt.toolName)
+ }
+
+ if len(result.Content) == 0 {
+ t.Fatalf("CallTool(%q) returned empty content", tt.toolName)
+ }
+
+ textContent, ok := result.Content[0].(*mcp.TextContent)
+ if !ok {
+ t.Fatalf("CallTool(%q) content[0] type = %T, want *mcp.TextContent", tt.toolName, result.Content[0])
+ }
+
+ // Parse JSON and verify result structure
+ var parsed map[string]any
+ if err := json.Unmarshal([]byte(textContent.Text), &parsed); err != nil {
+ t.Fatalf("CallTool(%q) parsing JSON: %v\ntext: %s", tt.toolName, err, textContent.Text)
+ }
+
+ if parsed["query"] != "test query" {
+ t.Errorf("CallTool(%q) query = %v, want %q", tt.toolName, parsed["query"], "test query")
+ }
+
+ // mock retriever returns 1 document
+ if count, ok := parsed["result_count"].(float64); !ok || count != 1 {
+ t.Errorf("CallTool(%q) result_count = %v, want 1", tt.toolName, parsed["result_count"])
+ }
+ })
+ }
+}
+
+// TestProtocol_CallTool_KnowledgeStore_NoDocStore verifies that
+// knowledge_store is not registered when docStore is nil.
+func TestProtocol_CallTool_KnowledgeStore_NoDocStore(t *testing.T) {
+ session := connectTestServerWithKnowledge(t)
+
+ _, err := session.CallTool(context.Background(), &mcp.CallToolParams{
+ Name: "knowledge_store",
+ Arguments: map[string]any{
+ "title": "test",
+ "content": "test content",
+ },
+ })
+ if err == nil {
+ t.Fatal("CallTool(knowledge_store) expected error for unregistered tool, got nil")
+ }
+
+ if !strings.Contains(err.Error(), "knowledge_store") {
+ t.Errorf("CallTool(knowledge_store) error = %q, want to contain tool name", err.Error())
+ }
+}
diff --git a/internal/mcp/server.go b/internal/mcp/server.go
index 164015c..8e5dd7a 100644
--- a/internal/mcp/server.go
+++ b/internal/mcp/server.go
@@ -2,7 +2,9 @@ package mcp
import (
"context"
+ "errors"
"fmt"
+ "log/slog"
"github.com/koopa0/koopa/internal/tools"
"github.com/modelcontextprotocol/go-sdk/mcp"
@@ -11,42 +13,42 @@ import (
// Server wraps the MCP SDK server and Koopa's tool handlers.
// It exposes Koopa's tools via the Model Context Protocol.
type Server struct {
- mcpServer *mcp.Server
- fileTools *tools.FileTools
- systemTools *tools.SystemTools
- networkTools *tools.NetworkTools
- knowledgeTools *tools.KnowledgeTools // nil when knowledge search is unavailable
- name string
- version string
+ mcpServer *mcp.Server
+ logger *slog.Logger
+ file *tools.File
+ system *tools.System
+ network *tools.Network
+ knowledge *tools.Knowledge // nil when knowledge search is unavailable
}
// Config holds MCP server configuration.
type Config struct {
- Name string
- Version string
- FileTools *tools.FileTools
- SystemTools *tools.SystemTools
- NetworkTools *tools.NetworkTools
- KnowledgeTools *tools.KnowledgeTools // Optional: nil disables knowledge search tools
+ Name string
+ Version string
+ Logger *slog.Logger // Optional: nil uses slog.Default()
+ File *tools.File
+ System *tools.System
+ Network *tools.Network
+ Knowledge *tools.Knowledge // Optional: nil disables knowledge search tools
}
// NewServer creates a new MCP server with the given configuration.
func NewServer(cfg Config) (*Server, error) {
// Validate required config
if cfg.Name == "" {
- return nil, fmt.Errorf("server name is required")
+ return nil, errors.New("server name is required")
}
if cfg.Version == "" {
- return nil, fmt.Errorf("server version is required")
+ return nil, errors.New("server version is required")
}
- if cfg.FileTools == nil {
- return nil, fmt.Errorf("file tools is required")
+ if cfg.File == nil {
+ return nil, errors.New("file tools is required")
}
- if cfg.SystemTools == nil {
- return nil, fmt.Errorf("system tools is required")
+ if cfg.System == nil {
+ return nil, errors.New("system tools is required")
}
- if cfg.NetworkTools == nil {
- return nil, fmt.Errorf("network tools is required")
+ if cfg.Network == nil {
+ return nil, errors.New("network tools is required")
}
// Create MCP server (using official SDK)
@@ -55,19 +57,23 @@ func NewServer(cfg Config) (*Server, error) {
Version: cfg.Version,
}, nil)
+ logger := cfg.Logger
+ if logger == nil {
+ logger = slog.Default()
+ }
+
s := &Server{
- mcpServer: mcpServer,
- fileTools: cfg.FileTools,
- systemTools: cfg.SystemTools,
- networkTools: cfg.NetworkTools,
- knowledgeTools: cfg.KnowledgeTools,
- name: cfg.Name,
- version: cfg.Version,
+ mcpServer: mcpServer,
+ logger: logger,
+ file: cfg.File,
+ system: cfg.System,
+ network: cfg.Network,
+ knowledge: cfg.Knowledge,
}
// Register all tools
if err := s.registerTools(); err != nil {
- return nil, fmt.Errorf("failed to register tools: %w", err)
+ return nil, fmt.Errorf("registering tools: %w", err)
}
return s, nil
@@ -77,28 +83,28 @@ func NewServer(cfg Config) (*Server, error) {
// This is a blocking call that handles all MCP protocol communication.
func (s *Server) Run(ctx context.Context, transport mcp.Transport) error {
if err := s.mcpServer.Run(ctx, transport); err != nil {
- return fmt.Errorf("MCP server run failed: %w", err)
+ return fmt.Errorf("running mcp server: %w", err)
}
return nil
}
// registerTools registers all Toolset tools to the MCP server.
func (s *Server) registerTools() error {
- if err := s.registerFileTools(); err != nil {
+ if err := s.registerFile(); err != nil {
return fmt.Errorf("register file tools: %w", err)
}
- if err := s.registerSystemTools(); err != nil {
+ if err := s.registerSystem(); err != nil {
return fmt.Errorf("register system tools: %w", err)
}
- if err := s.registerNetworkTools(); err != nil {
+ if err := s.registerNetwork(); err != nil {
return fmt.Errorf("register network tools: %w", err)
}
// Knowledge tools are optional (require DB + embedder)
- if s.knowledgeTools != nil {
- if err := s.registerKnowledgeTools(); err != nil {
+ if s.knowledge != nil {
+ if err := s.registerKnowledge(); err != nil {
return fmt.Errorf("register knowledge tools: %w", err)
}
}
diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go
index 7a77b96..d066f45 100644
--- a/internal/mcp/server_test.go
+++ b/internal/mcp/server_test.go
@@ -1,17 +1,36 @@
package mcp
import (
+ "context"
"log/slog"
"net/http"
"net/http/httptest"
"path/filepath"
+ "strings"
"testing"
"time"
+ "github.com/firebase/genkit/go/ai"
+ "github.com/firebase/genkit/go/core/api"
+
"github.com/koopa0/koopa/internal/security"
"github.com/koopa0/koopa/internal/tools"
)
+// mcpTestRetriever implements ai.Retriever for MCP protocol tests.
+// Returns a single mock document so tests can verify result structure.
+type mcpTestRetriever struct{}
+
+func (*mcpTestRetriever) Name() string { return "mock-retriever" }
+func (*mcpTestRetriever) Retrieve(_ context.Context, _ *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
+ return &ai.RetrieverResponse{
+ Documents: []*ai.Document{
+ ai.DocumentFromText("mock result for protocol test", nil),
+ },
+ }, nil
+}
+func (*mcpTestRetriever) Register(_ api.Registry) {}
+
// testHelper provides common test utilities.
type testHelper struct {
t *testing.T
@@ -24,7 +43,7 @@ func newTestHelper(t *testing.T) *testHelper {
tempDir := t.TempDir()
realTempDir, err := filepath.EvalSymlinks(tempDir)
if err != nil {
- t.Fatalf("failed to resolve temp dir symlinks: %v", err)
+ t.Fatalf("resolving temp dir symlinks: %v", err)
}
return &testHelper{
t: t,
@@ -32,33 +51,33 @@ func newTestHelper(t *testing.T) *testHelper {
}
}
-func (h *testHelper) createFileTools() *tools.FileTools {
+func (h *testHelper) createFile() *tools.File {
h.t.Helper()
pathVal, err := security.NewPath([]string{h.tempDir})
if err != nil {
- h.t.Fatalf("failed to create path validator: %v", err)
+ h.t.Fatalf("creating path validator: %v", err)
}
- ft, err := tools.NewFileTools(pathVal, slog.Default())
+ ft, err := tools.NewFile(pathVal, slog.Default())
if err != nil {
- h.t.Fatalf("failed to create file tools: %v", err)
+ h.t.Fatalf("creating file tools: %v", err)
}
return ft
}
-func (h *testHelper) createSystemTools() *tools.SystemTools {
+func (h *testHelper) createSystem() *tools.System {
h.t.Helper()
cmdVal := security.NewCommand()
envVal := security.NewEnv()
- st, err := tools.NewSystemTools(cmdVal, envVal, slog.Default())
+ st, err := tools.NewSystem(cmdVal, envVal, slog.Default())
if err != nil {
- h.t.Fatalf("failed to create system tools: %v", err)
+ h.t.Fatalf("creating system tools: %v", err)
}
return st
}
-func (h *testHelper) createNetworkTools() *tools.NetworkTools {
+func (h *testHelper) createNetwork() *tools.Network {
h.t.Helper()
// Use httptest.NewServer instead of hardcoded localhost URL
@@ -72,8 +91,8 @@ func (h *testHelper) createNetworkTools() *tools.NetworkTools {
// SSRF protection is only checked during Fetch(), not at construction.
// These tests only construct tools for Config, they don't execute network operations.
- nt, err := tools.NewNetworkTools(
- tools.NetworkConfig{
+ nt, err := tools.NewNetwork(
+ tools.NetConfig{
SearchBaseURL: mockServer.URL,
FetchParallelism: 2,
FetchDelay: 100 * time.Millisecond,
@@ -82,22 +101,38 @@ func (h *testHelper) createNetworkTools() *tools.NetworkTools {
slog.Default(),
)
if err != nil {
- h.t.Fatalf("failed to create network tools: %v", err)
+ h.t.Fatalf("creating network tools: %v", err)
}
return nt
}
+func (h *testHelper) createKnowledge() *tools.Knowledge {
+ h.t.Helper()
+ kt, err := tools.NewKnowledge(&mcpTestRetriever{}, nil, slog.New(slog.DiscardHandler))
+ if err != nil {
+ h.t.Fatalf("creating knowledge tools: %v", err)
+ }
+ return kt
+}
+
func (h *testHelper) createValidConfig() Config {
h.t.Helper()
return Config{
- Name: "test-server",
- Version: "1.0.0",
- FileTools: h.createFileTools(),
- SystemTools: h.createSystemTools(),
- NetworkTools: h.createNetworkTools(),
+ Name: "test-server",
+ Version: "1.0.0",
+ File: h.createFile(),
+ System: h.createSystem(),
+ Network: h.createNetwork(),
}
}
+func (h *testHelper) createConfigWithKnowledge() Config {
+ h.t.Helper()
+ cfg := h.createValidConfig()
+ cfg.Knowledge = h.createKnowledge()
+ return cfg
+}
+
// TestNewServer_Success tests successful server creation with all tools.
func TestNewServer_Success(t *testing.T) {
h := newTestHelper(t)
@@ -105,41 +140,32 @@ func TestNewServer_Success(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
- }
-
- // Verify server fields are correctly set
- if server.name != "test-server" {
- t.Errorf("server.name = %q, want %q", server.name, "test-server")
- }
-
- if server.version != "1.0.0" {
- t.Errorf("server.version = %q, want %q", server.version, "1.0.0")
+ t.Fatalf("NewServer(): %v", err)
}
if server.mcpServer == nil {
t.Error("server.mcpServer is nil")
}
- if server.fileTools == nil {
- t.Error("server.fileTools is nil")
+ if server.file == nil {
+ t.Error("server.file is nil")
}
- if server.systemTools == nil {
- t.Error("server.systemTools is nil")
+ if server.system == nil {
+ t.Error("server.system is nil")
}
- if server.networkTools == nil {
- t.Error("server.networkTools is nil")
+ if server.network == nil {
+ t.Error("server.network is nil")
}
}
// TestNewServer_ValidationErrors tests config validation.
func TestNewServer_ValidationErrors(t *testing.T) {
h := newTestHelper(t)
- validFile := h.createFileTools()
- validSystem := h.createSystemTools()
- validNetwork := h.createNetworkTools()
+ validFile := h.createFile()
+ validSystem := h.createSystem()
+ validNetwork := h.createNetwork()
tests := []struct {
name string
@@ -149,50 +175,50 @@ func TestNewServer_ValidationErrors(t *testing.T) {
{
name: "missing name",
config: Config{
- Version: "1.0.0",
- FileTools: validFile,
- SystemTools: validSystem,
- NetworkTools: validNetwork,
+ Version: "1.0.0",
+ File: validFile,
+ System: validSystem,
+ Network: validNetwork,
},
wantErr: "server name is required",
},
{
name: "missing version",
config: Config{
- Name: "test",
- FileTools: validFile,
- SystemTools: validSystem,
- NetworkTools: validNetwork,
+ Name: "test",
+ File: validFile,
+ System: validSystem,
+ Network: validNetwork,
},
wantErr: "server version is required",
},
{
name: "missing file tools",
config: Config{
- Name: "test",
- Version: "1.0.0",
- SystemTools: validSystem,
- NetworkTools: validNetwork,
+ Name: "test",
+ Version: "1.0.0",
+ System: validSystem,
+ Network: validNetwork,
},
wantErr: "file tools is required",
},
{
name: "missing system tools",
config: Config{
- Name: "test",
- Version: "1.0.0",
- FileTools: validFile,
- NetworkTools: validNetwork,
+ Name: "test",
+ Version: "1.0.0",
+ File: validFile,
+ Network: validNetwork,
},
wantErr: "system tools is required",
},
{
name: "missing network tools",
config: Config{
- Name: "test",
- Version: "1.0.0",
- FileTools: validFile,
- SystemTools: validSystem,
+ Name: "test",
+ Version: "1.0.0",
+ File: validFile,
+ System: validSystem,
},
wantErr: "network tools is required",
},
@@ -204,42 +230,9 @@ func TestNewServer_ValidationErrors(t *testing.T) {
if err == nil {
t.Fatal("NewServer succeeded, want error")
}
- if !contains(err.Error(), tt.wantErr) {
+ if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("NewServer error = %q, want to contain %q", err.Error(), tt.wantErr)
}
})
}
}
-
-// TestRegisterTools_AllToolsRegistered verifies all 10 tools are registered.
-func TestRegisterTools_AllToolsRegistered(t *testing.T) {
- h := newTestHelper(t)
- cfg := h.createValidConfig()
-
- server, err := NewServer(cfg)
- if err != nil {
- t.Fatalf("NewServer failed: %v", err)
- }
-
- // Verify server was created successfully (tools registered in constructor)
- if server.mcpServer == nil {
- t.Fatal("mcpServer is nil")
- }
-
- // Note: We can't directly verify tool registration without accessing
- // internal MCP server state. The fact that NewServer succeeded without
- // error means registerTools() completed successfully for all 10 tools:
- // - File: read_file, write_file, list_files, delete_file, get_file_info
- // - System: current_time, execute_command, get_env
- // - Network: web_search, web_fetch
-}
-
-// contains checks if s contains substr.
-func contains(s, substr string) bool {
- for i := 0; i <= len(s)-len(substr); i++ {
- if s[i:i+len(substr)] == substr {
- return true
- }
- }
- return false
-}
diff --git a/internal/mcp/system.go b/internal/mcp/system.go
index ac6bdd3..c74a450 100644
--- a/internal/mcp/system.go
+++ b/internal/mcp/system.go
@@ -10,39 +10,48 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"
)
-// registerSystemTools registers all system operation tools to the MCP server.
+// registerSystem registers all system operation tools to the MCP server.
// Tools: current_time, execute_command, get_env
-func (s *Server) registerSystemTools() error {
+func (s *Server) registerSystem() error {
// current_time
currentTimeSchema, err := jsonschema.For[tools.CurrentTimeInput](nil)
if err != nil {
- return fmt.Errorf("schema for %s: %w", tools.ToolCurrentTime, err)
+ return fmt.Errorf("schema for %s: %w", tools.CurrentTimeName, err)
}
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolCurrentTime,
- Description: "Get the current system date and time in formatted string.",
+ Name: tools.CurrentTimeName,
+ Description: "Get the current system date and time. " +
+ "Returns: formatted time string, Unix timestamp, and ISO 8601 format. " +
+ "Always returns the server's local time zone.",
InputSchema: currentTimeSchema,
}, s.CurrentTime)
// execute_command
executeCommandSchema, err := jsonschema.For[tools.ExecuteCommandInput](nil)
if err != nil {
- return fmt.Errorf("schema for %s: %w", tools.ToolExecuteCommand, err)
+ return fmt.Errorf("schema for %s: %w", tools.ExecuteCommandName, err)
}
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolExecuteCommand,
- Description: "Execute a shell command with security validation. Dangerous commands (rm -rf, sudo, etc.) are blocked.",
+ Name: tools.ExecuteCommandName,
+ Description: "Execute a shell command from the allowed list with security validation. " +
+ "Allowed commands: git, npm, yarn, go, make, docker, kubectl, ls, cat, grep, find, pwd, echo. " +
+ "Commands run with a timeout to prevent hanging. " +
+ "Returns: stdout, stderr, exit code, and execution time. " +
+ "Security: Dangerous commands (rm -rf, sudo, chmod, etc.) are blocked.",
InputSchema: executeCommandSchema,
}, s.ExecuteCommand)
// get_env
getEnvSchema, err := jsonschema.For[tools.GetEnvInput](nil)
if err != nil {
- return fmt.Errorf("schema for %s: %w", tools.ToolGetEnv, err)
+ return fmt.Errorf("schema for %s: %w", tools.GetEnvName, err)
}
mcp.AddTool(s.mcpServer, &mcp.Tool{
- Name: tools.ToolGetEnv,
- Description: "Read an environment variable value. Sensitive variables (*KEY*, *SECRET*, *TOKEN*) are protected.",
+ Name: tools.GetEnvName,
+ Description: "Read an environment variable value from the system. " +
+ "Returns: the variable name and its value. " +
+ "Use this to: check configuration, verify paths, read non-sensitive settings. " +
+ "Security: Sensitive variables containing KEY, SECRET, TOKEN, or PASSWORD in their names are protected and will not be returned.",
InputSchema: getEnvSchema,
}, s.GetEnv)
@@ -52,30 +61,30 @@ func (s *Server) registerSystemTools() error {
// CurrentTime handles the currentTime MCP tool call.
func (s *Server) CurrentTime(ctx context.Context, _ *mcp.CallToolRequest, input tools.CurrentTimeInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.systemTools.CurrentTime(toolCtx, input)
+ result, err := s.system.CurrentTime(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("currentTime failed: %w", err)
+ return nil, nil, fmt.Errorf("getting current time: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
// ExecuteCommand handles the executeCommand MCP tool call.
func (s *Server) ExecuteCommand(ctx context.Context, _ *mcp.CallToolRequest, input tools.ExecuteCommandInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.systemTools.ExecuteCommand(toolCtx, input)
+ result, err := s.system.ExecuteCommand(toolCtx, input)
if err != nil {
// Only infrastructure errors (context cancellation) return Go error
- return nil, nil, fmt.Errorf("executeCommand failed: %w", err)
+ return nil, nil, fmt.Errorf("executing command: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
// GetEnv handles the getEnv MCP tool call.
func (s *Server) GetEnv(ctx context.Context, _ *mcp.CallToolRequest, input tools.GetEnvInput) (*mcp.CallToolResult, any, error) {
toolCtx := &ai.ToolContext{Context: ctx}
- result, err := s.systemTools.GetEnv(toolCtx, input)
+ result, err := s.system.GetEnv(toolCtx, input)
if err != nil {
- return nil, nil, fmt.Errorf("getEnv failed: %w", err)
+ return nil, nil, fmt.Errorf("getting env: %w", err)
}
- return resultToMCP(result), nil, nil
+ return resultToMCP(result, s.logger), nil, nil
}
diff --git a/internal/mcp/system_test.go b/internal/mcp/system_test.go
index 3db437e..73cd4f5 100644
--- a/internal/mcp/system_test.go
+++ b/internal/mcp/system_test.go
@@ -15,13 +15,13 @@ func TestCurrentTime_Success(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
result, _, err := server.CurrentTime(context.Background(), &mcp.CallToolRequest{}, tools.CurrentTimeInput{})
if err != nil {
- t.Fatalf("CurrentTime failed: %v", err)
+ t.Fatalf("CurrentTime(): %v", err)
}
if result.IsError {
@@ -60,7 +60,7 @@ func TestExecuteCommand_Success(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
result, _, err := server.ExecuteCommand(context.Background(), &mcp.CallToolRequest{}, tools.ExecuteCommandInput{
@@ -69,7 +69,7 @@ func TestExecuteCommand_Success(t *testing.T) {
})
if err != nil {
- t.Fatalf("ExecuteCommand failed: %v", err)
+ t.Fatalf("ExecuteCommand(): %v", err)
}
if result.IsError {
@@ -97,7 +97,7 @@ func TestExecuteCommand_DangerousCommandBlocked(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
// Try to execute a dangerous command
@@ -122,7 +122,7 @@ func TestExecuteCommand_CommandNotFound(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
result, _, err := server.ExecuteCommand(context.Background(), &mcp.CallToolRequest{}, tools.ExecuteCommandInput{
@@ -146,7 +146,7 @@ func TestGetEnv_Success(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
// Set a test environment variable using t.Setenv for automatic cleanup
@@ -159,7 +159,7 @@ func TestGetEnv_Success(t *testing.T) {
})
if err != nil {
- t.Fatalf("GetEnv failed: %v", err)
+ t.Fatalf("GetEnv(): %v", err)
}
if result.IsError {
@@ -187,7 +187,7 @@ func TestGetEnv_NotSet(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
result, _, err := server.GetEnv(context.Background(), &mcp.CallToolRequest{}, tools.GetEnvInput{
@@ -195,7 +195,7 @@ func TestGetEnv_NotSet(t *testing.T) {
})
if err != nil {
- t.Fatalf("GetEnv failed: %v", err)
+ t.Fatalf("GetEnv(): %v", err)
}
if result.IsError {
@@ -224,7 +224,7 @@ func TestGetEnv_SensitiveVariableBlocked(t *testing.T) {
server, err := NewServer(cfg)
if err != nil {
- t.Fatalf("NewServer failed: %v", err)
+ t.Fatalf("NewServer(): %v", err)
}
// Try to access a sensitive variable
diff --git a/internal/mcp/util.go b/internal/mcp/util.go
index e5a54a2..2356ecd 100644
--- a/internal/mcp/util.go
+++ b/internal/mcp/util.go
@@ -26,7 +26,12 @@ import (
// resultToMCP converts a tools.Result to mcp.CallToolResult.
// This follows the Direct Inline Handling principle but extracts the common pattern.
-func resultToMCP(result tools.Result) *mcp.CallToolResult {
+// If logger is nil, falls back to slog.Default().
+func resultToMCP(result tools.Result, logger *slog.Logger) *mcp.CallToolResult {
+ if logger == nil {
+ logger = slog.Default()
+ }
+
if result.Status == tools.StatusError {
errorText := fmt.Sprintf("[%s] %s", result.Error.Code, result.Error.Message)
if result.Error.Details != nil {
@@ -36,7 +41,7 @@ func resultToMCP(result tools.Result) *mcp.CallToolResult {
detailsJSON, err := json.Marshal(sanitized)
if err != nil {
// Log internal error, don't expose to client
- slog.Warn("failed to marshal sanitized error details", "error", err)
+ logger.Warn("marshaling sanitized error details", "error", err)
errorText += "\nDetails: (see server logs)"
} else {
errorText += fmt.Sprintf("\nDetails: %s", string(detailsJSON))
@@ -44,7 +49,7 @@ func resultToMCP(result tools.Result) *mcp.CallToolResult {
}
// Always log full details server-side for debugging
- slog.Debug("MCP error details", "details", result.Error.Details)
+ logger.Debug("MCP error details", "details", result.Error.Details)
}
return &mcp.CallToolResult{
diff --git a/internal/mcp/util_test.go b/internal/mcp/util_test.go
index 5422290..f18be55 100644
--- a/internal/mcp/util_test.go
+++ b/internal/mcp/util_test.go
@@ -1,6 +1,7 @@
package mcp
import (
+ "strings"
"testing"
"github.com/koopa0/koopa/internal/tools"
@@ -13,7 +14,7 @@ func TestResultToMCP_Success(t *testing.T) {
Data: map[string]any{"result": "value", "count": 42},
}
- mcpResult := resultToMCP(result)
+ mcpResult := resultToMCP(result, nil)
if mcpResult.IsError {
t.Error("resultToMCP should not set IsError for success status")
@@ -29,7 +30,7 @@ func TestResultToMCP_Success(t *testing.T) {
}
// Data should be JSON marshaled
- if !contains(textContent.Text, "result") || !contains(textContent.Text, "value") {
+ if !strings.Contains(textContent.Text, "result") || !strings.Contains(textContent.Text, "value") {
t.Errorf("resultToMCP text should contain JSON data: %s", textContent.Text)
}
}
@@ -43,7 +44,7 @@ func TestResultToMCP_Error(t *testing.T) {
},
}
- mcpResult := resultToMCP(result)
+ mcpResult := resultToMCP(result, nil)
if !mcpResult.IsError {
t.Error("resultToMCP should set IsError for error status")
@@ -59,11 +60,11 @@ func TestResultToMCP_Error(t *testing.T) {
}
// Should contain error code and message
- if !contains(textContent.Text, string(tools.ErrCodeNotFound)) {
+ if !strings.Contains(textContent.Text, string(tools.ErrCodeNotFound)) {
t.Errorf("resultToMCP text should contain error code: %s", textContent.Text)
}
- if !contains(textContent.Text, "File not found") {
+ if !strings.Contains(textContent.Text, "File not found") {
t.Errorf("resultToMCP text should contain error message: %s", textContent.Text)
}
}
@@ -79,7 +80,7 @@ func TestResultToMCP_ErrorWithDetails(t *testing.T) {
},
}
- mcpResult := resultToMCP(result)
+ mcpResult := resultToMCP(result, nil)
if !mcpResult.IsError {
t.Error("resultToMCP should set IsError for error status")
@@ -95,15 +96,11 @@ func TestResultToMCP_ErrorWithDetails(t *testing.T) {
}
// Should contain "Details:" with whitelisted fields
- if !contains(textContent.Text, "Details:") {
+ if !strings.Contains(textContent.Text, "Details:") {
t.Errorf("resultToMCP text should contain 'Details:': %s", textContent.Text)
}
}
-// =============================================================================
-// dataToMCP Tests
-// =============================================================================
-
func TestDataToMCP_ValidData(t *testing.T) {
data := map[string]any{"key": "value", "count": 42}
result := dataToMCP(data)
@@ -121,7 +118,7 @@ func TestDataToMCP_ValidData(t *testing.T) {
t.Fatal("dataToMCP content is not TextContent")
}
- if !contains(textContent.Text, "key") || !contains(textContent.Text, "value") {
+ if !strings.Contains(textContent.Text, "key") || !strings.Contains(textContent.Text, "value") {
t.Errorf("dataToMCP should contain JSON data: %s", textContent.Text)
}
}
@@ -161,7 +158,7 @@ func TestDataToMCP_SliceData(t *testing.T) {
t.Fatal("dataToMCP content is not TextContent")
}
- if !contains(textContent.Text, "item1") {
+ if !strings.Contains(textContent.Text, "item1") {
t.Errorf("dataToMCP should contain JSON array: %s", textContent.Text)
}
}
@@ -187,7 +184,7 @@ func TestDataToMCP_NestedStruct(t *testing.T) {
t.Fatal("dataToMCP content is not TextContent")
}
- if !contains(textContent.Text, "test") || !contains(textContent.Text, "42") {
+ if !strings.Contains(textContent.Text, "test") || !strings.Contains(textContent.Text, "42") {
t.Errorf("dataToMCP should contain nested JSON: %s", textContent.Text)
}
}
@@ -222,7 +219,7 @@ func TestDataToMCP_ResultNilData(t *testing.T) {
Data: nil,
}
- mcpResult := resultToMCP(result)
+ mcpResult := resultToMCP(result, nil)
if mcpResult.IsError {
t.Error("resultToMCP should not set IsError for success with nil data")
@@ -238,10 +235,6 @@ func TestDataToMCP_ResultNilData(t *testing.T) {
}
}
-// =============================================================================
-// sanitizeErrorDetails Tests
-// =============================================================================
-
func TestSanitizeErrorDetails(t *testing.T) {
tests := []struct {
name string
diff --git a/internal/observability/datadog.go b/internal/observability/datadog.go
deleted file mode 100644
index 60c0990..0000000
--- a/internal/observability/datadog.go
+++ /dev/null
@@ -1,172 +0,0 @@
-// Package observability provides OpenTelemetry integration for distributed tracing.
-//
-// # Architecture Decision: Datadog Agent Mode
-//
-// We use the Datadog Agent for OTLP ingestion instead of direct API endpoint.
-// This decision was made because:
-//
-// - Direct OTLP Traces API is in Preview status (as of Nov 2025)
-// - Agent provides better reliability with local buffering and retry
-// - Lower latency (localhost vs internet roundtrip)
-// - Agent handles authentication - no need to pass DD_API_KEY in app
-// - Supports all Datadog features (metrics, logs, traces in one agent)
-//
-// # Prerequisites
-//
-// 1. Datadog Account with US5 region (or your region)
-// 2. DD_API_KEY from https://us5.datadoghq.com → Organization Settings → API Keys
-//
-// # macOS Installation
-//
-// Install Datadog Agent:
-//
-// DD_API_KEY="your-key" DD_SITE="us5.datadoghq.com" \
-// bash -c "$(curl -L https://install.datadoghq.com/scripts/install_mac_os.sh)"
-//
-// # Enable OTLP Receiver
-//
-// Add to /opt/datadog-agent/etc/datadog.yaml (at the end of file):
-//
-// otlp_config:
-// receiver:
-// protocols:
-// http:
-// endpoint: "localhost:4318"
-// traces:
-// enabled: true
-// span_name_as_resource_name: true
-//
-// # Restart Agent
-//
-// Option 1 - Using launchctl:
-//
-// sudo launchctl stop com.datadoghq.agent
-// sudo launchctl start com.datadoghq.agent
-//
-// Option 2 - Kill and restart:
-//
-// sudo pkill -9 -f datadog
-// sudo /opt/datadog-agent/bin/agent/agent run &
-//
-// # Option 3 - Use Datadog Agent GUI app
-//
-// # Verify OTLP is Enabled
-//
-// datadog-agent status | grep -A 5 "OTLP"
-//
-// Expected output:
-//
-// OTLP
-// ====
-// Status: Enabled
-// Collector status: Running
-//
-// # View Traces in Datadog
-//
-// After running koopa with tracing enabled:
-// - Go to https://us5.datadoghq.com/apm/traces
-// - Search for service:koopa or your configured service name
-// - Traces appear within 1-2 minutes after app shutdown (flush)
-//
-// # Troubleshooting
-//
-// Agent not running:
-//
-// launchctl list | grep datadog # PID should not be "-"
-//
-// Check Agent logs:
-//
-// sudo tail -50 /var/log/datadog/agent.log
-//
-// Test OTLP endpoint:
-//
-// curl -v http://localhost:4318/v1/traces
-//
-// # Configuration
-//
-// Environment variables (optional):
-// - DD_AGENT_HOST: Override agent host (default: localhost:4318)
-// - DD_ENV: Environment tag (default: dev)
-// - DD_SERVICE: Service name (default: koopa)
-//
-// Config file (~/.koopa/config.yaml):
-//
-// datadog:
-// agent_host: "localhost:4318"
-// environment: "dev"
-// service_name: "koopa"
-package observability
-
-import (
- "context"
- "log/slog"
- "os"
-
- "github.com/firebase/genkit/go/core/tracing"
- "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
- sdktrace "go.opentelemetry.io/otel/sdk/trace"
-)
-
-// Config for Datadog OTEL setup.
-type Config struct {
- // AgentHost is the Datadog Agent OTLP endpoint (default: localhost:4318)
- AgentHost string
- // Environment is the deployment environment (dev, staging, prod)
- Environment string
- // ServiceName is the service name shown in Datadog APM
- ServiceName string
-}
-
-// DefaultAgentHost is the default Datadog Agent OTLP HTTP endpoint.
-const DefaultAgentHost = "localhost:4318"
-
-// SetupDatadog registers a Datadog Agent exporter with Genkit's TracerProvider.
-// Traces are sent to the local Datadog Agent via OTLP HTTP protocol.
-//
-// Returns a shutdown function that flushes pending spans.
-// If AgentHost is empty, uses DefaultAgentHost (localhost:4318).
-func SetupDatadog(ctx context.Context, cfg Config) (shutdown func(context.Context) error, err error) {
- agentHost := cfg.AgentHost
- if agentHost == "" {
- agentHost = DefaultAgentHost
- }
-
- // Set OTEL env vars for Genkit's TracerProvider to pick up.
- // SAFETY: os.Setenv is not concurrent-safe, but this function is called
- // exactly once during startup in InitializeApp, before goroutines are spawned.
- if cfg.ServiceName != "" {
- _ = os.Setenv("OTEL_SERVICE_NAME", cfg.ServiceName)
- }
- if cfg.Environment != "" {
- _ = os.Setenv("OTEL_RESOURCE_ATTRIBUTES", "deployment.environment="+cfg.Environment)
- }
-
- // Create OTLP HTTP exporter pointing to local Datadog Agent
- // Agent handles authentication and forwarding to Datadog backend
- exporter, err := otlptracehttp.New(ctx,
- otlptracehttp.WithEndpoint(agentHost),
- otlptracehttp.WithInsecure(), // localhost doesn't need TLS
- )
- if err != nil {
- slog.Warn("failed to create datadog exporter, tracing disabled", "error", err)
- return func(context.Context) error { return nil }, nil
- }
-
- // Register BatchSpanProcessor with Genkit's TracerProvider
- processor := sdktrace.NewBatchSpanProcessor(exporter)
- tracing.TracerProvider().RegisterSpanProcessor(processor)
-
- slog.Debug("datadog tracing enabled",
- "agent", agentHost,
- "service", cfg.ServiceName,
- "environment", cfg.Environment,
- )
-
- // Create a test span to verify the pipeline works
- tracer := tracing.TracerProvider().Tracer("koopa-init")
- _, span := tracer.Start(ctx, "koopa.init")
- span.End()
- slog.Debug("test span created for datadog verification")
-
- return tracing.TracerProvider().Shutdown, nil
-}
diff --git a/internal/observability/datadog_test.go b/internal/observability/datadog_test.go
deleted file mode 100644
index b57937c..0000000
--- a/internal/observability/datadog_test.go
+++ /dev/null
@@ -1,115 +0,0 @@
-package observability
-
-import (
- "context"
- "testing"
-)
-
-func TestSetupDatadog_DefaultAgentHost(t *testing.T) {
- // NOTE: not parallel — SetupDatadog modifies global state (os.Setenv, TracerProvider)
-
- cfg := Config{
- AgentHost: "", // Empty should use default
- Environment: "test",
- ServiceName: "test-service",
- }
-
- ctx := context.Background()
- shutdown, err := SetupDatadog(ctx, cfg)
-
- // Should not fail even with empty AgentHost
- if err != nil {
- t.Fatalf("SetupDatadog() unexpected error: %v", err)
- }
- if shutdown == nil {
- t.Fatal("SetupDatadog() shutdown = nil, want non-nil")
- }
-
- // Cleanup
- if err := shutdown(ctx); err != nil {
- t.Errorf("shutdown() unexpected error: %v", err)
- }
-}
-
-func TestSetupDatadog_CustomAgentHost(t *testing.T) {
- // NOTE: not parallel — SetupDatadog modifies global state (os.Setenv, TracerProvider)
-
- cfg := Config{
- AgentHost: "custom-host:4318",
- Environment: "staging",
- ServiceName: "custom-service",
- }
-
- ctx := context.Background()
- shutdown, err := SetupDatadog(ctx, cfg)
-
- // Should not fail with custom host
- if err != nil {
- t.Fatalf("SetupDatadog() unexpected error: %v", err)
- }
- if shutdown == nil {
- t.Fatal("SetupDatadog() shutdown = nil, want non-nil")
- }
-
- // Cleanup
- if err := shutdown(ctx); err != nil {
- t.Errorf("shutdown() unexpected error: %v", err)
- }
-}
-
-func TestSetupDatadog_AgentUnavailable_GracefulDegradation(t *testing.T) {
- // NOTE: not parallel — SetupDatadog modifies global state (os.Setenv, TracerProvider)
-
- // Point to a non-existent agent
- cfg := Config{
- AgentHost: "localhost:99999", // Invalid port
- Environment: "test",
- ServiceName: "graceful-test",
- }
-
- ctx := context.Background()
- shutdown, err := SetupDatadog(ctx, cfg)
-
- // Should NOT fail - graceful degradation
- // The exporter creation may succeed but spans will fail to export silently
- if err != nil {
- t.Fatalf("SetupDatadog() unexpected error: %v", err)
- }
- if shutdown == nil {
- t.Fatal("SetupDatadog() shutdown = nil, want non-nil")
- }
-
- // Shutdown should not panic
- if err := shutdown(ctx); err != nil {
- t.Errorf("shutdown() unexpected error: %v", err)
- }
-}
-
-func TestSetupDatadog_EmptyConfig(t *testing.T) {
- // NOTE: not parallel — SetupDatadog modifies global state (os.Setenv, TracerProvider)
-
- // All empty config - should use defaults
- cfg := Config{}
-
- ctx := context.Background()
- shutdown, err := SetupDatadog(ctx, cfg)
-
- if err != nil {
- t.Fatalf("SetupDatadog() unexpected error: %v", err)
- }
- if shutdown == nil {
- t.Fatal("SetupDatadog() shutdown = nil, want non-nil")
- }
-
- if err := shutdown(ctx); err != nil {
- t.Errorf("shutdown() unexpected error: %v", err)
- }
-}
-
-func TestDefaultAgentHost_Value(t *testing.T) {
- t.Parallel()
-
- if got, want := DefaultAgentHost, "localhost:4318"; got != want {
- t.Errorf("DefaultAgentHost = %q, want %q", got, want)
- }
-}
diff --git a/internal/rag/constants.go b/internal/rag/constants.go
index 81980c6..c3a7da0 100644
--- a/internal/rag/constants.go
+++ b/internal/rag/constants.go
@@ -1,14 +1,9 @@
-// Package rag constants.go defines shared constants, types, and configuration for RAG operations.
-//
-// Contents:
-// - Source type constants (SourceTypeConversation, SourceTypeFile, SourceTypeSystem)
-// - Table schema constants for documents table
-// - NewDocStoreConfig factory for consistent DocStore configuration
package rag
import (
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/plugins/postgresql"
+ "google.golang.org/genai"
)
// Source type constants for knowledge documents.
@@ -24,6 +19,12 @@ const (
SourceTypeSystem = "system"
)
+// VectorDimension is the vector dimension used by the pgvector schema.
+// Must match the documents table migration: embedding vector(768).
+// gemini-embedding-001 produces 3072 dimensions by default;
+// we truncate to 768 via OutputDimensionality in EmbedderOptions.
+const VectorDimension int32 = 768
+
// Table schema constants for Genkit PostgreSQL plugin.
// These match the documents table in db/migrations.
const (
@@ -37,7 +38,9 @@ const (
// NewDocStoreConfig creates a postgresql.Config for the documents table.
// This factory ensures consistent configuration across production and tests.
+// EmbedderOptions sets OutputDimensionality to match the pgvector schema.
func NewDocStoreConfig(embedder ai.Embedder) *postgresql.Config {
+ dim := VectorDimension
return &postgresql.Config{
TableName: DocumentsTableName,
SchemaName: DocumentsSchemaName,
@@ -47,5 +50,6 @@ func NewDocStoreConfig(embedder ai.Embedder) *postgresql.Config {
MetadataJSONColumn: DocumentsMetadataCol,
MetadataColumns: []string{"source_type"}, // For filtering by type
Embedder: embedder,
+ EmbedderOptions: &genai.EmbedContentConfig{OutputDimensionality: &dim},
}
}
diff --git a/internal/rag/doc.go b/internal/rag/doc.go
index da78850..d324943 100644
--- a/internal/rag/doc.go
+++ b/internal/rag/doc.go
@@ -8,17 +8,11 @@
// RAG enhances LLM responses by augmenting prompts with relevant context from a knowledge base.
// The rag package manages:
//
-// - System knowledge indexing (IndexSystemKnowledge function)
+// - DocStore configuration for vector storage
// - Integration with Genkit's PostgreSQL DocStore
//
// # Architecture
//
-// System Knowledge Docs
-// |
-// v
-// IndexSystemKnowledge()
-// |
-// v
// Genkit PostgreSQL DocStore
// |
// +-- Vector embedding (via AI Embedder)
@@ -35,10 +29,9 @@
//
// # Key Components
//
-// IndexSystemKnowledge: Package-level function that indexes built-in knowledge:
-// - Go best practices and coding standards
-// - Agent capabilities and tool usage
-// - Architecture principles
+// NewDocStoreConfig: Creates configuration for the Genkit PostgreSQL DocStore.
+//
+// DeleteByIDs: Removes documents by ID for UPSERT emulation.
//
// # Source Types
//
@@ -51,5 +44,4 @@
// # Thread Safety
//
// DocStore handles concurrent operations safely.
-// IndexSystemKnowledge is called once at startup.
package rag
diff --git a/internal/rag/fuzz_test.go b/internal/rag/fuzz_test.go
index d5bac97..301026b 100644
--- a/internal/rag/fuzz_test.go
+++ b/internal/rag/fuzz_test.go
@@ -33,8 +33,7 @@ func FuzzDeleteByIDs_SQLInjection(f *testing.F) {
}
// Setup test database
- dbContainer, cleanup := testutil.SetupTestDB(t)
- defer cleanup()
+ dbContainer := testutil.SetupTestDB(t)
ctx := context.Background()
pool := dbContainer.Pool
@@ -70,15 +69,14 @@ func FuzzDeleteByIDs_SQLInjection(f *testing.F) {
func TestDeleteByIDs_EmptySlice(t *testing.T) {
t.Parallel()
- dbContainer, cleanup := testutil.SetupTestDB(t)
- defer cleanup()
+ dbContainer := testutil.SetupTestDB(t)
ctx := context.Background()
// Empty slice should return nil without executing query
err := rag.DeleteByIDs(ctx, dbContainer.Pool, []string{})
if err != nil {
- t.Errorf("deleteByIDs with empty slice should return nil, got: %v", err)
+ t.Errorf("DeleteByIDs(empty slice) unexpected error: %v", err)
}
}
@@ -86,8 +84,7 @@ func TestDeleteByIDs_EmptySlice(t *testing.T) {
func TestDeleteByIDs_ValidUUIDs(t *testing.T) {
t.Parallel()
- dbContainer, cleanup := testutil.SetupTestDB(t)
- defer cleanup()
+ dbContainer := testutil.SetupTestDB(t)
ctx := context.Background()
@@ -99,6 +96,6 @@ func TestDeleteByIDs_ValidUUIDs(t *testing.T) {
err := rag.DeleteByIDs(ctx, dbContainer.Pool, validIDs)
if err != nil {
- t.Errorf("deleteByIDs with valid UUIDs should succeed, got: %v", err)
+ t.Errorf("DeleteByIDs(valid UUIDs) unexpected error: %v", err)
}
}
diff --git a/internal/rag/system.go b/internal/rag/system.go
index db0cdbd..3e3d285 100644
--- a/internal/rag/system.go
+++ b/internal/rag/system.go
@@ -1,59 +1,12 @@
-// Package rag provides RAG (Retrieval-Augmented Generation) functionality.
-// This file implements IndexSystemKnowledge for managing built-in knowledge
-// about Agent capabilities, Golang best practices, and architecture principles.
package rag
import (
"context"
"fmt"
- "log/slog"
- "github.com/firebase/genkit/go/ai"
- "github.com/firebase/genkit/go/plugins/postgresql"
"github.com/jackc/pgx/v5/pgxpool"
)
-// IndexSystemKnowledge indexes all built-in system knowledge documents.
-// Called once during application startup.
-//
-// Features:
-// - Uses fixed document IDs (e.g., "system:golang-errors")
-// - UPSERT behavior via delete-then-insert (Genkit DocStore doesn't support UPSERT)
-// - Returns count of successfully indexed documents
-//
-// Returns: (indexedCount, error)
-// Error: returns error if indexing fails
-func IndexSystemKnowledge(ctx context.Context, store *postgresql.DocStore, pool *pgxpool.Pool) (int, error) {
- docs := buildSystemKnowledgeDocs()
-
- // Extract document IDs for deletion (UPSERT emulation)
- ids := make([]string, 0, len(docs))
- for _, doc := range docs {
- if id, ok := doc.Metadata["id"].(string); ok {
- ids = append(ids, id)
- }
- }
-
- // Delete existing documents first (UPSERT emulation).
- // Genkit DocStore.Index() only does INSERT, so we must delete first.
- // NOTE: Not fully atomic — Genkit DocStore manages its own connections,
- // so delete (via pool) and insert (via DocStore) cannot share a transaction.
- // This is acceptable because IndexSystemKnowledge runs only at startup.
- if err := DeleteByIDs(ctx, pool, ids); err != nil {
- // DELETE with non-existent IDs returns 0 rows (not an error).
- // A real error here indicates a connection or SQL problem.
- slog.Warn("failed to delete existing system knowledge", "error", err, "ids", ids)
- }
-
- if err := store.Index(ctx, docs); err != nil {
- return 0, fmt.Errorf("failed to index system knowledge: %w", err)
- }
-
- slog.Debug("system knowledge indexed", "count", len(docs))
-
- return len(docs), nil
-}
-
// DeleteByIDs deletes documents by their IDs.
// Used for UPSERT emulation since Genkit DocStore only supports INSERT.
// Exported for testing (fuzz tests in rag_test package).
@@ -69,235 +22,3 @@ func DeleteByIDs(ctx context.Context, pool *pgxpool.Pool, ids []string) error {
}
return nil
}
-
-// buildSystemKnowledgeDocs constructs all system knowledge documents.
-func buildSystemKnowledgeDocs() []*ai.Document {
- var docs []*ai.Document
-
- // 1. Golang Style Guide
- docs = append(docs, buildGolangStyleDocs()...)
-
- // 2. Agent Capabilities
- docs = append(docs, buildCapabilitiesDocs()...)
-
- // 3. Architecture Principles
- docs = append(docs, buildArchitectureDocs()...)
-
- return docs
-}
-
-// buildGolangStyleDocs creates Golang best practices documents.
-func buildGolangStyleDocs() []*ai.Document {
- return []*ai.Document{
- // Document 1: Error Handling
- ai.DocumentFromText(`# Golang Error Handling Best Practices
-
-## Core Principles
-- Always check errors immediately after function calls
-- Use fmt.Errorf with %w for error wrapping (enables errors.Is/As)
-- Avoid naked returns in error paths
-- Return errors to callers, don't panic unless truly exceptional
-
-## Examples
-Good:
- result, err := doSomething()
- if err != nil {
- return fmt.Errorf("failed to do something: %w", err)
- }
-
-Bad:
- result, _ := doSomething() // Ignoring errors
-
-## Security
-- Never expose internal error details to users
-- Log full errors, return sanitized messages`,
- map[string]any{
- "id": "system:golang-errors",
- "source_type": SourceTypeSystem,
- "category": "golang",
- "topic": "error-handling",
- "version": "1.0",
- }),
-
- // Document 2: Concurrency Patterns
- ai.DocumentFromText(`# Golang Concurrency Best Practices
-
-## Goroutines
-- Always have a way to stop goroutines (context, done channel)
-- Use WaitGroups for coordinating multiple goroutines
-- Avoid goroutine leaks by ensuring all goroutines eventually exit
-
-## Channels
-- Close channels from sender side only
-- Use select with context.Done() for cancellation
-- Buffered channels for non-blocking sends
-
-## Context
-- Pass context as first parameter
-- Use context.WithTimeout for operations with deadlines
-- Never store context in struct fields (exception: short-lived request-scoped structs)
-
-## Mutexes
-- Keep critical sections small
-- Use RWMutex when read-heavy workload
-- Prefer channels over shared memory when possible`,
- map[string]any{
- "id": "system:golang-concurrency",
- "source_type": SourceTypeSystem,
- "category": "golang",
- "topic": "concurrency",
- "version": "1.0",
- }),
-
- // Document 3: Naming Conventions
- ai.DocumentFromText(`# Golang Naming Conventions
-
-## Packages
-- Short, lowercase, no underscores (e.g., httputil, not http_util)
-- Singular form (e.g., encoding, not encodings)
-
-## Interfaces
-- One-method interfaces: name with -er suffix (Reader, Writer, Closer)
-- Avoid "I" prefix (use Reader, not IReader)
-
-## Getters/Setters
-- No "Get" prefix for getters (use Owner(), not GetOwner())
-- Use "Set" prefix for setters (SetOwner())
-
-## Acronyms
-- Keep consistent casing: URL, HTTP, ID (not Url, Http, Id)
-- In names: use URLParser, not UrlParser
-
-## Exported vs Unexported
-- Exported: PascalCase (MyFunction, MyStruct)
-- Unexported: camelCase (myFunction, myStruct)`,
- map[string]any{
- "id": "system:golang-naming",
- "source_type": SourceTypeSystem,
- "category": "golang",
- "topic": "naming",
- "version": "1.0",
- }),
- }
-}
-
-// buildCapabilitiesDocs creates Agent capabilities documents.
-func buildCapabilitiesDocs() []*ai.Document {
- return []*ai.Document{
- // Document 4: Available Tools
- ai.DocumentFromText(`# Agent Available Tools
-
-## File Operations
-- read_file: Read file contents
-- write_file: Create or update file
-- list_files: List directory contents with glob patterns
-- delete_file: Remove file (requires confirmation)
-- get_file_info: Get file metadata
-
-## System Operations
-- current_time: Get current timestamp
-- execute_command: Run shell commands (requires confirmation for destructive ops)
-- get_env: Read environment variables
-
-## Network Operations
-- web_search: Search the web using SearXNG metasearch engine
-- web_fetch: Fetch and extract content from a specific URL
-
-## Knowledge Operations
-- search_history: Search conversation history
-- search_documents: Search user-indexed documents (Notion pages, local files)
-- search_system_knowledge: Search Agent's built-in knowledge
-- knowledge_store: Store new knowledge documents for later retrieval
-
-## Limitations
-- File operations limited to current working directory
-- Commands requiring sudo are blocked
-- Cannot access system files (/etc, /sys, etc.)`,
- map[string]any{
- "id": "system:agent-tools",
- "source_type": SourceTypeSystem,
- "category": "capabilities",
- "topic": "available-tools",
- "version": "1.0",
- }),
-
- // Document 5: Best Practices
- ai.DocumentFromText(`# Agent Best Practices
-
-## When to Use Tools
-- search_history: When user asks "what did I say about X?"
-- search_documents: When user asks about their notes/documents
-- search_system_knowledge: When unsure about Golang conventions or Agent capabilities
-- read_file before write_file: Always read first to understand context
-
-## Communication
-- Be concise but informative
-- Use code blocks for code snippets
-- Explain what you're about to do before using destructive tools
-- Ask for confirmation when ambiguous
-
-## Error Handling
-- If a tool fails, try alternative approaches
-- Explain errors in user-friendly language
-- Don't give up after first failure - retry with different parameters
-
-## Security
-- Never execute user-provided code without understanding it
-- Always validate file paths before operations
-- Sanitize command inputs to prevent injection`,
- map[string]any{
- "id": "system:agent-best-practices",
- "source_type": SourceTypeSystem,
- "category": "capabilities",
- "topic": "best-practices",
- "version": "1.0",
- }),
- }
-}
-
-// buildArchitectureDocs creates architecture principles documents.
-func buildArchitectureDocs() []*ai.Document {
- return []*ai.Document{
- // Document 6: Design Principles
- ai.DocumentFromText(`# Koopa CLI Architecture Principles
-
-## Dependency Injection
-- Use struct-based DI
-- Define interfaces in consumer packages (not provider packages)
-- Accept interfaces, return structs
-
-## Package Structure
-- internal/agent: Core AI interaction logic
-- internal/tools: Tool definitions and implementations
-- internal/rag: Knowledge retrieval and indexing
-- internal/session: Session persistence
-- cmd: CLI commands and user interaction
-
-## Error Handling
-- Errors propagate up, logged at boundaries
-- Use error wrapping (fmt.Errorf with %w)
-- Graceful degradation when non-critical services fail
-
-## Testing
-- Unit tests for business logic
-- Integration tests for cross-package interactions
-- Use interfaces for mocking
-
-## Security
-- Principle of least privilege
-- Input validation at boundaries
-- Explicit user confirmation for destructive operations
-
-## Concurrency
-- Use context for cancellation
-- Protect shared state with mutexes
-- Goroutines must have cleanup mechanism`,
- map[string]any{
- "id": "system:architecture-principles",
- "source_type": SourceTypeSystem,
- "category": "architecture",
- "topic": "design-principles",
- "version": "1.0",
- }),
- }
-}
diff --git a/internal/rag/system_test.go b/internal/rag/system_test.go
deleted file mode 100644
index a84a046..0000000
--- a/internal/rag/system_test.go
+++ /dev/null
@@ -1,206 +0,0 @@
-//go:build integration
-
-// Package rag_test provides integration tests for RAG system knowledge functions.
-package rag_test
-
-import (
- "context"
- "testing"
-
- "github.com/koopa0/koopa/internal/rag"
- "github.com/koopa0/koopa/internal/testutil"
-)
-
-// TestIndexSystemKnowledge_FirstTime verifies first-time indexing works correctly.
-func TestIndexSystemKnowledge_FirstTime(t *testing.T) {
- t.Parallel()
-
- dbContainer, cleanup := testutil.SetupTestDB(t)
- defer cleanup()
-
- ragSetup := testutil.SetupRAG(t, dbContainer.Pool)
- ctx := context.Background()
-
- // Index system knowledge for the first time
- count, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool)
- if err != nil {
- t.Fatalf("IndexSystemKnowledge() first-time indexing unexpected error: %v (should succeed)", err)
- }
- if count <= 0 {
- t.Errorf("IndexSystemKnowledge() count = %d, want > 0 (should index at least one document)", count)
- }
-
- t.Logf("Indexed %d system knowledge documents", count)
-}
-
-// TestIndexSystemKnowledge_Reindexing verifies UPSERT behavior (no duplicates on re-index).
-func TestIndexSystemKnowledge_Reindexing(t *testing.T) {
- t.Parallel()
-
- dbContainer, cleanup := testutil.SetupTestDB(t)
- defer cleanup()
-
- ragSetup := testutil.SetupRAG(t, dbContainer.Pool)
- ctx := context.Background()
-
- // First indexing
- count1, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool)
- if err != nil {
- t.Fatalf("IndexSystemKnowledge() first indexing unexpected error: %v (should succeed)", err)
- }
-
- // Second indexing (should not create duplicates)
- count2, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool)
- if err != nil {
- t.Fatalf("IndexSystemKnowledge() re-indexing unexpected error: %v (should succeed)", err)
- }
-
- if count1 != count2 {
- t.Errorf("IndexSystemKnowledge() re-indexing count = %d, want %d (should index same number of documents)", count2, count1)
- }
-
- // Verify no duplicates by counting documents with system source type
- var totalCount int
- err = dbContainer.Pool.QueryRow(ctx,
- `SELECT COUNT(*) FROM documents WHERE metadata->>'source_type' = $1`,
- rag.SourceTypeSystem,
- ).Scan(&totalCount)
- if err != nil {
- t.Fatalf("QueryRow() counting documents unexpected error: %v (should succeed)", err)
- }
-
- if totalCount != count1 {
- t.Errorf("total system documents after re-indexing = %d, want %d (should have no duplicate documents)", totalCount, count1)
- }
- t.Logf("Total system documents after re-indexing: %d (expected: %d)", totalCount, count1)
-}
-
-// TestIndexSystemKnowledge_DocumentMetadata verifies all documents have required metadata.
-func TestIndexSystemKnowledge_DocumentMetadata(t *testing.T) {
- t.Parallel()
-
- dbContainer, cleanup := testutil.SetupTestDB(t)
- defer cleanup()
-
- ragSetup := testutil.SetupRAG(t, dbContainer.Pool)
- ctx := context.Background()
-
- // Index documents
- if _, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool); err != nil {
- t.Fatalf("IndexSystemKnowledge() unexpected error: %v", err)
- }
-
- // Query documents and verify metadata
- rows, err := dbContainer.Pool.Query(ctx,
- `SELECT id, metadata FROM documents WHERE metadata->>'source_type' = $1`,
- rag.SourceTypeSystem,
- )
- if err != nil {
- t.Fatalf("Query() unexpected error: %v", err)
- }
- defer rows.Close()
-
- docCount := 0
- for rows.Next() {
- var id string
- var metadata map[string]any
- if err := rows.Scan(&id, &metadata); err != nil {
- t.Fatalf("rows.Scan() unexpected error: %v", err)
- }
-
- // Verify required metadata fields
- if metadata["id"] == "" || metadata["id"] == nil {
- t.Errorf("document %q metadata[id] = empty, want non-empty", id)
- }
- if metadata["source_type"] == "" || metadata["source_type"] == nil {
- t.Errorf("document %q metadata[source_type] = empty, want non-empty", id)
- }
- if metadata["category"] == "" || metadata["category"] == nil {
- t.Errorf("document %q metadata[category] = empty, want non-empty", id)
- }
- if metadata["topic"] == "" || metadata["topic"] == nil {
- t.Errorf("document %q metadata[topic] = empty, want non-empty", id)
- }
-
- docCount++
- }
-
- if err := rows.Err(); err != nil {
- t.Fatalf("rows.Err() unexpected error: %v", err)
- }
- if docCount <= 0 {
- t.Errorf("verified documents count = %d, want > 0 (should have at least one system document)", docCount)
- }
- t.Logf("Verified metadata for %d documents", docCount)
-}
-
-// TestIndexSystemKnowledge_UniqueIDs verifies all documents have unique IDs.
-func TestIndexSystemKnowledge_UniqueIDs(t *testing.T) {
- t.Parallel()
-
- dbContainer, cleanup := testutil.SetupTestDB(t)
- defer cleanup()
-
- ragSetup := testutil.SetupRAG(t, dbContainer.Pool)
- ctx := context.Background()
-
- // Index documents
- if _, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool); err != nil {
- t.Fatalf("IndexSystemKnowledge() unexpected error: %v", err)
- }
-
- // Query for duplicate IDs
- rows, err := dbContainer.Pool.Query(ctx,
- `SELECT metadata->>'id' as doc_id, COUNT(*) as cnt
- FROM documents
- WHERE metadata->>'source_type' = $1
- GROUP BY metadata->>'id'
- HAVING COUNT(*) > 1`,
- rag.SourceTypeSystem,
- )
- if err != nil {
- t.Fatalf("Query() unexpected error: %v", err)
- }
- defer rows.Close()
-
- duplicates := 0
- for rows.Next() {
- var docID string
- var count int
- if err := rows.Scan(&docID, &count); err != nil {
- t.Fatalf("rows.Scan() unexpected error: %v", err)
- }
- t.Errorf("duplicate document ID found: %s (count: %d)", docID, count)
- duplicates++
- }
-
- if err := rows.Err(); err != nil {
- t.Fatalf("rows.Err() unexpected error: %v", err)
- }
- if duplicates != 0 {
- t.Errorf("duplicate document IDs = %d, want 0", duplicates)
- }
-}
-
-// TestIndexSystemKnowledge_CanceledContext verifies graceful handling of canceled context.
-func TestIndexSystemKnowledge_CanceledContext(t *testing.T) {
- t.Parallel()
-
- dbContainer, cleanup := testutil.SetupTestDB(t)
- defer cleanup()
-
- ragSetup := testutil.SetupRAG(t, dbContainer.Pool)
-
- // Create already-canceled context
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
-
- // Should fail gracefully with canceled context
- _, err := rag.IndexSystemKnowledge(ctx, ragSetup.DocStore, dbContainer.Pool)
-
- // Error is expected (context canceled)
- if err == nil {
- t.Error("IndexSystemKnowledge(canceled context) error = nil, want non-nil (should fail with canceled context)")
- }
- t.Logf("Expected error: %v", err)
-}
diff --git a/internal/security/command.go b/internal/security/command.go
index 07a59dc..420a4a2 100644
--- a/internal/security/command.go
+++ b/internal/security/command.go
@@ -3,14 +3,14 @@ package security
import (
"fmt"
"log/slog"
+ "slices"
"strings"
)
// Command validates commands to prevent injection attacks.
// Used to prevent command injection attacks (CWE-78).
type Command struct {
- blacklist []string
- whitelist []string // If non-empty, only allow commands in the whitelist
+ whitelist []string // Only allow commands in this list
blockedSubcommands map[string][]string // cmd → blocked first-arg subcommands
blockedArgPatterns map[string][]string // cmd → blocked argument patterns (any position)
}
@@ -31,7 +31,6 @@ type Command struct {
// make and mkdir are NOT whitelisted — make can execute arbitrary Makefile targets.
func NewCommand() *Command {
return &Command{
- blacklist: []string{}, // Whitelist mode doesn't need blacklist
whitelist: []string{
// File listing (metadata only — no content reading)
"ls", "wc", "sort", "uniq",
@@ -68,7 +67,7 @@ func NewCommand() *Command {
}
}
-// ValidateCommand validates whether a command is safe.
+// Validate validates whether a command is safe to execute.
//
// SECURITY NOTE: This validator is designed for use with exec.Command(cmd, args...),
// which does NOT pass arguments through a shell. Therefore:
@@ -79,14 +78,14 @@ func NewCommand() *Command {
// Parameters:
// - cmd: command name (executable)
// - args: command arguments (passed directly to exec.Command, not shell-interpreted)
-func (v *Command) ValidateCommand(cmd string, args []string) error {
+func (v *Command) Validate(cmd string, args []string) error {
// 1. Check for empty command
if strings.TrimSpace(cmd) == "" {
return fmt.Errorf("command cannot be empty")
}
// 2. Validate command name only (no args yet)
- if err := v.validateCommandName(cmd); err != nil {
+ if err := validateCommandName(cmd); err != nil {
return fmt.Errorf("validating command name: %w", err)
}
@@ -111,7 +110,7 @@ func (v *Command) ValidateCommand(cmd string, args []string) error {
// NOTE: We do NOT check for shell metacharacters (|, $, >, etc.) because
// exec.Command treats them as literal strings, not shell operators
for i, arg := range args {
- if err := v.validateArgument(arg); err != nil {
+ if err := validateArgument(arg); err != nil {
slog.Warn("dangerous argument detected",
"command", cmd,
"arg_index", i,
@@ -122,52 +121,27 @@ func (v *Command) ValidateCommand(cmd string, args []string) error {
}
}
- // 6. Check full command string (cmd + args) for dangerous patterns
- // Some dangerous patterns span command and arguments (e.g., "rm -rf /")
- fullCmd := strings.ToLower(cmd + " " + strings.Join(args, " "))
- for _, pattern := range v.blacklist {
- if strings.Contains(fullCmd, strings.ToLower(pattern)) {
- slog.Warn("command+args match dangerous pattern",
- "command", cmd,
- "args", args,
- "full_command", fullCmd,
- "dangerous_pattern", pattern,
- "security_event", "dangerous_command_combination")
- return fmt.Errorf("command contains dangerous pattern: '%s'", pattern)
- }
- }
-
return nil
}
+// shellMetachars lists characters that indicate shell injection in a command name.
+const shellMetachars = ";|&`\n><$()"
+
// validateCommandName validates the command name (executable) only.
-// Checks blacklist patterns and shell injection attempts in the command name itself.
-func (v *Command) validateCommandName(cmd string) error {
+// Checks for shell injection attempts in the command name itself.
+func validateCommandName(cmd string) error {
// Normalize command name
cmd = strings.TrimSpace(strings.ToLower(cmd))
- // Check blacklisted command patterns
- for _, pattern := range v.blacklist {
- if strings.Contains(cmd, strings.ToLower(pattern)) {
- slog.Warn("command matches blacklisted pattern",
- "command", cmd,
- "dangerous_pattern", pattern,
- "security_event", "command_blacklist_violation")
- return fmt.Errorf("command contains dangerous pattern: '%s'", pattern)
- }
- }
-
// Check for shell metacharacters in command name itself
// (These would indicate shell injection attempt)
- shellMetachars := []string{";", "|", "&", "`", "\n", ">", "<", "$", "(", ")"}
- for _, char := range shellMetachars {
- if strings.Contains(cmd, char) {
- slog.Warn("command name contains shell metacharacter",
- "command", cmd,
- "character", char,
- "security_event", "shell_injection_in_command_name")
- return fmt.Errorf("command name contains shell metacharacter: '%s'", char)
- }
+ if i := strings.IndexAny(cmd, shellMetachars); i >= 0 {
+ char := string(cmd[i])
+ slog.Warn("command name contains shell metacharacter",
+ "command", cmd,
+ "character", char,
+ "security_event", "shell_injection_in_command_name")
+ return fmt.Errorf("command name contains shell metacharacter: %q", char)
}
return nil
@@ -193,14 +167,12 @@ func (v *Command) validateSubcommands(cmd string, args []string) error {
// Check blocked subcommands (first argument)
if blocked, ok := v.blockedSubcommands[cmdLower]; ok && len(args) > 0 {
firstArg := strings.ToLower(strings.TrimSpace(args[0]))
- for _, sub := range blocked {
- if firstArg == sub {
- slog.Warn("blocked subcommand",
- "command", cmd,
- "subcommand", args[0],
- "security_event", "blocked_subcommand")
- return fmt.Errorf("subcommand '%s %s' is not allowed (can execute arbitrary code)", cmd, args[0])
- }
+ if slices.Contains(blocked, firstArg) {
+ slog.Warn("blocked subcommand",
+ "command", cmd,
+ "subcommand", args[0],
+ "security_event", "blocked_subcommand")
+ return fmt.Errorf("subcommand '%s %s' is not allowed (can execute arbitrary code)", cmd, args[0])
}
}
@@ -224,6 +196,20 @@ func (v *Command) validateSubcommands(cmd string, args []string) error {
return nil
}
+// dangerousArgPatterns lists embedded command patterns that are dangerous
+// even when passed as arguments via exec.Command.
+var dangerousArgPatterns = []string{
+ "rm -rf /",
+ "rm -rf /*",
+ "rm -rf ~",
+ "mkfs",
+ "dd if=/dev/zero",
+ "dd if=/dev/urandom",
+ "shutdown",
+ "reboot",
+ "sudo su",
+}
+
// validateArgument checks if an argument contains obviously malicious patterns.
//
// IMPORTANT: This function does NOT check for shell metacharacters like $, |, >, <
@@ -232,7 +218,7 @@ func (v *Command) validateSubcommands(cmd string, args []string) error {
// - Embedded dangerous commands (e.g., "rm -rf /")
// - Null bytes
// - Extremely long arguments (possible buffer overflow)
-func (*Command) validateArgument(arg string) error {
+func validateArgument(arg string) error {
// Check for null bytes (often used in injection attacks)
if strings.Contains(arg, "\x00") {
return fmt.Errorf("argument contains null byte")
@@ -246,19 +232,7 @@ func (*Command) validateArgument(arg string) error {
// Check for embedded dangerous command patterns
// These are suspicious even in arguments
argLower := strings.ToLower(arg)
- dangerousPatterns := []string{
- "rm -rf /",
- "rm -rf /*",
- "rm -rf ~",
- "mkfs",
- "dd if=/dev/zero",
- "dd if=/dev/urandom",
- "shutdown",
- "reboot",
- "sudo su",
- }
-
- for _, pattern := range dangerousPatterns {
+ for _, pattern := range dangerousArgPatterns {
if strings.Contains(argLower, pattern) {
return fmt.Errorf("argument contains dangerous pattern: %s", pattern)
}
diff --git a/internal/security/command_test.go b/internal/security/command_test.go
index 95efef5..0f0d6da 100644
--- a/internal/security/command_test.go
+++ b/internal/security/command_test.go
@@ -89,12 +89,12 @@ func TestCommandValidation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- err := cmdValidator.ValidateCommand(tt.command, tt.args)
+ err := cmdValidator.Validate(tt.command, tt.args)
if tt.shouldErr && err == nil {
- t.Errorf("expected error for %q, but got none: %s", tt.command, tt.reason)
+ t.Errorf("Validate(%q, %v) = nil, want error: %s", tt.command, tt.args, tt.reason)
}
if !tt.shouldErr && err != nil {
- t.Errorf("unexpected error for %q: %v (%s)", tt.command, err, tt.reason)
+ t.Errorf("Validate(%q, %v) = %v, want nil (%s)", tt.command, tt.args, err, tt.reason)
}
})
}
@@ -133,12 +133,12 @@ func TestStrictCommandValidator(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- err := validator.ValidateCommand(tt.command, tt.args)
+ err := validator.Validate(tt.command, tt.args)
if tt.shouldErr && err == nil {
- t.Errorf("expected error for %q", tt.command)
+ t.Errorf("Validate(%q, %v) = nil, want error", tt.command, tt.args)
}
if !tt.shouldErr && err != nil {
- t.Errorf("unexpected error for %q: %v", tt.command, err)
+ t.Errorf("Validate(%q, %v) = %v, want nil", tt.command, tt.args, err)
}
})
}
@@ -193,18 +193,33 @@ func TestBlockedSubcommands(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- err := v.ValidateCommand(tt.command, tt.args)
+ err := v.Validate(tt.command, tt.args)
if tt.shouldErr && err == nil {
- t.Errorf("ValidateCommand(%q, %v) = nil, want error", tt.command, tt.args)
+ t.Errorf("Validate(%q, %v) = nil, want error", tt.command, tt.args)
}
if !tt.shouldErr && err != nil {
- t.Errorf("ValidateCommand(%q, %v) = %v, want nil", tt.command, tt.args, err)
+ t.Errorf("Validate(%q, %v) = %v, want nil", tt.command, tt.args, err)
}
})
}
}
// TestCommandValidationEdgeCases tests edge cases in command validation
+// TestAllShellMetacharsBlocked verifies every shell metacharacter in the const
+// is blocked when it appears in a command name. This prevents regressions if
+// shellMetachars is modified.
+func TestAllShellMetacharsBlocked(t *testing.T) {
+ v := NewCommand()
+ metachars := []string{";", "|", "&", "`", "\n", ">", "<", "$", "(", ")"}
+
+ for _, char := range metachars {
+ cmd := "ls" + char + "cat"
+ if err := v.Validate(cmd, nil); err == nil {
+ t.Errorf("Validate(%q, nil) = nil, want error for metachar %q", cmd, char)
+ }
+ }
+}
+
func TestCommandValidationEdgeCases(t *testing.T) {
cmdValidator := NewCommand()
@@ -275,12 +290,12 @@ func TestCommandValidationEdgeCases(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- err := cmdValidator.ValidateCommand(tt.command, tt.args)
+ err := cmdValidator.Validate(tt.command, tt.args)
if tt.shouldErr && err == nil {
- t.Errorf("expected error for %q, but got none: %s", tt.name, tt.reason)
+ t.Errorf("Validate(%q, %v) = nil, want error: %s", tt.command, tt.args, tt.reason)
}
if !tt.shouldErr && err != nil {
- t.Errorf("unexpected error for %q: %v (%s)", tt.name, err, tt.reason)
+ t.Errorf("Validate(%q, %v) = %v, want nil (%s)", tt.command, tt.args, err, tt.reason)
}
})
}
diff --git a/internal/security/doc.go b/internal/security/doc.go
index 02d203f..8e17ca3 100644
--- a/internal/security/doc.go
+++ b/internal/security/doc.go
@@ -16,30 +16,30 @@
// Path Validator: Prevents directory traversal and ensures file operations
// stay within allowed boundaries.
//
-// pathValidator := security.NewPath()
-// if err := pathValidator.ValidatePath(userInput); err != nil {
+// pathValidator, err := security.NewPath([]string{"/safe/dir"})
+// if _, err := pathValidator.Validate(userInput); err != nil {
// return fmt.Errorf("invalid path: %w", err)
// }
//
// Command Validator: Blocks dangerous shell commands and prevents command injection.
//
// cmdValidator := security.NewCommand()
-// if err := cmdValidator.ValidateCommand(cmd, args); err != nil {
+// if err := cmdValidator.Validate(cmd, args); err != nil {
// return fmt.Errorf("dangerous command: %w", err)
// }
//
// Dangerous commands blocked include: rm -rf, sudo, shutdown, dd, mkfs,
// format, and other destructive operations.
//
-// HTTP Validator: Prevents SSRF attacks by blocking requests to private networks
+// URL Validator: Prevents SSRF attacks by blocking requests to private networks
// and cloud metadata endpoints.
//
-// httpValidator := security.NewHTTP()
-// if err := httpValidator.ValidateURL(url); err != nil {
+// urlValidator := security.NewURL()
+// if err := urlValidator.Validate(rawURL); err != nil {
// return fmt.Errorf("SSRF attempt blocked: %w", err)
// }
-// // Use the validator's HTTP client for safe requests
-// resp, err := httpValidator.Client().Get(url)
+// // Use SafeTransport for DNS-rebinding protection
+// client := &http.Client{Transport: urlValidator.SafeTransport()}
//
// Blocked targets include:
// - Private IP ranges (127.0.0.1, 192.168.x.x, 10.x.x.x)
@@ -50,7 +50,7 @@
// from unauthorized access.
//
// envValidator := security.NewEnv()
-// if err := envValidator.ValidateEnvAccess(key); err != nil {
+// if err := envValidator.Validate(key); err != nil {
// return fmt.Errorf("access to sensitive variable blocked: %w", err)
// }
//
@@ -67,27 +67,32 @@
// # Integration Example
//
// // Create validators
-// pathVal := security.NewPath()
+// pathVal, _ := security.NewPath([]string{workDir})
// cmdVal := security.NewCommand()
-// httpVal := security.NewHTTP()
+// urlVal := security.NewURL()
// envVal := security.NewEnv()
//
-// // Pass to toolsets during initialization
-// fileToolset := tools.NewFileToolset(pathVal)
-// systemToolset := tools.NewSystemToolset(cmdVal, envVal)
-// networkToolset := tools.NewNetworkToolset(httpVal)
+// // Pass to tool constructors during initialization
+// fileTools, _ := tools.NewFile(pathVal, logger)
+// systemTools, _ := tools.NewSystem(cmdVal, envVal, logger)
+// networkTools, _ := tools.NewNetwork(urlVal, logger)
//
// # Configuration
//
-// The HTTP validator supports configuration for response size limits,
-// timeouts, and redirect limits:
+// The URL validator uses SafeTransport for DNS-resolution-level SSRF protection:
//
-// httpValidator := security.NewHTTP()
-// // Default: 10MB max response, 30s timeout, 10 redirects
-// client := httpValidator.Client()
+// urlValidator := security.NewURL()
+// client := &http.Client{Transport: urlValidator.SafeTransport()}
//
// Other validators use secure defaults and require no configuration.
//
+// # Error Handling
+//
+// Validators intentionally both log and return errors. This is a deliberate
+// exception to the "handle errors once" rule: security events require an
+// audit trail (via logging) AND must propagate the error to callers so they
+// can deny the operation. Removing either side would create a security gap.
+//
// # Testing
//
// Each validator includes comprehensive tests covering:
diff --git a/internal/security/env.go b/internal/security/env.go
index 6582c5f..c869dd9 100644
--- a/internal/security/env.go
+++ b/internal/security/env.go
@@ -22,7 +22,6 @@ func NewEnv() *Env {
"SECRET",
"PASSWORD",
"PASSWD",
- "PWD",
"TOKEN",
"ACCESS_TOKEN",
"REFRESH_TOKEN",
@@ -34,8 +33,8 @@ func NewEnv() *Env {
// Cloud services related
"AWS_SECRET",
"AWS_ACCESS_KEY",
- "AZURE_",
- "GCP_",
+ "AZURE",
+ "GCP",
"GOOGLE_API",
"GOOGLE_APPLICATION_CREDENTIALS",
@@ -88,13 +87,14 @@ func NewEnv() *Env {
}
}
-// ValidateEnvAccess validates whether access to the specified environment variable is allowed
-func (v *Env) ValidateEnvAccess(envName string) error {
+// Validate validates whether access to the specified environment variable is allowed.
+func (v *Env) Validate(envName string) error {
envUpper := strings.ToUpper(envName)
- // Check if it matches sensitive patterns
+ // Check if it matches sensitive patterns using word-boundary matching.
+ // Splits on "_" to avoid false positives like PWD matching PASSWORD.
for _, pattern := range v.sensitivePatterns {
- if strings.Contains(envUpper, pattern) {
+ if isSensitivePattern(envUpper, pattern) {
slog.Warn("sensitive environment variable access attempt",
"env_name", envName,
"matched_pattern", pattern,
@@ -105,3 +105,27 @@ func (v *Env) ValidateEnvAccess(envName string) error {
return nil
}
+
+// isSensitivePattern checks if envName matches pattern.
+//
+// For composite patterns (containing "_" like "API_KEY"), it uses substring matching
+// to catch variables like "MY_API_KEY".
+//
+// For single-word patterns (like "SECRET"), it uses word-boundary matching by splitting
+// on "_" to avoid false positives (e.g., "GOPATH" should not match "PATH").
+func isSensitivePattern(envName, pattern string) bool {
+ if envName == pattern {
+ return true
+ }
+ // Composite patterns: substring matching (e.g., "API_KEY" in "MY_API_KEY")
+ if strings.Contains(pattern, "_") {
+ return strings.Contains(envName, pattern)
+ }
+ // Single-word patterns: word-boundary matching (e.g., "SECRET" in "MY_SECRET_VAR")
+ for _, segment := range strings.Split(envName, "_") {
+ if segment == pattern {
+ return true
+ }
+ }
+ return false
+}
diff --git a/internal/security/env_test.go b/internal/security/env_test.go
index 74a2d9a..b34bc3e 100644
--- a/internal/security/env_test.go
+++ b/internal/security/env_test.go
@@ -4,44 +4,95 @@ import (
"testing"
)
-// TestEnvValidator tests environment variable validation
+// TestEnvValidator tests environment variable validation with word-boundary matching.
func TestEnvValidator(t *testing.T) {
envValidator := NewEnv()
tests := []struct {
- name string
- key string
- shouldErr bool
- reason string
+ name string
+ key string
+ wantErr bool
}{
- {
- name: "valid env key",
- key: "MY_VAR",
- shouldErr: false,
- reason: "valid env key should be allowed",
- },
- {
- name: "API_KEY should be blocked",
- key: "API_KEY",
- shouldErr: true,
- reason: "API_KEY is sensitive and should be blocked",
- },
- {
- name: "PASSWORD should be blocked",
- key: "PASSWORD",
- shouldErr: true,
- reason: "PASSWORD is sensitive and should be blocked",
- },
+ // Allowed — no sensitive segment
+ {name: "generic var", key: "MY_VAR", wantErr: false},
+ {name: "HOME", key: "HOME", wantErr: false},
+ {name: "SHELL", key: "SHELL", wantErr: false},
+ {name: "TERM", key: "TERM", wantErr: false},
+
+ // Word-boundary: previously false positives, now allowed
+ {name: "PWD allowed", key: "PWD", wantErr: false},
+ {name: "GOPATH allowed", key: "GOPATH", wantErr: false},
+ {name: "MANPATH allowed", key: "MANPATH", wantErr: false},
+ {name: "PYTHONPATH allowed", key: "PYTHONPATH", wantErr: false},
+
+ // Blocked — exact match
+ {name: "API_KEY blocked", key: "API_KEY", wantErr: true},
+ {name: "PASSWORD blocked", key: "PASSWORD", wantErr: true},
+ {name: "SECRET blocked", key: "SECRET", wantErr: true},
+ {name: "TOKEN blocked", key: "TOKEN", wantErr: true},
+ {name: "DATABASE_URL blocked", key: "DATABASE_URL", wantErr: true},
+
+ // Blocked — segment match
+ {name: "DB_PASSWORD blocked", key: "DB_PASSWORD", wantErr: true},
+ {name: "MY_API_KEY blocked", key: "MY_API_KEY", wantErr: true},
+ {name: "MY_SECRET_KEY blocked", key: "MY_SECRET_KEY", wantErr: true},
+ {name: "APP_TOKEN blocked", key: "APP_TOKEN", wantErr: true},
+
+ // Cloud provider segment matching
+ {name: "AZURE_CLIENT_ID blocked", key: "AZURE_CLIENT_ID", wantErr: true},
+ {name: "AZURE_TENANT_ID blocked", key: "AZURE_TENANT_ID", wantErr: true},
+ {name: "GCP_PROJECT blocked", key: "GCP_PROJECT", wantErr: true},
+ {name: "GCP_REGION blocked", key: "GCP_REGION", wantErr: true},
+
+ // Case insensitive
+ {name: "lowercase password blocked", key: "db_password", wantErr: true},
+ {name: "mixed case api_key blocked", key: "My_Api_Key", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- err := envValidator.ValidateEnvAccess(tt.key)
- if tt.shouldErr && err == nil {
- t.Errorf("expected error for %q, but got none: %s", tt.key, tt.reason)
+ err := envValidator.Validate(tt.key)
+ if tt.wantErr && err == nil {
+ t.Errorf("Validate(%q) = nil, want error", tt.key)
}
- if !tt.shouldErr && err != nil {
- t.Errorf("unexpected error for %q: %v (%s)", tt.key, err, tt.reason)
+ if !tt.wantErr && err != nil {
+ t.Errorf("Validate(%q) = %v, want nil", tt.key, err)
+ }
+ })
+ }
+}
+
+// TestIsSensitivePattern tests the word-boundary matching logic directly.
+func TestIsSensitivePattern(t *testing.T) {
+ tests := []struct {
+ name string
+ envName string
+ pattern string
+ want bool
+ }{
+ // Single-word patterns: word-boundary matching
+ {name: "exact match", envName: "PASSWORD", pattern: "PASSWORD", want: true},
+ {name: "segment match", envName: "DB_PASSWORD", pattern: "PASSWORD", want: true},
+ {name: "no match substring", envName: "PWD", pattern: "PASSWORD", want: false},
+ {name: "no match compound", envName: "GOPATH", pattern: "PATH", want: false},
+ {name: "segment at start", envName: "SECRET_KEY", pattern: "SECRET", want: true},
+ {name: "segment at end", envName: "MY_TOKEN", pattern: "TOKEN", want: true},
+ {name: "segment in middle", envName: "MY_SECRET_KEY", pattern: "SECRET", want: true},
+ {name: "no partial segment", envName: "MYPASSWORD", pattern: "PASSWORD", want: false},
+
+ // Composite patterns: substring matching
+ {name: "composite exact", envName: "API_KEY", pattern: "API_KEY", want: true},
+ {name: "composite prefix", envName: "MY_API_KEY", pattern: "API_KEY", want: true},
+ {name: "composite suffix", envName: "API_KEY_OLD", pattern: "API_KEY", want: true},
+ {name: "composite in middle", envName: "MY_API_KEY_OLD", pattern: "API_KEY", want: true},
+ {name: "composite no match", envName: "APIKEY", pattern: "API_KEY", want: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := isSensitivePattern(tt.envName, tt.pattern)
+ if got != tt.want {
+ t.Errorf("isSensitivePattern(%q, %q) = %v, want %v", tt.envName, tt.pattern, got, tt.want)
}
})
}
diff --git a/internal/security/fuzz_test.go b/internal/security/fuzz_test.go
index fc278d3..226487f 100644
--- a/internal/security/fuzz_test.go
+++ b/internal/security/fuzz_test.go
@@ -72,7 +72,7 @@ func FuzzPathValidation(f *testing.F) {
tmpDir := f.TempDir()
validator, err := NewPath([]string{tmpDir})
if err != nil {
- f.Fatalf("failed to create validator: %v", err)
+ f.Fatalf("creating validator: %v", err)
}
f.Fuzz(func(t *testing.T, input string) {
@@ -132,14 +132,14 @@ func FuzzPathValidationWithSymlinks(f *testing.F) {
tmpDir := t.TempDir()
validator, err := NewPath([]string{tmpDir})
if err != nil {
- t.Skipf("failed to create validator: %v", err)
+ t.Skipf("creating validator: %v", err)
}
// Create a symlink pointing outside allowed directories
linkPath := filepath.Join(tmpDir, linkName)
err = os.Symlink("/etc/passwd", linkPath)
if err != nil {
- t.Skipf("failed to create symlink: %v", err)
+ t.Skipf("creating symlink: %v", err)
}
// Validation should fail because symlink points outside allowed dirs
@@ -208,7 +208,7 @@ func FuzzCommandValidation(f *testing.F) {
f.Fuzz(func(t *testing.T, cmd, args string) {
argSlice := strings.Fields(args)
- err := validator.ValidateCommand(cmd, argSlice)
+ err := validator.Validate(cmd, argSlice)
// Property 1: Commands with shell metacharacters in name must be rejected
shellMetachars := []string{";", "|", "&", "`", "$", "(", ")", "\n", ">", "<"}
@@ -268,3 +268,68 @@ func FuzzCommandValidation(f *testing.F) {
}
})
}
+
+// =============================================================================
+// URL Fuzzing Tests
+// =============================================================================
+
+// FuzzURLValidation tests URL validation against SSRF bypass attempts.
+// Run with: go test -fuzz=FuzzURLValidation -fuzztime=30s ./internal/security/
+func FuzzURLValidation(f *testing.F) {
+ seeds := []string{
+ // Valid public URLs
+ "https://example.com",
+ "http://example.com/path?q=1",
+
+ // Blocked schemes
+ "ftp://example.com",
+ "file:///etc/passwd",
+ "javascript:alert(1)",
+ "gopher://evil.com",
+
+ // Loopback
+ "http://127.0.0.1",
+ "http://127.0.0.1:8080",
+ "http://[::1]",
+
+ // Private IPs
+ "http://10.0.0.1",
+ "http://172.16.0.1",
+ "http://192.168.1.1",
+
+ // Cloud metadata
+ "http://169.254.169.254/latest/meta-data/",
+ "http://metadata.google.internal",
+
+ // Blocked hosts
+ "http://localhost",
+ "http://localhost:3000",
+
+ // Edge cases
+ "",
+ "://",
+ "http://",
+ "http://0.0.0.0",
+ "http://[::ffff:127.0.0.1]",
+
+ // Encoding tricks
+ "http://0x7f000001", // 127.0.0.1 as hex
+ "http://2130706433", // 127.0.0.1 as decimal
+ "http://017700000001", // 127.0.0.1 as octal
+ "http://[::ffff:7f00:1]", // IPv6-mapped IPv4 loopback
+ "http://127.1", // short form loopback
+ "http://0x7f.0.0.1", // partial hex loopback
+ "http://0177.0.0.1", // octal first octet
+ }
+
+ for _, seed := range seeds {
+ f.Add(seed)
+ }
+
+ validator := NewURL()
+
+ f.Fuzz(func(t *testing.T, rawURL string) {
+ // Must not panic
+ _ = validator.Validate(rawURL)
+ })
+}
diff --git a/internal/security/path_test.go b/internal/security/path_test.go
index efbfa50..be8e75f 100644
--- a/internal/security/path_test.go
+++ b/internal/security/path_test.go
@@ -14,18 +14,18 @@ func TestPathValidation(t *testing.T) {
tmpDir := t.TempDir()
workDir, err := os.Getwd()
if err != nil {
- t.Fatalf("failed to get working directory: %v", err)
+ t.Fatalf("getting working directory: %v", err)
}
// Change to temp directory for testing
if err := os.Chdir(tmpDir); err != nil {
- t.Fatalf("failed to change to temp directory: %v", err)
+ t.Fatalf("changing to temp directory: %v", err)
}
defer func() { _ = os.Chdir(workDir) }() // Restore original directory
validator, err := NewPath([]string{tmpDir})
if err != nil {
- t.Fatalf("failed to create path validator: %v", err)
+ t.Fatalf("creating path validator: %v", err)
}
tests := []struct {
@@ -64,10 +64,10 @@ func TestPathValidation(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
_, err := validator.Validate(tt.path)
if tt.shouldErr && err == nil {
- t.Errorf("expected error for %s, but got none: %s", tt.path, tt.reason)
+ t.Errorf("Validate(%q) = nil, want error: %s", tt.path, tt.reason)
}
if !tt.shouldErr && err != nil {
- t.Errorf("unexpected error for %s: %v (%s)", tt.path, err, tt.reason)
+ t.Errorf("Validate(%q) unexpected error: %v (%s)", tt.path, err, tt.reason)
}
})
}
@@ -77,7 +77,7 @@ func TestPathValidation(t *testing.T) {
func TestPathErrorSanitization(t *testing.T) {
validator, err := NewPath([]string{})
if err != nil {
- t.Fatalf("failed to create path validator: %v", err)
+ t.Fatalf("creating path validator: %v", err)
}
// Try to access a path outside allowed directories
@@ -103,24 +103,24 @@ func TestSymlinkValidation(t *testing.T) {
tmpDir := t.TempDir()
workDir, err := os.Getwd()
if err != nil {
- t.Fatalf("failed to get working directory: %v", err)
+ t.Fatalf("getting working directory: %v", err)
}
// Change to temp directory for testing
if err := os.Chdir(tmpDir); err != nil {
- t.Fatalf("failed to change to temp directory: %v", err)
+ t.Fatalf("changing to temp directory: %v", err)
}
defer func() { _ = os.Chdir(workDir) }() // Restore original directory
validator, err := NewPath([]string{tmpDir})
if err != nil {
- t.Fatalf("failed to create path validator: %v", err)
+ t.Fatalf("creating path validator: %v", err)
}
// Create a file
targetFile := filepath.Join(tmpDir, "target.txt")
if err := os.WriteFile(targetFile, []byte("test"), 0o600); err != nil {
- t.Fatalf("failed to create target file: %v", err)
+ t.Fatalf("creating target file: %v", err)
}
// Create a symlink to the file
@@ -141,7 +141,7 @@ func TestSymlinkValidation(t *testing.T) {
expectedPath = targetFile
}
if resolvedPath != expectedPath {
- t.Errorf("expected resolved path %s, got %s", expectedPath, resolvedPath)
+ t.Errorf("Validate() resolved path = %q, want %q", resolvedPath, expectedPath)
}
}
@@ -150,17 +150,17 @@ func TestPathValidationWithNonExistentFile(t *testing.T) {
tmpDir := t.TempDir()
workDir, err := os.Getwd()
if err != nil {
- t.Fatalf("failed to get working directory: %v", err)
+ t.Fatalf("getting working directory: %v", err)
}
if err := os.Chdir(tmpDir); err != nil {
- t.Fatalf("failed to change to temp directory: %v", err)
+ t.Fatalf("changing to temp directory: %v", err)
}
defer func() { _ = os.Chdir(workDir) }()
validator, err := NewPath([]string{tmpDir})
if err != nil {
- t.Fatalf("failed to create path validator: %v", err)
+ t.Fatalf("creating path validator: %v", err)
}
// Test with non-existent file (should be allowed for creating new files)
@@ -170,7 +170,7 @@ func TestPathValidationWithNonExistentFile(t *testing.T) {
t.Errorf("validation of non-existent file failed: %v", err)
}
if validatedPath != nonExistentPath {
- t.Errorf("expected path %s, got %s", nonExistentPath, validatedPath)
+ t.Errorf("Validate(%q) = %q, want %q", nonExistentPath, validatedPath, nonExistentPath)
}
}
@@ -179,24 +179,24 @@ func TestSymlinkBypassAttempt(t *testing.T) {
tmpDir := t.TempDir()
workDir, err := os.Getwd()
if err != nil {
- t.Fatalf("failed to get working directory: %v", err)
+ t.Fatalf("getting working directory: %v", err)
}
// Create another temp directory outside the allowed directory
outsideDir := t.TempDir()
outsideFile := filepath.Join(outsideDir, "secret.txt")
if err := os.WriteFile(outsideFile, []byte("secret data"), 0o600); err != nil {
- t.Fatalf("failed to create outside file: %v", err)
+ t.Fatalf("creating outside file: %v", err)
}
if err := os.Chdir(tmpDir); err != nil {
- t.Fatalf("failed to change to temp directory: %v", err)
+ t.Fatalf("changing to temp directory: %v", err)
}
defer func() { _ = os.Chdir(workDir) }()
validator, err := NewPath([]string{tmpDir})
if err != nil {
- t.Fatalf("failed to create path validator: %v", err)
+ t.Fatalf("creating path validator: %v", err)
}
// Create symlink inside allowed dir pointing to file outside
@@ -212,7 +212,7 @@ func TestSymlinkBypassAttempt(t *testing.T) {
}
if err != nil && !errors.Is(err, ErrSymlinkOutsideAllowed) {
- t.Errorf("expected ErrSymlinkOutsideAllowed, got: %v", err)
+ t.Errorf("Validate(symlink) error = %v, want ErrSymlinkOutsideAllowed", err)
}
}
@@ -221,31 +221,32 @@ func TestPathValidationErrors(t *testing.T) {
tmpDir := t.TempDir()
workDir, err := os.Getwd()
if err != nil {
- t.Fatalf("failed to get working directory: %v", err)
+ t.Fatalf("getting working directory: %v", err)
}
if err := os.Chdir(tmpDir); err != nil {
- t.Fatalf("failed to change to temp directory: %v", err)
+ t.Fatalf("changing to temp directory: %v", err)
}
defer func() { _ = os.Chdir(workDir) }()
validator, err := NewPath([]string{tmpDir})
if err != nil {
- t.Fatalf("failed to create path validator: %v", err)
+ t.Fatalf("creating path validator: %v", err)
}
// Test with extremely long path (should be handled gracefully)
longPath := filepath.Join(tmpDir, string(make([]byte, 1000)))
_, err = validator.Validate(longPath)
- // Should not panic, error is acceptable
- _ = err
+ if err == nil {
+ t.Error("Validate(longPath) = nil, want error for 1000-byte filename")
+ }
}
// BenchmarkPathValidation benchmarks path validation performance
func BenchmarkPathValidation(b *testing.B) {
validator, err := NewPath([]string{})
if err != nil {
- b.Fatalf("failed to create path validator: %v", err)
+ b.Fatalf("creating path validator: %v", err)
}
b.ResetTimer()
diff --git a/internal/security/url.go b/internal/security/url.go
index 3a3b0dc..aef0122 100644
--- a/internal/security/url.go
+++ b/internal/security/url.go
@@ -1,7 +1,3 @@
-// Package security provides security validators for Koopa.
-//
-// URL validator prevents SSRF (Server-Side Request Forgery) attacks by blocking
-// requests to private networks, cloud metadata endpoints, and other dangerous targets.
package security
import (
@@ -173,7 +169,7 @@ func (v *URL) safeDialContext(ctx context.Context, network, addr string) (net.Co
}
conn, dialErr := (&net.Dialer{}).DialContext(ctx, network, addr)
if dialErr != nil {
- return nil, fmt.Errorf("dial failed: %w", dialErr)
+ return nil, fmt.Errorf("dialing: %w", dialErr)
}
return conn, nil
}
@@ -181,7 +177,7 @@ func (v *URL) safeDialContext(ctx context.Context, network, addr string) (net.Co
// Resolve DNS and check all returned IPs
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
if err != nil {
- return nil, fmt.Errorf("DNS lookup failed: %w", err)
+ return nil, fmt.Errorf("resolving DNS: %w", err)
}
// Check all resolved IPs
@@ -200,7 +196,7 @@ func (v *URL) safeDialContext(ctx context.Context, network, addr string) (net.Co
}
conn, err := (&net.Dialer{}).DialContext(ctx, network, targetAddr)
if err != nil {
- return nil, fmt.Errorf("dial to %s failed: %w", targetAddr, err)
+ return nil, fmt.Errorf("dialing %s: %w", targetAddr, err)
}
return conn, nil
}
diff --git a/internal/security/url_test.go b/internal/security/url_test.go
index af9e650..20455e6 100644
--- a/internal/security/url_test.go
+++ b/internal/security/url_test.go
@@ -2,6 +2,7 @@ package security
import (
"net"
+ "strings"
"testing"
)
@@ -164,7 +165,7 @@ func TestURL_Validate(t *testing.T) {
t.Errorf("Validate(%q) expected error, got nil", tt.url)
return
}
- if tt.errMsg != "" && !urlContains(err.Error(), tt.errMsg) {
+ if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("Validate(%q) error = %q, want error containing %q", tt.url, err.Error(), tt.errMsg)
}
} else if err != nil {
@@ -205,7 +206,7 @@ func TestURL_checkIP(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
if ip == nil {
- t.Fatalf("failed to parse IP: %s", tt.ip)
+ t.Fatalf("parsing IP: %s", tt.ip)
}
err := v.checkIP(ip)
if tt.wantErr && err == nil {
@@ -229,20 +230,32 @@ func TestURL_SafeTransport(t *testing.T) {
if transport.DialContext == nil {
t.Error("SafeTransport() DialContext is nil")
}
-}
-
-// Helper functions
-func urlContains(s, substr string) bool {
- return len(s) >= len(substr) && (s == substr || substr == "" ||
- (s != "" && substr != "" && urlContainsSubstring(s, substr)))
-}
+ // Verify SafeTransport blocks dangerous IPs at the dial level.
+ // This tests DNS-rebinding protection: even if DNS resolves to a blocked IP,
+ // the custom DialContext must reject the connection.
+ tests := []struct {
+ name string
+ addr string
+ wantSub string // expected substring in error message
+ }{
+ {name: "loopback", addr: "127.0.0.1:80", wantSub: "loopback"},
+ {name: "private 10.x", addr: "10.0.0.1:80", wantSub: "private"},
+ {name: "private 192.168.x", addr: "192.168.1.1:80", wantSub: "private"},
+ {name: "link-local metadata", addr: "169.254.169.254:80", wantSub: "link-local"},
+ {name: "IPv6 loopback", addr: "[::1]:80", wantSub: "loopback"},
+ }
-func urlContainsSubstring(s, substr string) bool {
- for i := 0; i <= len(s)-len(substr); i++ {
- if s[i:i+len(substr)] == substr {
- return true
- }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := transport.DialContext(t.Context(), "tcp", tt.addr)
+ if err == nil {
+ t.Errorf("SafeTransport().DialContext(%q) = nil, want error", tt.addr)
+ return
+ }
+ if !strings.Contains(err.Error(), tt.wantSub) {
+ t.Errorf("SafeTransport().DialContext(%q) error = %q, want error containing %q", tt.addr, err.Error(), tt.wantSub)
+ }
+ })
}
- return false
}
diff --git a/internal/session/benchmark_test.go b/internal/session/benchmark_test.go
index 42ebff7..bf7f20b 100644
--- a/internal/session/benchmark_test.go
+++ b/internal/session/benchmark_test.go
@@ -12,6 +12,7 @@ import (
"time"
"github.com/firebase/genkit/go/ai"
+ "github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/koopa0/koopa/internal/sqlc"
)
@@ -34,7 +35,7 @@ func BenchmarkStore_GetHistory(b *testing.B) {
for b.Loop() {
_, err := store.History(ctx, sessionID)
if err != nil {
- b.Fatalf("GetHistory failed: %v", err)
+ b.Fatalf("History(): %v", err)
}
}
}
@@ -51,7 +52,7 @@ func BenchmarkStore_GetHistory_SmallSession(b *testing.B) {
b.ResetTimer()
for b.Loop() {
if _, err := store.History(ctx, sessionID); err != nil {
- b.Fatalf("GetHistory failed: %v", err)
+ b.Fatalf("History(): %v", err)
}
}
}
@@ -68,7 +69,7 @@ func BenchmarkStore_GetHistory_LargeSession(b *testing.B) {
b.ResetTimer()
for b.Loop() {
if _, err := store.History(ctx, sessionID); err != nil {
- b.Fatalf("GetHistory failed: %v", err)
+ b.Fatalf("History(): %v", err)
}
}
}
@@ -83,9 +84,9 @@ func BenchmarkStore_AddMessages(b *testing.B) {
store := New(sqlc.New(pool), pool, logger)
// Create a test session
- session, err := store.CreateSession(ctx, "Benchmark-AddMessages", "", "")
+ session, err := store.CreateSession(ctx, "Benchmark-AddMessages")
if err != nil {
- b.Fatalf("Failed to create session: %v", err)
+ b.Fatalf("creating session: %v", err)
}
defer func() { _ = store.DeleteSession(ctx, session.ID) }()
@@ -124,9 +125,9 @@ func BenchmarkStore_AppendMessages(b *testing.B) {
store := New(sqlc.New(pool), pool, logger)
// Create a test session
- session, err := store.CreateSession(ctx, "Benchmark-AppendMessages", "", "")
+ session, err := store.CreateSession(ctx, "Benchmark-AppendMessages")
if err != nil {
- b.Fatalf("Failed to create session: %v", err)
+ b.Fatalf("creating session: %v", err)
}
defer func() { _ = store.DeleteSession(ctx, session.ID) }()
@@ -158,23 +159,21 @@ func BenchmarkStore_CreateSession(b *testing.B) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
store := New(sqlc.New(pool), pool, logger)
- createdSessionIDs := make([]string, 0, b.N)
+ createdSessionIDs := make([]uuid.UUID, 0, b.N)
defer func() {
for _, id := range createdSessionIDs {
- if parsed, err := parseBenchUUID(id); err == nil {
- _ = store.DeleteSession(context.Background(), parsed)
- }
+ _ = store.DeleteSession(context.Background(), id)
}
}()
b.ReportAllocs()
b.ResetTimer()
for i := range b.N {
- session, err := store.CreateSession(ctx, fmt.Sprintf("Benchmark-Session-%d", i), "test-model", "test-prompt")
+ session, err := store.CreateSession(ctx, fmt.Sprintf("Benchmark-Session-%d", i))
if err != nil {
b.Fatalf("CreateSession failed at iteration %d: %v", i, err)
}
- createdSessionIDs = append(createdSessionIDs, session.ID.String())
+ createdSessionIDs = append(createdSessionIDs, session.ID)
}
}
@@ -188,9 +187,9 @@ func BenchmarkStore_GetSession(b *testing.B) {
store := New(sqlc.New(pool), pool, logger)
// Create a test session
- session, err := store.CreateSession(ctx, "Benchmark-GetSession", "", "")
+ session, err := store.CreateSession(ctx, "Benchmark-GetSession")
if err != nil {
- b.Fatalf("Failed to create session: %v", err)
+ b.Fatalf("creating session: %v", err)
}
defer func() { _ = store.DeleteSession(ctx, session.ID) }()
@@ -199,13 +198,13 @@ func BenchmarkStore_GetSession(b *testing.B) {
for b.Loop() {
_, err := store.Session(ctx, session.ID)
if err != nil {
- b.Fatalf("GetSession failed: %v", err)
+ b.Fatalf("Session(): %v", err)
}
}
}
-// BenchmarkStore_ListSessions benchmarks listing sessions.
-func BenchmarkStore_ListSessions(b *testing.B) {
+// BenchmarkStore_Sessions benchmarks listing sessions.
+func BenchmarkStore_Sessions(b *testing.B) {
ctx := context.Background()
pool, cleanup := setupBenchmarkDB(b, ctx)
defer cleanup()
@@ -215,9 +214,9 @@ func BenchmarkStore_ListSessions(b *testing.B) {
// Create some test sessions
for i := 0; i < 20; i++ {
- session, err := store.CreateSession(ctx, fmt.Sprintf("Benchmark-List-%d", i), "", "")
+ session, err := store.CreateSession(ctx, fmt.Sprintf("Benchmark-List-%d", i))
if err != nil {
- b.Fatalf("Failed to create session: %v", err)
+ b.Fatalf("creating session: %v", err)
}
defer func(s *Session) { _ = store.DeleteSession(context.Background(), s.ID) }(session)
}
@@ -225,9 +224,9 @@ func BenchmarkStore_ListSessions(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
- _, err := store.ListSessions(ctx, 100, 0)
+ _, err := store.Sessions(ctx, 100, 0)
if err != nil {
- b.Fatalf("ListSessions failed: %v", err)
+ b.Fatalf("Sessions() unexpected error: %v", err)
}
}
}
@@ -243,7 +242,7 @@ func BenchmarkStore_GetMessages(b *testing.B) {
for b.Loop() {
_, err := store.Messages(ctx, session.ID, 100, 0)
if err != nil {
- b.Fatalf("GetMessages failed: %v", err)
+ b.Fatalf("Messages(): %v", err)
}
}
}
@@ -257,10 +256,10 @@ func setupBenchmarkSession(b *testing.B, ctx context.Context, numMessages int) (
store := New(sqlc.New(pool), pool, logger)
// Create a test session
- session, err := store.CreateSession(ctx, "Benchmark-Session", "", "")
+ session, err := store.CreateSession(ctx, "Benchmark-Session")
if err != nil {
cleanup()
- b.Fatalf("Failed to create session: %v", err)
+ b.Fatalf("creating session: %v", err)
}
// Pre-load messages (in batches for efficiency)
@@ -285,7 +284,7 @@ func setupBenchmarkSession(b *testing.B, ctx context.Context, numMessages int) (
if err := store.AddMessages(ctx, session.ID, messages); err != nil {
cleanup()
- b.Fatalf("Failed to add messages: %v", err)
+ b.Fatalf("adding messages: %v", err)
}
}
@@ -311,12 +310,12 @@ func setupBenchmarkDB(b *testing.B, ctx context.Context) (*pgxpool.Pool, func())
pool, err := pgxpool.New(ctx, dbURL)
if err != nil {
- b.Fatalf("Failed to connect to database: %v", err)
+ b.Fatalf("connecting to database: %v", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
- b.Fatalf("Failed to ping database: %v", err)
+ b.Fatalf("pinging database: %v", err)
}
cleanup := func() {
@@ -327,34 +326,3 @@ func setupBenchmarkDB(b *testing.B, ctx context.Context) (*pgxpool.Pool, func())
return pool, cleanup
}
-
-// parseBenchUUID is a helper to parse UUID strings for benchmarks.
-func parseBenchUUID(s string) ([16]byte, error) {
- var u [16]byte
- if len(s) != 36 {
- return u, fmt.Errorf("invalid UUID length")
- }
- // Simple UUID parsing - parse each segment into temporary variables
- var a, b, c, d, e uint64
- if _, err := fmt.Sscanf(s, "%08x-%04x-%04x-%04x-%012x", &a, &b, &c, &d, &e); err != nil {
- return u, err
- }
- // Pack into [16]byte
- u[0] = byte(a >> 24)
- u[1] = byte(a >> 16)
- u[2] = byte(a >> 8)
- u[3] = byte(a)
- u[4] = byte(b >> 8)
- u[5] = byte(b)
- u[6] = byte(c >> 8)
- u[7] = byte(c)
- u[8] = byte(d >> 8)
- u[9] = byte(d)
- u[10] = byte(e >> 40)
- u[11] = byte(e >> 32)
- u[12] = byte(e >> 24)
- u[13] = byte(e >> 16)
- u[14] = byte(e >> 8)
- u[15] = byte(e)
- return u, nil
-}
diff --git a/internal/session/doc.go b/internal/session/doc.go
index 7c8a3a0..17dfeb2 100644
--- a/internal/session/doc.go
+++ b/internal/session/doc.go
@@ -1,225 +1,30 @@
-// Package session provides conversation history persistence.
+// Package session provides conversation history persistence with PostgreSQL.
//
-// The session package manages conversation sessions and messages with PostgreSQL backend.
-// It provides thread-safe storage for conversation history with concurrent message insertion
-// and transaction-safe operations.
+// A session represents a conversation context containing ordered messages
+// exchanged between user and model. The [Store] handles persistence while
+// the agent handles conversation logic.
//
-// # Overview
+// Key operations:
//
-// A session represents a conversation context. Each session can contain multiple messages
-// exchanged between user and model. The package handles persistence while the agent handles
-// the conversation logic.
-//
-// Key responsibilities:
-//
-// - Session lifecycle (create, retrieve, list, delete)
-// - Message persistence with sequential ordering
-// - Transaction-safe batch message insertion
-// - Concurrent access safety
-//
-// # Architecture
-//
-// Session and Message Organization:
-//
-// Session (conversation context)
-// |
-// +-- Metadata (ID, title, model_name, created_at, updated_at)
-// |
-// v
-// Messages (ordered conversation)
-// |
-// +-- Message 1 (role: "user")
-// +-- Message 2 (role: "model")
-// +-- Message 3 (role: "user")
-// +-- ...
-//
-// Messages are ordered by sequence number within a session, ensuring
-// consistent retrieval and reconstruction of conversation context.
-//
-// # Session Management
-//
-// The Store type provides session operations:
-//
-// CreateSession(ctx, title, modelName, systemPrompt) - Create new session
-// Session(ctx, sessionID) - Retrieve session
-// ListSessions(ctx, limit, offset) - List sessions with pagination
-// DeleteSession(ctx, sessionID) - Delete session and messages
-//
-// Sessions store optional metadata:
-//
-// - Title: User-friendly name
-// - ModelName: LLM model used for this session
-// - SystemPrompt: Custom system instructions
-// - MessageCount: Total number of messages (cached for efficiency)
-//
-// # Message Persistence
-//
-// The Store provides message operations:
-//
-// AddMessages(ctx, sessionID, messages) - Batch insert with transaction safety
-// Messages(ctx, sessionID, limit, offset) - Retrieve messages with pagination
-//
-// Messages are stored with:
-//
-// - Role: "user", "model", or "tool"
-// - Content: Array of ai.Part (serialized as JSONB)
-// - SequenceNumber: Sequential ordering within session
-// - CreatedAt: Timestamp
+// - Session lifecycle: [Store.CreateSession], [Store.Session], [Store.Sessions], [Store.DeleteSession], [Store.ResolveCurrentSession]
+// - Message persistence: [Store.AddMessages], [Store.Messages] (transaction-safe batch insertion)
+// - Agent integration: [Store.History], [Store.AppendMessages]
//
// # Transaction Safety
//
-// AddMessages provides ACID guarantees for batch message insertion:
-//
-// 1. Lock session row (SELECT ... FOR UPDATE)
-// 2. Get current max sequence number
-// 3. Insert messages in batch with next sequence numbers
-// 4. Update session metadata (message_count, updated_at)
-// 5. Commit transaction atomically
-//
-// If any step fails, the entire transaction rolls back, ensuring consistency.
-// Session locking prevents race conditions in concurrent scenarios.
-//
-// # Chat Agent Integration
-//
-// The Store provides methods for integration with the Chat agent:
-//
-// History(ctx, sessionID) - Get conversation history for agent
-// AppendMessages(ctx, sessionID, messages) - Persist conversation messages
-//
-// Following Go standard library conventions (similar to database/sql returning *sql.DB),
-// consumers use *session.Store directly instead of defining separate interfaces.
-// Testability is achieved via the internal Querier interface (for mocking database operations).
-//
-// # Database Backend
-//
-// The session store requires PostgreSQL with the following schema:
-//
-// sessions table:
-// id UUID PRIMARY KEY
-// title TEXT
-// model_name TEXT
-// system_prompt TEXT
-// message_count INT32
-// created_at TIMESTAMPTZ
-// updated_at TIMESTAMPTZ
-//
-// session_messages table:
-// id UUID PRIMARY KEY
-// session_id UUID FOREIGN KEY (CASCADE)
-// role TEXT (user|model|tool)
-// content JSONB (ai.Part array)
-// sequence_number INT32
-// created_at TIMESTAMPTZ
-//
-// # Example Usage
-//
-// package main
-//
-// import (
-// "context"
-// "github.com/firebase/genkit/go/ai"
-// "github.com/jackc/pgx/v5/pgxpool"
-// "github.com/koopa0/koopa/internal/session"
-// "log/slog"
-// )
-//
-// func main() {
-// ctx := context.Background()
-//
-// // Connect to PostgreSQL
-// dbPool, _ := pgxpool.New(ctx, "postgresql://...")
-// defer dbPool.Close()
-//
-// // Create session store
-// store := session.New(sqlc.New(dbPool), dbPool, slog.Default())
-//
-// // Create a new session
-// sess, err := store.CreateSession(ctx, "My Conversation", "gemini-pro", "")
-// if err != nil {
-// panic(err)
-// }
-// println("Session ID:", sess.ID)
-//
-// // Add messages to session
-// messages := []*session.Message{
-// {
-// Role: session.RoleUser,
-// Content: []*ai.Part{ai.NewTextPart("Hello!")},
-// },
-// {
-// Role: session.RoleModel,
-// Content: []*ai.Part{ai.NewTextPart("Hi there!")},
-// },
-// }
-//
-// err = store.AddMessages(ctx, sess.ID, messages)
-// if err != nil {
-// panic(err)
-// }
-//
-// // Retrieve messages
-// retrieved, _ := store.Messages(ctx, sess.ID, 100, 0)
-// println("Retrieved", len(retrieved), "messages")
-//
-// // Load history for agent
-// history, _ := store.History(ctx, sess.ID)
-// println("History messages:", len(history.Messages()))
-//
-// // List all sessions
-// sessions, _ := store.ListSessions(ctx, 10, 0)
-// println("Total sessions:", len(sessions))
-// }
-//
-// # Concurrency and Race Conditions
-//
-// The Store is designed to handle concurrent access safely:
-//
-// - Session locking (SELECT ... FOR UPDATE) prevents concurrent modifications
-// - Transactions ensure atomicity of batch operations
-// - PostgreSQL isolation levels handle concurrent reads
-// - No shared state in Go code (all state in database)
-//
-// However, callers should avoid concurrent modifications to the same session
-// to prevent transaction conflicts and retries.
-//
-// # Pagination
-//
-// All list operations support limit/offset pagination:
-//
-// - ListSessions(ctx, limit=10, offset=0) // First 10 sessions
-// - GetMessages(ctx, sessionID, limit=50, offset=0) // First 50 messages
-//
-// Sessions are ordered by updated_at descending (most recent first).
-// Messages are ordered by sequence_number ascending (earliest first).
-//
-// # Error Handling
-//
-// The Store propagates database errors with context:
-//
-// - "failed to create session: ..." - Creation failures
-// - "failed to get session ...: ..." - Retrieval failures
-// - "failed to lock session: ..." - Concurrency issues
-// - "failed to insert message ...: ..." - Message insertion failures
-// - "failed to unmarshal message content: ..." - Deserialization errors (skipped)
-//
-// Malformed messages are skipped during retrieval, allowing resilience
-// to schema changes.
-//
-// # Testing
-//
-// The session package is designed for testability:
+// [Store.AddMessages] uses SELECT ... FOR UPDATE to lock the session row,
+// preventing race conditions on sequence numbers during concurrent writes.
+// If any step fails, the entire transaction rolls back.
//
-// - Store accepts Querier interface for mock database
-// - New() accepts interface, pass mock querier directly for tests
-// - Integration tests use real PostgreSQL database
-// - Supports non-transactional mode for mock testing
+// # Concurrency
//
-// # Thread Safety
+// Store is safe for concurrent use. All state lives in PostgreSQL;
+// no shared Go-side state exists. Session locking and transaction
+// isolation handle concurrent access.
//
-// The Store is thread-safe for concurrent use:
+// # Local State
//
-// - All database operations use connection pool
-// - PostgreSQL handles concurrent access safely
-// - No shared state in Go code
-// - Transactions provide isolation
+// [SaveCurrentSessionID] and [LoadCurrentSessionID] persist the active session
+// to ~/.koopa/current_session using atomic writes (temp file + rename) with
+// file locking via [github.com/gofrs/flock].
package session
diff --git a/internal/session/errors.go b/internal/session/errors.go
deleted file mode 100644
index 94ead7e..0000000
--- a/internal/session/errors.go
+++ /dev/null
@@ -1,56 +0,0 @@
-package session
-
-import "errors"
-
-// Message status constants for streaming lifecycle.
-const (
- // StatusStreaming indicates the message is being streamed (AI generating).
- StatusStreaming = "streaming"
-
- // StatusCompleted indicates the message has been fully generated.
- StatusCompleted = "completed"
-
- // StatusFailed indicates the message generation failed.
- StatusFailed = "failed"
-)
-
-// History limit constants.
-const (
- // DefaultHistoryLimit is the default number of messages to load.
- DefaultHistoryLimit int32 = 100
-
- // MaxHistoryLimit is the absolute maximum to prevent OOM.
- MaxHistoryLimit int32 = 10000
-
- // MinHistoryLimit is the minimum allowed value for history limit.
- MinHistoryLimit int32 = 10
-)
-
-// Sentinel errors for session operations.
-// These errors are part of the Store's public API and should be checked using errors.Is().
-//
-// Example:
-//
-// sess, err := store.Session(ctx, id)
-// if errors.Is(err, session.ErrSessionNotFound) {
-// // Handle missing session
-// }
-//
-// ErrSessionNotFound indicates the requested session does not exist in the database.
-var ErrSessionNotFound = errors.New("session not found")
-
-// NormalizeHistoryLimit normalizes the history limit value.
-// Returns DefaultHistoryLimit for zero/negative values.
-// Clamps to MinHistoryLimit/MaxHistoryLimit as bounds.
-func NormalizeHistoryLimit(limit int32) int32 {
- if limit <= 0 {
- return DefaultHistoryLimit
- }
- if limit < MinHistoryLimit {
- return MinHistoryLimit
- }
- if limit > MaxHistoryLimit {
- return MaxHistoryLimit
- }
- return limit
-}
diff --git a/internal/session/errors_test.go b/internal/session/errors_test.go
deleted file mode 100644
index e9592ce..0000000
--- a/internal/session/errors_test.go
+++ /dev/null
@@ -1,121 +0,0 @@
-package session
-
-import (
- "testing"
-)
-
-func TestNormalizeHistoryLimit(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
- input int32
- want int32
- }{
- // Default cases
- {"zero defaults", 0, DefaultHistoryLimit},
- {"negative defaults", -1, DefaultHistoryLimit},
- {"large negative defaults", -999, DefaultHistoryLimit},
-
- // Clamping to minimum
- {"below min clamped", MinHistoryLimit - 1, MinHistoryLimit},
- {"exactly min", MinHistoryLimit, MinHistoryLimit},
-
- // Valid middle values
- {"valid 50", 50, 50},
- {"valid 100", 100, 100},
- {"valid 500", 500, 500},
- {"valid 5000", 5000, 5000},
-
- // Clamping to maximum
- {"exactly max", MaxHistoryLimit, MaxHistoryLimit},
- {"above max clamped", MaxHistoryLimit + 1, MaxHistoryLimit},
- {"large above max", MaxHistoryLimit * 2, MaxHistoryLimit},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := NormalizeHistoryLimit(tt.input)
- if got != tt.want {
- t.Errorf("NormalizeHistoryLimit(%d) = %d, want %d", tt.input, got, tt.want)
- }
- })
- }
-}
-
-func TestConstants(t *testing.T) {
- t.Parallel()
-
- t.Run("DefaultHistoryLimit", func(t *testing.T) {
- if DefaultHistoryLimit != 100 {
- t.Errorf("DefaultHistoryLimit = %d, want %d", DefaultHistoryLimit, 100)
- }
- })
-
- t.Run("MaxHistoryLimit", func(t *testing.T) {
- if MaxHistoryLimit != 10000 {
- t.Errorf("MaxHistoryLimit = %d, want %d", MaxHistoryLimit, 10000)
- }
- })
-
- t.Run("MinHistoryLimit", func(t *testing.T) {
- if MinHistoryLimit != 10 {
- t.Errorf("MinHistoryLimit = %d, want %d", MinHistoryLimit, 10)
- }
- })
-
- t.Run("StatusConstants", func(t *testing.T) {
- if StatusStreaming != "streaming" {
- t.Errorf("StatusStreaming = %q, want %q", StatusStreaming, "streaming")
- }
- if StatusCompleted != "completed" {
- t.Errorf("StatusCompleted = %q, want %q", StatusCompleted, "completed")
- }
- if StatusFailed != "failed" {
- t.Errorf("StatusFailed = %q, want %q", StatusFailed, "failed")
- }
- })
-}
-
-// BenchmarkNormalizeHistoryLimit benchmarks limit normalization.
-func BenchmarkNormalizeHistoryLimit(b *testing.B) {
- limits := []int32{0, -1, 50, 100, 10001}
-
- b.ResetTimer()
- for b.Loop() {
- for _, limit := range limits {
- _ = NormalizeHistoryLimit(limit)
- }
- }
-}
-
-// TestNormalizeRole tests the Genkit role normalization function.
-// Genkit uses "model" for AI responses, but we store "assistant" in the database
-// for consistency with the CHECK constraint.
-func TestNormalizeRole(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
- input string
- expected string
- }{
- {"model to assistant", "model", "assistant"},
- {"user unchanged", "user", "user"},
- {"assistant unchanged", "assistant", "assistant"},
- {"system unchanged", "system", "system"},
- {"tool unchanged", "tool", "tool"},
- {"empty passthrough", "", ""},
- {"unknown passthrough", "unknown", "unknown"},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
- got := normalizeRole(tt.input)
- if got != tt.expected {
- t.Errorf("normalizeRole(%q) = %q, want %q", tt.input, got, tt.expected)
- }
- })
- }
-}
diff --git a/internal/session/integration_test.go b/internal/session/integration_test.go
index fbad17e..35593f5 100644
--- a/internal/session/integration_test.go
+++ b/internal/session/integration_test.go
@@ -20,31 +20,22 @@ import (
"github.com/koopa0/koopa/internal/testutil"
)
-// =============================================================================
-// Test Setup Helper
-// =============================================================================
-
// setupIntegrationTest creates a Store with test database connection.
// All integration tests should use this unified setup.
-func setupIntegrationTest(t *testing.T) (*Store, func()) {
+// Cleanup is automatic via tb.Cleanup — no manual cleanup needed.
+func setupIntegrationTest(t *testing.T) *Store {
t.Helper()
- dbContainer, cleanup := testutil.SetupTestDB(t)
- store := New(sqlc.New(dbContainer.Pool), dbContainer.Pool, slog.Default())
- return store, cleanup
+ dbContainer := testutil.SetupTestDB(t)
+ return New(sqlc.New(dbContainer.Pool), dbContainer.Pool, slog.Default())
}
-// =============================================================================
-// Basic CRUD Tests
-// =============================================================================
-
// TestStore_CreateAndGet tests creating and retrieving a session
func TestStore_CreateAndGet(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a session
- session, err := store.CreateSession(ctx, "Test Session", "gemini-2.5-flash", "You are a helpful assistant")
+ session, err := store.CreateSession(ctx, "Test Session")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
@@ -57,12 +48,6 @@ func TestStore_CreateAndGet(t *testing.T) {
if session.Title != "Test Session" {
t.Errorf("CreateSession() Title = %q, want %q", session.Title, "Test Session")
}
- if session.ModelName != "gemini-2.5-flash" {
- t.Errorf("CreateSession() ModelName = %q, want %q", session.ModelName, "gemini-2.5-flash")
- }
- if session.SystemPrompt != "You are a helpful assistant" {
- t.Errorf("CreateSession() SystemPrompt = %q, want %q", session.SystemPrompt, "You are a helpful assistant")
- }
if session.CreatedAt.IsZero() {
t.Error("CreateSession() CreatedAt should be set")
}
@@ -73,33 +58,26 @@ func TestStore_CreateAndGet(t *testing.T) {
// Retrieve the session
retrieved, err := store.Session(ctx, session.ID)
if err != nil {
- t.Fatalf("GetSession(%v) unexpected error: %v", session.ID, err)
+ t.Fatalf("Session(%v) unexpected error: %v", session.ID, err)
}
if retrieved == nil {
- t.Fatal("GetSession() returned nil session")
+ t.Fatal("Session() returned nil session")
}
if retrieved.ID != session.ID {
- t.Errorf("GetSession() ID = %v, want %v", retrieved.ID, session.ID)
+ t.Errorf("Session() ID = %v, want %v", retrieved.ID, session.ID)
}
if retrieved.Title != session.Title {
- t.Errorf("GetSession() Title = %q, want %q", retrieved.Title, session.Title)
- }
- if retrieved.ModelName != session.ModelName {
- t.Errorf("GetSession() ModelName = %q, want %q", retrieved.ModelName, session.ModelName)
- }
- if retrieved.SystemPrompt != session.SystemPrompt {
- t.Errorf("GetSession() SystemPrompt = %q, want %q", retrieved.SystemPrompt, session.SystemPrompt)
+ t.Errorf("Session() Title = %q, want %q", retrieved.Title, session.Title)
}
}
// TestStore_CreateWithEmptyFields tests creating session with empty optional fields
func TestStore_CreateWithEmptyFields(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
- // Create session with empty title and system prompt
- session, err := store.CreateSession(ctx, "", "gemini-2.5-flash", "")
+ // Create session with empty title
+ session, err := store.CreateSession(ctx, "")
if err != nil {
t.Fatalf("CreateSession() with empty fields unexpected error: %v", err)
}
@@ -109,79 +87,65 @@ func TestStore_CreateWithEmptyFields(t *testing.T) {
if session.Title != "" {
t.Errorf("CreateSession() Title = %q, want empty string", session.Title)
}
- if session.ModelName != "gemini-2.5-flash" {
- t.Errorf("CreateSession() ModelName = %q, want %q", session.ModelName, "gemini-2.5-flash")
- }
- if session.SystemPrompt != "" {
- t.Errorf("CreateSession() SystemPrompt = %q, want empty string", session.SystemPrompt)
- }
// Retrieve should work
retrieved, err := store.Session(ctx, session.ID)
if err != nil {
- t.Fatalf("GetSession(%v) unexpected error: %v", session.ID, err)
+ t.Fatalf("Session(%v) unexpected error: %v", session.ID, err)
}
if retrieved.Title != "" {
- t.Errorf("GetSession() Title = %q, want empty string", retrieved.Title)
- }
- if retrieved.SystemPrompt != "" {
- t.Errorf("GetSession() SystemPrompt = %q, want empty string", retrieved.SystemPrompt)
+ t.Errorf("Session() Title = %q, want empty string", retrieved.Title)
}
}
// TestStore_ListSessions tests listing sessions with pagination
func TestStore_ListSessions_Integration(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create multiple sessions
for i := 0; i < 5; i++ {
- _, err := store.CreateSession(ctx,
- fmt.Sprintf("Session %d", i+1),
- "gemini-2.5-flash",
- "")
+ _, err := store.CreateSession(ctx, fmt.Sprintf("Session %d", i+1))
if err != nil {
t.Fatalf("CreateSession(%d) unexpected error: %v", i+1, err)
}
}
// List all sessions
- sessions, err := store.ListSessions(ctx, 10, 0)
+ sessions, err := store.Sessions(ctx, 10, 0)
if err != nil {
- t.Fatalf("ListSessions(10, 0) unexpected error: %v", err)
+ t.Fatalf("Sessions(10, 0) unexpected error: %v", err)
}
if len(sessions) < 5 {
- t.Errorf("ListSessions(10, 0) returned %d sessions, want at least 5", len(sessions))
+ t.Errorf("Sessions(10, 0) returned %d sessions, want at least 5", len(sessions))
}
// Test pagination - first 3
- sessions, err = store.ListSessions(ctx, 3, 0)
+ sessions, err = store.Sessions(ctx, 3, 0)
if err != nil {
- t.Fatalf("ListSessions(3, 0) unexpected error: %v", err)
+ t.Fatalf("Sessions(3, 0) unexpected error: %v", err)
}
if len(sessions) != 3 {
- t.Errorf("ListSessions(3, 0) returned %d sessions, want exactly 3", len(sessions))
+ t.Errorf("Sessions(3, 0) returned %d sessions, want exactly 3", len(sessions))
}
// Test pagination - next 2
- sessions, err = store.ListSessions(ctx, 3, 3)
+ sessions, err = store.Sessions(ctx, 3, 3)
if err != nil {
- t.Fatalf("ListSessions(3, 3) unexpected error: %v", err)
+ t.Fatalf("Sessions(3, 3) unexpected error: %v", err)
}
if len(sessions) < 2 {
- t.Errorf("ListSessions(3, 3) returned %d sessions, want at least 2", len(sessions))
+ t.Errorf("Sessions(3, 3) returned %d sessions, want at least 2", len(sessions))
}
}
// TestStore_DeleteSession tests deleting a session
func TestStore_DeleteSession_Integration(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a session
- session, err := store.CreateSession(ctx, "To Be Deleted", "gemini-2.5-flash", "")
+ session, err := store.CreateSession(ctx, "To Be Deleted")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
@@ -189,7 +153,7 @@ func TestStore_DeleteSession_Integration(t *testing.T) {
// Verify it exists
_, err = store.Session(ctx, session.ID)
if err != nil {
- t.Fatalf("GetSession(%v) before delete unexpected error: %v", session.ID, err)
+ t.Fatalf("Session(%v) before delete unexpected error: %v", session.ID, err)
}
// Delete the session
@@ -201,22 +165,17 @@ func TestStore_DeleteSession_Integration(t *testing.T) {
// Verify it no longer exists
_, err = store.Session(ctx, session.ID)
if err == nil {
- t.Errorf("GetSession(%v) after delete should return error", session.ID)
+ t.Errorf("Session(%v) after delete should return error", session.ID)
}
}
-// =============================================================================
-// Message Tests
-// =============================================================================
-
// TestStore_AddMessage tests adding messages to a session
func TestStore_AddMessage(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a session
- session, err := store.CreateSession(ctx, "Message Test", "gemini-2.5-flash", "")
+ session, err := store.CreateSession(ctx, "Message Test")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
@@ -250,35 +209,34 @@ func TestStore_AddMessage(t *testing.T) {
// Retrieve messages
messages, err := store.Messages(ctx, session.ID, 10, 0)
if err != nil {
- t.Fatalf("GetMessages(%v, 10, 0) unexpected error: %v", session.ID, err)
+ t.Fatalf("Messages(%v, 10, 0) unexpected error: %v", session.ID, err)
}
if len(messages) != 2 {
- t.Fatalf("GetMessages() returned %d messages, want 2", len(messages))
+ t.Fatalf("Messages() returned %d messages, want 2", len(messages))
}
// Verify order (should be chronological)
if messages[0].Role != string(ai.RoleUser) {
- t.Errorf("GetMessages()[0].Role = %q, want %q", messages[0].Role, string(ai.RoleUser))
+ t.Errorf("Messages()[0].Role = %q, want %q", messages[0].Role, string(ai.RoleUser))
}
if messages[0].Content[0].Text != "Hello, how are you?" {
- t.Errorf("GetMessages()[0].Content[0].Text = %q, want %q", messages[0].Content[0].Text, "Hello, how are you?")
+ t.Errorf("Messages()[0].Content[0].Text = %q, want %q", messages[0].Content[0].Text, "Hello, how are you?")
}
if messages[1].Role != string(ai.RoleModel) {
- t.Errorf("GetMessages()[1].Role = %q, want %q", messages[1].Role, string(ai.RoleModel))
+ t.Errorf("Messages()[1].Role = %q, want %q", messages[1].Role, string(ai.RoleModel))
}
if messages[1].Content[0].Text != "I'm doing well, thank you!" {
- t.Errorf("GetMessages()[1].Content[0].Text = %q, want %q", messages[1].Content[0].Text, "I'm doing well, thank you!")
+ t.Errorf("Messages()[1].Content[0].Text = %q, want %q", messages[1].Content[0].Text, "I'm doing well, thank you!")
}
}
// TestStore_GetMessages tests retrieving messages with pagination
func TestStore_GetMessages_Integration(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a session
- session, err := store.CreateSession(ctx, "Pagination Test", "gemini-2.5-flash", "")
+ session, err := store.CreateSession(ctx, "Pagination Test")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
@@ -302,36 +260,35 @@ func TestStore_GetMessages_Integration(t *testing.T) {
// Get first 5 messages
retrieved, err := store.Messages(ctx, session.ID, 5, 0)
if err != nil {
- t.Fatalf("GetMessages(%v, 5, 0) unexpected error: %v", session.ID, err)
+ t.Fatalf("Messages(%v, 5, 0) unexpected error: %v", session.ID, err)
}
if len(retrieved) != 5 {
- t.Errorf("GetMessages(%v, 5, 0) returned %d messages, want 5", session.ID, len(retrieved))
+ t.Errorf("Messages(%v, 5, 0) returned %d messages, want 5", session.ID, len(retrieved))
}
if retrieved[0].Content[0].Text != "Message 1" {
- t.Errorf("GetMessages(%v, 5, 0)[0].Content[0].Text = %q, want %q", session.ID, retrieved[0].Content[0].Text, "Message 1")
+ t.Errorf("Messages(%v, 5, 0)[0].Content[0].Text = %q, want %q", session.ID, retrieved[0].Content[0].Text, "Message 1")
}
// Get next 5 messages
retrieved, err = store.Messages(ctx, session.ID, 5, 5)
if err != nil {
- t.Fatalf("GetMessages(%v, 5, 5) unexpected error: %v", session.ID, err)
+ t.Fatalf("Messages(%v, 5, 5) unexpected error: %v", session.ID, err)
}
if len(retrieved) != 5 {
- t.Errorf("GetMessages(%v, 5, 5) returned %d messages, want 5", session.ID, len(retrieved))
+ t.Errorf("Messages(%v, 5, 5) returned %d messages, want 5", session.ID, len(retrieved))
}
if retrieved[0].Content[0].Text != "Message 6" {
- t.Errorf("GetMessages(%v, 5, 5)[0].Content[0].Text = %q, want %q", session.ID, retrieved[0].Content[0].Text, "Message 6")
+ t.Errorf("Messages(%v, 5, 5)[0].Content[0].Text = %q, want %q", session.ID, retrieved[0].Content[0].Text, "Message 6")
}
}
// TestStore_MessageOrdering tests that messages maintain chronological order
func TestStore_MessageOrdering(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a session
- session, err := store.CreateSession(ctx, "Ordering Test", "gemini-2.5-flash", "")
+ session, err := store.CreateSession(ctx, "Ordering Test")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
@@ -361,10 +318,10 @@ func TestStore_MessageOrdering(t *testing.T) {
// Retrieve all messages
retrieved, err := store.Messages(ctx, session.ID, 100, 0)
if err != nil {
- t.Fatalf("GetMessages(%v, 100, 0) unexpected error: %v", session.ID, err)
+ t.Fatalf("Messages(%v, 100, 0) unexpected error: %v", session.ID, err)
}
if len(retrieved) != 6 {
- t.Fatalf("GetMessages() returned %d messages, want 6", len(retrieved))
+ t.Fatalf("Messages() returned %d messages, want 6", len(retrieved))
}
// Verify order
@@ -392,12 +349,11 @@ func TestStore_MessageOrdering(t *testing.T) {
// TestStore_LargeMessageContent tests handling large message content
func TestStore_LargeMessageContent(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a session
- session, err := store.CreateSession(ctx, "Large Content Test", "gemini-2.5-flash", "")
+ session, err := store.CreateSession(ctx, "Large Content Test")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
@@ -420,24 +376,23 @@ func TestStore_LargeMessageContent(t *testing.T) {
// Retrieve and verify
messages, err := store.Messages(ctx, session.ID, 10, 0)
if err != nil {
- t.Fatalf("GetMessages(%v, 10, 0) unexpected error: %v", session.ID, err)
+ t.Fatalf("Messages(%v, 10, 0) unexpected error: %v", session.ID, err)
}
if len(messages) != 1 {
- t.Fatalf("GetMessages() returned %d messages, want 1", len(messages))
+ t.Fatalf("Messages() returned %d messages, want 1", len(messages))
}
if messages[0].Content[0].Text != largeText {
- t.Errorf("GetMessages()[0].Content[0].Text length = %d, want %d (large content not preserved)", len(messages[0].Content[0].Text), len(largeText))
+ t.Errorf("Messages()[0].Content[0].Text length = %d, want %d (large content not preserved)", len(messages[0].Content[0].Text), len(largeText))
}
}
// TestStore_DeleteSessionWithMessages tests that deleting a session also deletes messages
func TestStore_DeleteSessionWithMessages(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a session with messages
- session, err := store.CreateSession(ctx, "Cascade Delete Test", "gemini-2.5-flash", "")
+ session, err := store.CreateSession(ctx, "Cascade Delete Test")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
@@ -456,10 +411,10 @@ func TestStore_DeleteSessionWithMessages(t *testing.T) {
// Verify message exists
retrieved, err := store.Messages(ctx, session.ID, 10, 0)
if err != nil {
- t.Fatalf("GetMessages(%v, 10, 0) before delete unexpected error: %v", session.ID, err)
+ t.Fatalf("Messages(%v, 10, 0) before delete unexpected error: %v", session.ID, err)
}
if len(retrieved) != 1 {
- t.Errorf("GetMessages() before delete returned %d messages, want 1", len(retrieved))
+ t.Errorf("Messages() before delete returned %d messages, want 1", len(retrieved))
}
// Delete session
@@ -471,24 +426,19 @@ func TestStore_DeleteSessionWithMessages(t *testing.T) {
// Verify session is deleted
_, err = store.Session(ctx, session.ID)
if err == nil {
- t.Error("GetSession() after delete should return error")
+ t.Error("Session() after delete should return error")
}
}
-// =============================================================================
-// Race Condition Tests
-// =============================================================================
-
// TestStore_ConcurrentSessionCreation tests that multiple goroutines can create
// sessions simultaneously without data corruption or race conditions.
func TestStore_ConcurrentSessionCreation(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
const numGoroutines = 10
var wg sync.WaitGroup
- errors := make(chan error, numGoroutines)
+ errs := make(chan error, numGoroutines)
sessionIDs := make(chan string, numGoroutines)
// Create sessions concurrently
@@ -497,9 +447,9 @@ func TestStore_ConcurrentSessionCreation(t *testing.T) {
go func(id int) {
defer wg.Done()
title := fmt.Sprintf("Race-Session-%d", id)
- session, err := store.CreateSession(ctx, title, "test-model", "test-prompt")
+ session, err := store.CreateSession(ctx, title)
if err != nil {
- errors <- fmt.Errorf("goroutine %d: %w", id, err)
+ errs <- fmt.Errorf("goroutine %d: %w", id, err)
return
}
sessionIDs <- session.ID.String()
@@ -507,12 +457,12 @@ func TestStore_ConcurrentSessionCreation(t *testing.T) {
}
wg.Wait()
- close(errors)
+ close(errs)
close(sessionIDs)
// Check for errors
var errCount int
- for err := range errors {
+ for err := range errs {
t.Errorf("concurrent creation error: %v", err)
errCount++
}
@@ -530,18 +480,16 @@ func TestStore_ConcurrentSessionCreation(t *testing.T) {
if len(ids) != expectedCount {
t.Errorf("created %d unique sessions, want %d", len(ids), expectedCount)
}
- t.Logf("Successfully created %d sessions concurrently", len(ids))
}
// TestStore_ConcurrentHistoryUpdate tests that multiple goroutines can add
// messages to the same session without data corruption.
func TestStore_ConcurrentHistoryUpdate(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a test session
- session, err := store.CreateSession(ctx, "Race-History-Test", "", "")
+ session, err := store.CreateSession(ctx, "Race-History-Test")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
@@ -549,7 +497,7 @@ func TestStore_ConcurrentHistoryUpdate(t *testing.T) {
const numGoroutines = 10
const messagesPerGoroutine = 5
var wg sync.WaitGroup
- errors := make(chan error, numGoroutines)
+ errs := make(chan error, numGoroutines)
var successCount atomic.Int32
// Add messages concurrently
@@ -567,7 +515,7 @@ func TestStore_ConcurrentHistoryUpdate(t *testing.T) {
}
if err := store.AddMessages(ctx, session.ID, messages); err != nil {
- errors <- fmt.Errorf("goroutine %d: %w", goroutineID, err)
+ errs <- fmt.Errorf("goroutine %d: %w", goroutineID, err)
return
}
successCount.Add(1)
@@ -575,22 +523,22 @@ func TestStore_ConcurrentHistoryUpdate(t *testing.T) {
}
wg.Wait()
- close(errors)
+ close(errs)
// Check for errors
- for err := range errors {
+ for err := range errs {
t.Errorf("concurrent history update error: %v", err)
}
// Verify message integrity
allMessages, err := store.Messages(ctx, session.ID, 1000, 0)
if err != nil {
- t.Fatalf("GetMessages(%v, 1000, 0) unexpected error: %v", session.ID, err)
+ t.Fatalf("Messages(%v, 1000, 0) unexpected error: %v", session.ID, err)
}
expectedCount := int(successCount.Load()) * messagesPerGoroutine
if len(allMessages) != expectedCount {
- t.Errorf("GetMessages() returned %d messages, want %d", len(allMessages), expectedCount)
+ t.Errorf("Messages() returned %d messages, want %d", len(allMessages), expectedCount)
}
// Verify sequence numbers are unique and sequential
@@ -608,19 +556,16 @@ func TestStore_ConcurrentHistoryUpdate(t *testing.T) {
t.Errorf("missing sequence number: %d", i)
}
}
-
- t.Logf("Successfully added %d messages from %d goroutines", len(allMessages), successCount.Load())
}
// TestStore_ConcurrentLoadAndSaveHistory tests simultaneous load and save
// operations on the same session.
func TestStore_ConcurrentLoadAndSaveHistory(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a test session with initial messages
- session, err := store.CreateSession(ctx, "Race-LoadSave-Test", "", "")
+ session, err := store.CreateSession(ctx, "Race-LoadSave-Test")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
@@ -653,9 +598,9 @@ func TestStore_ConcurrentLoadAndSaveHistory(t *testing.T) {
return
}
// Verify we got at least the initial messages
- if len(history.Messages()) < 2 {
+ if len(history) < 2 {
loadErrors <- fmt.Errorf("load goroutine %d: expected at least 2 messages, got %d",
- id, len(history.Messages()))
+ id, len(history))
}
}(i)
@@ -687,22 +632,19 @@ func TestStore_ConcurrentLoadAndSaveHistory(t *testing.T) {
// Verify final state
finalHistory, err := store.History(ctx, sessionID)
if err != nil {
- t.Fatalf("GetHistory(%v) final state unexpected error: %v", sessionID, err)
+ t.Fatalf("History(%v) final state unexpected error: %v", sessionID, err)
}
// Should have at least initial messages + some concurrent messages
- if len(finalHistory.Messages()) < 2 {
- t.Errorf("GetHistory() final state returned %d messages, want at least 2", len(finalHistory.Messages()))
+ if len(finalHistory) < 2 {
+ t.Errorf("History() final state returned %d messages, want at least 2", len(finalHistory))
}
-
- t.Logf("Final history has %d messages after concurrent load/save", len(finalHistory.Messages()))
}
// TestStore_ConcurrentSessionDeletion tests that deleting sessions while
// other operations are in progress doesn't cause crashes or data corruption.
func TestStore_ConcurrentSessionDeletion(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
const numSessions = 5
@@ -710,7 +652,7 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) {
// Create test sessions
for i := 0; i < numSessions; i++ {
- session, err := store.CreateSession(ctx, fmt.Sprintf("Race-Delete-Test-%d", i), "", "")
+ session, err := store.CreateSession(ctx, fmt.Sprintf("Race-Delete-Test-%d", i))
if err != nil {
t.Fatalf("CreateSession() for session %d unexpected error: %v", i, err)
}
@@ -744,7 +686,7 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) {
// List goroutine
go func() {
defer wg.Done()
- _, _ = store.ListSessions(ctx, 100, 0)
+ _, _ = store.Sessions(ctx, 100, 0)
}()
// Get goroutine
@@ -757,9 +699,9 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) {
wg.Wait()
// Verify all sessions are deleted
- remaining, err := store.ListSessions(ctx, 100, 0)
+ remaining, err := store.Sessions(ctx, 100, 0)
if err != nil {
- t.Fatalf("ListSessions(100, 0) after concurrent deletion unexpected error: %v", err)
+ t.Fatalf("Sessions(100, 0) after concurrent deletion unexpected error: %v", err)
}
for _, session := range sessions {
@@ -775,8 +717,6 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) {
_ = store.DeleteSession(ctx, session.ID)
}
}
-
- t.Log("Concurrent deletion test completed without crashes")
}
// TestStore_RaceDetector is a comprehensive test designed to trigger
@@ -784,12 +724,11 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) {
//
// Run with: go test -race -tags=integration ./internal/session/...
func TestStore_RaceDetector(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a shared session
- session, err := store.CreateSession(ctx, "Race-Detector-Test", "", "")
+ session, err := store.CreateSession(ctx, "Race-Detector-Test")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
@@ -833,20 +772,18 @@ func TestStore_RaceDetector(t *testing.T) {
}
wg.Wait()
- t.Log("Race detector test completed - if no race detected, Store is thread-safe")
}
// TestStore_ConcurrentWrites tests concurrent writes to different sessions
func TestStore_ConcurrentWrites(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create multiple sessions
numSessions := 5
sessions := make([]*Session, numSessions)
for i := 0; i < numSessions; i++ {
- session, err := store.CreateSession(ctx, fmt.Sprintf("Concurrent Session %d", i+1), "gemini-2.5-flash", "")
+ session, err := store.CreateSession(ctx, fmt.Sprintf("Concurrent Session %d", i+1))
if err != nil {
t.Fatalf("CreateSession() for session %d unexpected error: %v", i+1, err)
}
@@ -855,7 +792,7 @@ func TestStore_ConcurrentWrites(t *testing.T) {
// Concurrently write messages to different sessions
var wg sync.WaitGroup
- errors := make(chan error, numSessions*10)
+ errs := make(chan error, numSessions*10)
for i := 0; i < numSessions; i++ {
sessionID := sessions[i].ID
@@ -872,17 +809,17 @@ func TestStore_ConcurrentWrites(t *testing.T) {
}
if err := store.AddMessages(ctx, sid, []*Message{message}); err != nil {
- errors <- err
+ errs <- err
}
}
}(sessionID, i)
}
wg.Wait()
- close(errors)
+ close(errs)
// Check for errors
- for err := range errors {
+ for err := range errs {
t.Errorf("Concurrent write error: %v", err)
}
@@ -890,7 +827,7 @@ func TestStore_ConcurrentWrites(t *testing.T) {
for i, session := range sessions {
messages, err := store.Messages(ctx, session.ID, 100, 0)
if err != nil {
- t.Fatalf("GetMessages(%v, 100, 0) for session %d unexpected error: %v", session.ID, i+1, err)
+ t.Fatalf("Messages(%v, 100, 0) for session %d unexpected error: %v", session.ID, i+1, err)
}
if len(messages) != 10 {
t.Errorf("Session %d has %d messages, want 10", i+1, len(messages))
@@ -898,28 +835,23 @@ func TestStore_ConcurrentWrites(t *testing.T) {
}
}
-// =============================================================================
-// SQL Injection Prevention Tests
-// =============================================================================
-
// TestStore_SQLInjectionPrevention verifies that SQL injection attacks are blocked.
// Session store uses sqlc parameterized queries which should prevent all injection.
func TestStore_SQLInjectionPrevention(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// First, create a legitimate session
- legitSession, err := store.CreateSession(ctx, "Legitimate Session", "", "")
+ legitSession, err := store.CreateSession(ctx, "Legitimate Session")
if err != nil {
t.Fatalf("CreateSession() for legitimate session unexpected error: %v", err)
}
t.Logf("Created legitimate session: %s", legitSession.ID)
// Count sessions before attacks
- sessions, err := store.ListSessions(ctx, 100, 0)
+ sessions, err := store.Sessions(ctx, 100, 0)
if err != nil {
- t.Fatalf("ListSessions(100, 0) before attacks unexpected error: %v", err)
+ t.Fatalf("Sessions(100, 0) before attacks unexpected error: %v", err)
}
t.Logf("Initial session count: %d", len(sessions))
@@ -957,7 +889,7 @@ func TestStore_SQLInjectionPrevention(t *testing.T) {
for _, tc := range maliciousTitles {
t.Run("title_"+tc.name, func(t *testing.T) {
// Attempt SQL injection via session title
- session, err := store.CreateSession(ctx, tc.title, "", "")
+ session, err := store.CreateSession(ctx, tc.title)
// Should either succeed (with escaped title) or fail safely
if err != nil {
@@ -971,42 +903,12 @@ func TestStore_SQLInjectionPrevention(t *testing.T) {
})
}
- // SQL injection via model name
- maliciousModels := []string{
- "'; DROP TABLE sessions; --",
- "model' UNION SELECT password FROM users --",
- }
-
- for i, model := range maliciousModels {
- t.Run("model_"+string(rune('a'+i)), func(t *testing.T) {
- session, err := store.CreateSession(ctx, "Test", model, "")
- if err == nil {
- _ = store.DeleteSession(ctx, session.ID)
- }
- })
- }
-
- // SQL injection via system prompt
- maliciousPrompts := []string{
- "'; DELETE FROM session_messages; --",
- "You are helpful'); DROP TABLE sessions; --",
- }
-
- for i, prompt := range maliciousPrompts {
- t.Run("prompt_"+string(rune('a'+i)), func(t *testing.T) {
- session, err := store.CreateSession(ctx, "Test", "", prompt)
- if err == nil {
- _ = store.DeleteSession(ctx, session.ID)
- }
- })
- }
-
// Verify database integrity
t.Run("verify database integrity", func(t *testing.T) {
// Sessions table should still exist
- sessions, err := store.ListSessions(ctx, 100, 0)
+ sessions, err := store.Sessions(ctx, 100, 0)
if err != nil {
- t.Fatalf("ListSessions(100, 0) after attacks unexpected error: %v (sessions table should still exist)", err)
+ t.Fatalf("Sessions(100, 0) after attacks unexpected error: %v (sessions table should still exist)", err)
}
// Legitimate session should still exist
@@ -1024,22 +926,21 @@ func TestStore_SQLInjectionPrevention(t *testing.T) {
// Should be able to load the session
loaded, err := store.Session(ctx, legitSession.ID)
if err != nil {
- t.Fatalf("GetSession(%v) after attacks unexpected error: %v", legitSession.ID, err)
+ t.Fatalf("Session(%v) after attacks unexpected error: %v", legitSession.ID, err)
}
if loaded.Title != "Legitimate Session" {
- t.Errorf("GetSession(%v) Title = %q, want %q", legitSession.ID, loaded.Title, "Legitimate Session")
+ t.Errorf("Session(%v) Title = %q, want %q", legitSession.ID, loaded.Title, "Legitimate Session")
}
})
}
// TestStore_SQLInjectionViaSessionID tests injection through session IDs.
func TestStore_SQLInjectionViaSessionID(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a test session
- session, err := store.CreateSession(ctx, "Test Session", "", "")
+ session, err := store.CreateSession(ctx, "Test Session")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
@@ -1072,82 +973,70 @@ func TestStore_SQLInjectionViaSessionID(t *testing.T) {
// Verify the test session still exists
loaded, err := store.Session(ctx, session.ID)
if err != nil {
- t.Fatalf("GetSession(%v) after malicious ID attempts unexpected error: %v", session.ID, err)
+ t.Fatalf("Session(%v) after malicious ID attempts unexpected error: %v", session.ID, err)
}
if loaded.Title != "Test Session" {
- t.Errorf("GetSession(%v) Title = %q, want %q", session.ID, loaded.Title, "Test Session")
+ t.Errorf("Session(%v) Title = %q, want %q", session.ID, loaded.Title, "Test Session")
}
}
-// =============================================================================
-// Error Handling Tests
-// =============================================================================
-
-// TestStore_GetHistory_SessionNotFound verifies that GetHistory returns ErrSessionNotFound
-// sentinel error when the session doesn't exist. This test validates the A3 fix from
-// Proposal 056 - proper sentinel error propagation without double-wrapping.
+// TestStore_GetHistory_SessionNotFound verifies that History returns ErrNotFound
+// sentinel error when the session doesn't exist.
func TestStore_GetHistory_SessionNotFound(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Use a non-existent session ID
nonExistentID := uuid.New()
- // GetHistory should return ErrSessionNotFound
+ // GetHistory should return ErrNotFound
_, err := store.History(ctx, nonExistentID)
if err == nil {
- t.Fatal("GetHistory() with non-existent session should return error")
+ t.Fatal("History() with non-existent session should return error")
}
- // Verify the error is the sentinel ErrSessionNotFound (errors.Is check)
- if !errors.Is(err, ErrSessionNotFound) {
- t.Errorf("GetHistory(%v) error = %v (type: %T), want ErrSessionNotFound sentinel", nonExistentID, err, err)
+ // Verify the error is the sentinel ErrNotFound (errors.Is check)
+ if !errors.Is(err, ErrNotFound) {
+ t.Errorf("History(%v) error = %v (type: %T), want ErrNotFound sentinel", nonExistentID, err, err)
}
// Verify error message is not double-wrapped
errStr := err.Error()
if strings.Contains(errStr, "session not found: session not found") {
- t.Errorf("GetHistory(%v) error message is double-wrapped: %v", nonExistentID, err)
+ t.Errorf("History(%v) error message is double-wrapped: %v", nonExistentID, err)
}
}
-// TestStore_GetSession_NotFound verifies GetSession returns ErrSessionNotFound sentinel.
+// TestStore_GetSession_NotFound verifies GetSession returns ErrNotFound sentinel.
func TestStore_GetSession_NotFound(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
nonExistentID := uuid.New()
_, err := store.Session(ctx, nonExistentID)
if err == nil {
- t.Fatal("GetSession() with non-existent session should return error")
+ t.Fatal("Session() with non-existent session should return error")
}
- if !errors.Is(err, ErrSessionNotFound) {
- t.Errorf("GetSession(%v) error = %v, want ErrSessionNotFound sentinel", nonExistentID, err)
+ if !errors.Is(err, ErrNotFound) {
+ t.Errorf("Session(%v) error = %v, want ErrNotFound sentinel", nonExistentID, err)
}
}
-// =============================================================================
-// SQL Injection Prevention Tests
-// =============================================================================
-
// TestStore_SQLInjectionViaMessageContent tests injection through message content.
func TestStore_SQLInjectionViaMessageContent(t *testing.T) {
- store, cleanup := setupIntegrationTest(t)
- defer cleanup()
+ store := setupIntegrationTest(t)
ctx := context.Background()
// Create a test session
- session, err := store.CreateSession(ctx, "Message Test", "", "")
+ session, err := store.CreateSession(ctx, "Message Test")
if err != nil {
t.Fatalf("CreateSession() unexpected error: %v", err)
}
// Malicious message content
maliciousMessages := []string{
- "'; DROP TABLE session_messages; --",
+ "'; DROP TABLE messages; --",
"Hello'); DELETE FROM sessions WHERE '1'='1",
"Test' UNION SELECT password FROM users --",
"Message\x00'; DROP TABLE sessions; --",
@@ -1174,15 +1063,15 @@ func TestStore_SQLInjectionViaMessageContent(t *testing.T) {
// Session should still exist
_, err := store.Session(ctx, session.ID)
if err != nil {
- t.Fatalf("GetSession(%v) after malicious message attempts unexpected error: %v (session should still exist)", session.ID, err)
+ t.Fatalf("Session(%v) after malicious message attempts unexpected error: %v (session should still exist)", session.ID, err)
}
// Should be able to load messages
sessionID := session.ID
history, err := store.History(ctx, sessionID)
if err != nil {
- t.Fatalf("GetHistory(%v) after malicious message attempts unexpected error: %v (should be able to load history)", sessionID, err)
+ t.Fatalf("History(%v) after malicious message attempts unexpected error: %v (should be able to load history)", sessionID, err)
}
- t.Logf("loaded history with %d messages", len(history.Messages()))
+ t.Logf("loaded history with %d messages", len(history))
})
}
diff --git a/internal/session/session.go b/internal/session/session.go
new file mode 100644
index 0000000..4975c54
--- /dev/null
+++ b/internal/session/session.go
@@ -0,0 +1,31 @@
+package session
+
+import (
+ "errors"
+ "time"
+
+ "github.com/firebase/genkit/go/ai"
+ "github.com/google/uuid"
+)
+
+// ErrNotFound indicates the requested session does not exist in the database.
+var ErrNotFound = errors.New("session not found")
+
+// Session represents a conversation session (application-level type).
+type Session struct {
+ ID uuid.UUID
+ Title string
+ CreatedAt time.Time
+ UpdatedAt time.Time
+}
+
+// Message represents a single conversation message (application-level type).
+// Content field stores Genkit's ai.Part slice, serialized as JSONB in database.
+type Message struct {
+ ID uuid.UUID
+ SessionID uuid.UUID
+ Role string // "user" | "assistant" | "system" | "tool"
+ Content []*ai.Part // Genkit Part slice (stored as JSONB)
+ SequenceNumber int
+ CreatedAt time.Time
+}
diff --git a/internal/session/state.go b/internal/session/state.go
index 91a1876..db30f7f 100644
--- a/internal/session/state.go
+++ b/internal/session/state.go
@@ -64,7 +64,7 @@ func getStateFilePath() (string, error) {
// LoadCurrentSessionID loads the currently active session ID from local state file.
//
-// Acquires shared file lock to allow concurrent reads but prevent writes during read.
+// Acquires exclusive file lock to prevent concurrent access during read.
//
// Returns:
// - *uuid.UUID: Current session ID (nil if no current session)
@@ -146,9 +146,6 @@ func SaveCurrentSessionID(sessionID uuid.UUID) error {
return fmt.Errorf("saving session: %w", err)
}
- // Clean up any orphaned temp files from previous crashed sessions
- _ = cleanupOrphanedTempFiles()
-
// Acquire file lock to prevent concurrent access
lock, err := acquireStateLock()
if err != nil {
@@ -156,6 +153,9 @@ func SaveCurrentSessionID(sessionID uuid.UUID) error {
}
defer func() { _ = lock.Unlock() }()
+ // Clean up any orphaned temp files from previous crashed sessions (under lock)
+ _ = cleanupOrphanedTempFiles()
+
// Write to temporary file first (atomic write pattern)
tmpFile := filePath + ".tmp"
if err := os.WriteFile(tmpFile, []byte(sessionID.String()), 0o600); err != nil {
diff --git a/internal/session/state_test.go b/internal/session/state_test.go
index 6a46db1..8007259 100644
--- a/internal/session/state_test.go
+++ b/internal/session/state_test.go
@@ -26,19 +26,19 @@ func TestGetStateFilePath(t *testing.T) {
// Verify path is absolute
if !filepath.IsAbs(path) {
- t.Errorf("getStateFilePath() returned relative path: %s", path)
+ t.Errorf("getStateFilePath() returned relative path: %q", path)
}
// Verify path uses temp directory
rel, err := filepath.Rel(tempDir, path)
if err != nil || strings.HasPrefix(rel, "..") {
- t.Errorf("getStateFilePath() = %s, expected to be within %s", path, tempDir)
+ t.Errorf("getStateFilePath() = %q, want within %q", path, tempDir)
}
// Verify directory was created
dir := filepath.Dir(path)
if _, err := os.Stat(dir); os.IsNotExist(err) {
- t.Errorf("getStateFilePath() did not create directory: %s", dir)
+ t.Errorf("getStateFilePath() did not create directory: %q", dir)
}
}
@@ -64,7 +64,6 @@ func TestSaveAndLoadCurrentSessionID(t *testing.T) {
if loadedID == nil {
t.Fatal("LoadCurrentSessionID() returned nil")
- return
}
if *loadedID != testID {
@@ -110,7 +109,6 @@ func TestSaveAndLoadCurrentSessionID(t *testing.T) {
if loadedID == nil {
t.Fatal("LoadCurrentSessionID() returned nil")
- return
}
if *loadedID != secondID {
diff --git a/internal/session/store.go b/internal/session/store.go
index a48b28b..edb13e2 100644
--- a/internal/session/store.go
+++ b/internal/session/store.go
@@ -12,6 +12,7 @@ import (
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
+ "github.com/koopa0/koopa/internal/config"
"github.com/koopa0/koopa/internal/sqlc"
)
@@ -52,32 +53,19 @@ func New(queries *sqlc.Queries, pool *pgxpool.Pool, logger *slog.Logger) *Store
// Parameters:
// - ctx: Context for the operation
// - title: Session title (empty string = no title)
-// - modelName: Model name used for this session (empty string = default)
-// - systemPrompt: System prompt for this session (empty string = default)
//
// Returns:
// - *Session: Created session with generated UUID
// - error: If creation fails
-func (s *Store) CreateSession(ctx context.Context, title, modelName, systemPrompt string) (*Session, error) {
- // Convert empty strings to nil for nullable fields
- var titlePtr, modelNamePtr, systemPromptPtr *string
+func (s *Store) CreateSession(ctx context.Context, title string) (*Session, error) {
+ var titlePtr *string
if title != "" {
titlePtr = &title
}
- if modelName != "" {
- modelNamePtr = &modelName
- }
- if systemPrompt != "" {
- systemPromptPtr = &systemPrompt
- }
- sqlcSession, err := s.queries.CreateSession(ctx, sqlc.CreateSessionParams{
- Title: titlePtr,
- ModelName: modelNamePtr,
- SystemPrompt: systemPromptPtr,
- })
+ sqlcSession, err := s.queries.CreateSession(ctx, titlePtr)
if err != nil {
- return nil, fmt.Errorf("failed to create session: %w", err)
+ return nil, fmt.Errorf("creating session: %w", err)
}
session := s.sqlcSessionToSession(sqlcSession)
@@ -86,21 +74,21 @@ func (s *Store) CreateSession(ctx context.Context, title, modelName, systemPromp
}
// Session retrieves a session by ID.
-// Returns ErrSessionNotFound if the session does not exist.
+// Returns ErrNotFound if the session does not exist.
func (s *Store) Session(ctx context.Context, sessionID uuid.UUID) (*Session, error) {
sqlcSession, err := s.queries.Session(ctx, sessionID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
// Return sentinel error directly (no wrapping per reviewer guidance)
- return nil, ErrSessionNotFound
+ return nil, ErrNotFound
}
- return nil, fmt.Errorf("failed to get session %s: %w", sessionID, err)
+ return nil, fmt.Errorf("getting session %s: %w", sessionID, err)
}
return s.sqlcSessionToSession(sqlcSession), nil
}
-// ListSessions lists sessions with pagination, ordered by updated_at descending.
+// Sessions lists sessions with pagination, ordered by updated_at descending.
//
// Parameters:
// - ctx: Context for the operation
@@ -110,13 +98,13 @@ func (s *Store) Session(ctx context.Context, sessionID uuid.UUID) (*Session, err
// Returns:
// - []*Session: List of sessions
// - error: If listing fails
-func (s *Store) ListSessions(ctx context.Context, limit, offset int32) ([]*Session, error) {
- sqlcSessions, err := s.queries.ListSessions(ctx, sqlc.ListSessionsParams{
+func (s *Store) Sessions(ctx context.Context, limit, offset int32) ([]*Session, error) {
+ sqlcSessions, err := s.queries.Sessions(ctx, sqlc.SessionsParams{
ResultLimit: limit,
ResultOffset: offset,
})
if err != nil {
- return nil, fmt.Errorf("failed to list sessions: %w", err)
+ return nil, fmt.Errorf("listing sessions: %w", err)
}
sessions := make([]*Session, 0, len(sqlcSessions))
@@ -128,35 +116,6 @@ func (s *Store) ListSessions(ctx context.Context, limit, offset int32) ([]*Sessi
return sessions, nil
}
-// ListSessionsWithMessages lists sessions that have messages or titles.
-// This is used for sidebar display where empty "New Chat" placeholder sessions should be hidden.
-//
-// Parameters:
-// - ctx: Context for the operation
-// - limit: Maximum number of sessions to return
-// - offset: Number of sessions to skip (for pagination)
-//
-// Returns:
-// - []*Session: List of sessions with messages or titles
-// - error: If listing fails
-func (s *Store) ListSessionsWithMessages(ctx context.Context, limit, offset int32) ([]*Session, error) {
- sqlcSessions, err := s.queries.ListSessionsWithMessages(ctx, sqlc.ListSessionsWithMessagesParams{
- ResultLimit: limit,
- ResultOffset: offset,
- })
- if err != nil {
- return nil, fmt.Errorf("failed to list sessions with messages: %w", err)
- }
-
- sessions := make([]*Session, 0, len(sqlcSessions))
- for i := range sqlcSessions {
- sessions = append(sessions, s.sqlcSessionToSession(sqlcSessions[i]))
- }
-
- s.logger.Debug("listed sessions with messages", "count", len(sessions), "limit", limit, "offset", offset)
- return sessions, nil
-}
-
// DeleteSession deletes a session and all its messages (CASCADE).
//
// Parameters:
@@ -167,7 +126,7 @@ func (s *Store) ListSessionsWithMessages(ctx context.Context, limit, offset int3
// - error: If deletion fails
func (s *Store) DeleteSession(ctx context.Context, sessionID uuid.UUID) error {
if err := s.queries.DeleteSession(ctx, sessionID); err != nil {
- return fmt.Errorf("failed to delete session %s: %w", sessionID, err)
+ return fmt.Errorf("deleting session %s: %w", sessionID, err)
}
s.logger.Debug("deleted session", "id", sessionID)
@@ -194,7 +153,7 @@ func (s *Store) UpdateSessionTitle(ctx context.Context, sessionID uuid.UUID, tit
SessionID: sessionID,
Title: titlePtr,
}); err != nil {
- return fmt.Errorf("failed to update session title %s: %w", sessionID, err)
+ return fmt.Errorf("updating session title %s: %w", sessionID, err)
}
s.logger.Debug("updated session title", "id", sessionID, "title", title)
@@ -224,13 +183,13 @@ func (s *Store) AddMessages(ctx context.Context, sessionID uuid.UUID, messages [
// Database pool is required for transactional operations
// Tests should use pgxmock or Testcontainers for proper transaction testing
if s.pool == nil {
- return fmt.Errorf("database pool required for AddMessages: use pgxmock or real database for testing")
+ return fmt.Errorf("database pool is required for transactional operations")
}
// Begin transaction for atomicity
tx, err := s.pool.Begin(ctx)
if err != nil {
- return fmt.Errorf("failed to begin transaction: %w", err)
+ return fmt.Errorf("beginning transaction: %w", err)
}
// Rollback if not committed - log any rollback errors for debugging
defer func() {
@@ -246,16 +205,13 @@ func (s *Store) AddMessages(ctx context.Context, sessionID uuid.UUID, messages [
// This SELECT ... FOR UPDATE ensures that only one transaction can modify
// this session at a time, preventing race conditions on sequence numbers
if _, err = txQuerier.LockSession(ctx, sessionID); err != nil {
- return fmt.Errorf("failed to lock session: %w", err)
+ return fmt.Errorf("locking session: %w", err)
}
// 1. Get current max sequence number within transaction
- maxSeq, err := txQuerier.GetMaxSequenceNumber(ctx, sessionID)
+ maxSeq, err := txQuerier.MaxSequenceNumber(ctx, sessionID)
if err != nil {
- // If session doesn't exist yet or no messages, start from 0
- s.logger.Debug("no existing messages, starting from sequence 0",
- "session_id", sessionID)
- maxSeq = 0
+ return fmt.Errorf("getting max sequence number: %w", err)
}
// 2. Insert messages in batch within transaction
@@ -271,7 +227,7 @@ func (s *Store) AddMessages(ctx context.Context, sessionID uuid.UUID, messages [
contentJSON, marshalErr := json.Marshal(msg.Content)
if marshalErr != nil {
// Transaction will be rolled back by defer
- return fmt.Errorf("failed to marshal message content at index %d: %w", i, marshalErr)
+ return fmt.Errorf("marshaling message content at index %d: %w", i, marshalErr)
}
// Calculate sequence number (maxSeq is now int32 from sqlc)
@@ -285,24 +241,19 @@ func (s *Store) AddMessages(ctx context.Context, sessionID uuid.UUID, messages [
SequenceNumber: seqNum,
}); err != nil {
// Transaction will be rolled back by defer
- return fmt.Errorf("failed to insert message %d: %w", i, err)
+ return fmt.Errorf("inserting message %d: %w", i, err)
}
}
- // 3. Update session's updated_at and message_count within transaction
- // Safe conversion: len(messages) is bounded by practical limits (< millions)
- newCount := maxSeq + int32(len(messages)) // #nosec G115 -- len bounded by practical message limits
- if err = txQuerier.UpdateSessionUpdatedAt(ctx, sqlc.UpdateSessionUpdatedAtParams{
- MessageCount: &newCount,
- SessionID: sessionID,
- }); err != nil {
+ // 3. Update session's updated_at within transaction
+ if err = txQuerier.UpdateSessionUpdatedAt(ctx, sessionID); err != nil {
// Transaction will be rolled back by defer
- return fmt.Errorf("failed to update session metadata: %w", err)
+ return fmt.Errorf("updating session metadata: %w", err)
}
// 4. Commit transaction
if err := tx.Commit(ctx); err != nil {
- return fmt.Errorf("failed to commit transaction: %w", err)
+ return fmt.Errorf("committing transaction: %w", err)
}
s.logger.Debug("added messages", "session_id", sessionID, "count", len(messages))
@@ -327,14 +278,14 @@ func (s *Store) Messages(ctx context.Context, sessionID uuid.UUID, limit, offset
ResultOffset: offset,
})
if err != nil {
- return nil, fmt.Errorf("failed to get messages for session %s: %w", sessionID, err)
+ return nil, fmt.Errorf("getting messages for session %s: %w", sessionID, err)
}
messages := make([]*Message, 0, len(sqlcMessages))
for i := range sqlcMessages {
msg, err := s.sqlcMessageToMessage(sqlcMessages[i])
if err != nil {
- s.logger.Warn("failed to unmarshal message content",
+ s.logger.Warn("skipping malformed message",
"message_id", sqlcMessages[i].ID,
"error", err)
continue // Skip malformed messages
@@ -382,7 +333,7 @@ func (s *Store) AppendMessages(ctx context.Context, sessionID uuid.UUID, message
// Use AddMessages
if err := s.AddMessages(ctx, sessionID, sessionMessages); err != nil {
- return fmt.Errorf("failed to append messages: %w", err)
+ return err // AddMessages already wraps with context
}
s.logger.Debug("appended messages",
@@ -391,32 +342,21 @@ func (s *Store) AppendMessages(ctx context.Context, sessionID uuid.UUID, message
return nil
}
-// History retrieves the conversation history for a session.
-// Used by chat.Chat agent for session management.
-//
-// Parameters:
-// - ctx: Context for the operation
-// - sessionID: Session UUID
-//
-// Returns:
-// - *History: Conversation history
-// - error: If retrieval fails
-func (s *Store) History(ctx context.Context, sessionID uuid.UUID) (*History, error) {
+// History retrieves the conversation history for a session as a slice of ai.Message.
+// Used by chat.Agent for session management.
+func (s *Store) History(ctx context.Context, sessionID uuid.UUID) ([]*ai.Message, error) {
// Verify session exists before loading history
if _, err := s.Session(ctx, sessionID); err != nil {
- if errors.Is(err, ErrSessionNotFound) {
+ if errors.Is(err, ErrNotFound) {
return nil, err // Sentinel propagates unchanged
}
- return nil, fmt.Errorf("get history for session %s: %w", sessionID, err)
+ return nil, fmt.Errorf("getting history for session %s: %w", sessionID, err)
}
- // Use default limit for history retrieval
- limit := DefaultHistoryLimit
-
// Retrieve messages
- messages, err := s.Messages(ctx, sessionID, limit, 0)
+ messages, err := s.Messages(ctx, sessionID, config.DefaultMaxHistoryMessages, 0)
if err != nil {
- return nil, fmt.Errorf("failed to load history: %w", err)
+ return nil, fmt.Errorf("loading history: %w", err)
}
// Convert to ai.Message
@@ -430,11 +370,42 @@ func (s *Store) History(ctx context.Context, sessionID uuid.UUID) (*History, err
s.logger.Debug("loaded history",
"session_id", sessionID,
- "message_count", len(messages))
+ "count", len(messages))
- history := NewHistory()
- history.SetMessages(aiMessages)
- return history, nil
+ return aiMessages, nil
+}
+
+// ResolveCurrentSession loads the active session from the state file,
+// validates it exists in the database, and creates a new session if needed.
+// Returns the session ID.
+func (s *Store) ResolveCurrentSession(ctx context.Context) (uuid.UUID, error) {
+ //nolint:contextcheck // LoadCurrentSessionID manages its own lock timeout context
+ savedID, err := LoadCurrentSessionID()
+ if err != nil {
+ return uuid.Nil, fmt.Errorf("loading current session: %w", err)
+ }
+
+ if savedID != nil {
+ if _, err = s.Session(ctx, *savedID); err == nil {
+ return *savedID, nil
+ }
+ if !errors.Is(err, ErrNotFound) {
+ return uuid.Nil, fmt.Errorf("validating session: %w", err)
+ }
+ }
+
+ newSess, err := s.CreateSession(ctx, "")
+ if err != nil {
+ return uuid.Nil, fmt.Errorf("creating session: %w", err)
+ }
+
+ // best-effort: state file is non-critical, session already created in DB
+ //nolint:contextcheck // SaveCurrentSessionID manages its own lock timeout context
+ if saveErr := SaveCurrentSessionID(newSess.ID); saveErr != nil {
+ s.logger.Warn("saving session state", "error", saveErr)
+ }
+
+ return newSess.ID, nil
}
// sqlcSessionToSession converts sqlc.Session to Session (application type).
@@ -448,15 +419,6 @@ func (*Store) sqlcSessionToSession(ss sqlc.Session) *Session {
if ss.Title != nil {
session.Title = *ss.Title
}
- if ss.ModelName != nil {
- session.ModelName = *ss.ModelName
- }
- if ss.SystemPrompt != nil {
- session.SystemPrompt = *ss.SystemPrompt
- }
- if ss.MessageCount != nil {
- session.MessageCount = int(*ss.MessageCount)
- }
return session
}
@@ -466,7 +428,7 @@ func (*Store) sqlcMessageToMessage(sm sqlc.Message) (*Message, error) {
// Unmarshal JSONB content to ai.Part slice
var content []*ai.Part
if err := json.Unmarshal(sm.Content, &content); err != nil {
- return nil, fmt.Errorf("failed to unmarshal content: %w", err)
+ return nil, fmt.Errorf("unmarshaling content: %w", err)
}
return &Message{
@@ -474,169 +436,7 @@ func (*Store) sqlcMessageToMessage(sm sqlc.Message) (*Message, error) {
SessionID: sm.SessionID,
Role: sm.Role,
Content: content,
- Status: sm.Status,
SequenceNumber: int(sm.SequenceNumber),
CreatedAt: sm.CreatedAt.Time,
}, nil
}
-
-// =============================================================================
-// Streaming Message Operations (for SSE chat flow)
-// =============================================================================
-
-// MessagePair represents a user-assistant message pair created for streaming.
-// Used to track both messages atomically for SSE-based chat.
-type MessagePair struct {
- UserMsgID uuid.UUID
- AssistantMsgID uuid.UUID
- UserSeq int32
- AssistantSeq int32
-}
-
-// CreateMessagePair atomically creates a user message and empty assistant placeholder.
-// The user message is marked as "completed", the assistant message as "streaming".
-// This is used at the start of a chat turn before SSE streaming begins.
-//
-// Parameters:
-// - ctx: Context for the operation
-// - sessionID: UUID of the session
-// - userContent: User message content as ai.Part slice
-// - assistantID: Pre-generated UUID for the assistant message (used in SSE URL)
-//
-// Returns:
-// - *MessagePair: Contains both message IDs and sequence numbers
-// - error: If creation fails or session doesn't exist
-func (s *Store) CreateMessagePair(
- ctx context.Context,
- sessionID uuid.UUID,
- userContent []*ai.Part,
- assistantID uuid.UUID,
-) (*MessagePair, error) {
- // Database pool is required for transactional operations
- if s.pool == nil {
- return nil, fmt.Errorf("database pool required for CreateMessagePair")
- }
-
- // Begin transaction for atomicity
- tx, err := s.pool.Begin(ctx)
- if err != nil {
- return nil, fmt.Errorf("begin transaction: %w", err)
- }
- defer func() {
- if rollbackErr := tx.Rollback(ctx); rollbackErr != nil {
- s.logger.Debug("transaction rollback (may be already committed)", "error", rollbackErr)
- }
- }()
-
- txQuerier := sqlc.New(tx)
-
- // Lock session row to prevent concurrent modifications
- _, err = txQuerier.LockSession(ctx, sessionID)
- if err != nil {
- return nil, fmt.Errorf("lock session: %w", err)
- }
-
- // Get current max sequence number
- maxSeq, err := txQuerier.GetMaxSequenceNumber(ctx, sessionID)
- if err != nil {
- s.logger.Debug("no existing messages, starting from sequence 0",
- "session_id", sessionID)
- maxSeq = 0
- }
-
- // Marshal user content to JSON
- userContentJSON, err := json.Marshal(userContent)
- if err != nil {
- return nil, fmt.Errorf("marshal user content: %w", err)
- }
-
- // Generate user message ID
- userMsgID := uuid.New()
- userSeq := maxSeq + 1
- assistantSeq := maxSeq + 2
-
- // Insert user message (status = completed)
- _, err = txQuerier.AddMessageWithID(ctx, sqlc.AddMessageWithIDParams{
- ID: userMsgID,
- SessionID: sessionID,
- Role: "user",
- Content: userContentJSON,
- Status: StatusCompleted,
- SequenceNumber: userSeq,
- })
- if err != nil {
- return nil, fmt.Errorf("insert user message: %w", err)
- }
-
- // Insert empty assistant message placeholder (status = streaming)
- emptyContent := []byte("[]") // Empty ai.Part slice
- _, err = txQuerier.AddMessageWithID(ctx, sqlc.AddMessageWithIDParams{
- ID: assistantID,
- SessionID: sessionID,
- Role: RoleAssistant,
- Content: emptyContent,
- Status: StatusStreaming,
- SequenceNumber: assistantSeq,
- })
- if err != nil {
- return nil, fmt.Errorf("insert assistant message: %w", err)
- }
-
- // Update session metadata
- if err = txQuerier.UpdateSessionUpdatedAt(ctx, sqlc.UpdateSessionUpdatedAtParams{
- MessageCount: &assistantSeq,
- SessionID: sessionID,
- }); err != nil {
- return nil, fmt.Errorf("update session metadata: %w", err)
- }
-
- // Commit transaction
- if err := tx.Commit(ctx); err != nil {
- return nil, fmt.Errorf("commit transaction: %w", err)
- }
-
- s.logger.Debug("created message pair",
- "session_id", sessionID,
- "user_msg_id", userMsgID,
- "assistant_msg_id", assistantID)
-
- return &MessagePair{
- UserMsgID: userMsgID,
- AssistantMsgID: assistantID,
- UserSeq: userSeq,
- AssistantSeq: assistantSeq,
- }, nil
-}
-
-// UpdateMessageContent updates the content of a message and marks it as completed.
-// Used after streaming is finished to save the final AI response.
-func (s *Store) UpdateMessageContent(ctx context.Context, msgID uuid.UUID, content []*ai.Part) error {
- contentJSON, err := json.Marshal(content)
- if err != nil {
- return fmt.Errorf("marshal content: %w", err)
- }
-
- if err := s.queries.UpdateMessageContent(ctx, sqlc.UpdateMessageContentParams{
- ID: msgID,
- Content: contentJSON,
- }); err != nil {
- return fmt.Errorf("update message content: %w", err)
- }
-
- s.logger.Debug("updated message content", "msg_id", msgID)
- return nil
-}
-
-// UpdateMessageStatus updates the status of a message.
-// Used to mark streaming messages as failed if an error occurs.
-func (s *Store) UpdateMessageStatus(ctx context.Context, msgID uuid.UUID, status string) error {
- if err := s.queries.UpdateMessageStatus(ctx, sqlc.UpdateMessageStatusParams{
- ID: msgID,
- Status: status,
- }); err != nil {
- return fmt.Errorf("update message status: %w", err)
- }
-
- s.logger.Debug("updated message status", "msg_id", msgID, "status", status)
- return nil
-}
diff --git a/internal/session/store_test.go b/internal/session/store_test.go
new file mode 100644
index 0000000..a1cafda
--- /dev/null
+++ b/internal/session/store_test.go
@@ -0,0 +1,34 @@
+package session
+
+import "testing"
+
+// TestNormalizeRole tests the Genkit role normalization function.
+// Genkit uses "model" for AI responses, but we store "assistant" in the database
+// for consistency with the CHECK constraint.
+func TestNormalizeRole(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {name: "model to assistant", input: "model", want: "assistant"},
+ {name: "user unchanged", input: "user", want: "user"},
+ {name: "assistant unchanged", input: "assistant", want: "assistant"},
+ {name: "system unchanged", input: "system", want: "system"},
+ {name: "tool unchanged", input: "tool", want: "tool"},
+ {name: "empty passthrough", input: "", want: ""},
+ {name: "unknown passthrough", input: "unknown", want: "unknown"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := normalizeRole(tt.input)
+ if got != tt.want {
+ t.Errorf("normalizeRole(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/internal/session/types.go b/internal/session/types.go
deleted file mode 100644
index cbd9c6f..0000000
--- a/internal/session/types.go
+++ /dev/null
@@ -1,112 +0,0 @@
-// Package session provides session persistence functionality for conversation history.
-//
-// Responsibilities: Save/load conversation sessions to PostgreSQL database.
-// Thread Safety: Not thread-safe - caller must synchronize access.
-package session
-
-import (
- "sync"
- "time"
-
- "github.com/firebase/genkit/go/ai"
- "github.com/google/uuid"
-)
-
-// History encapsulates conversation history with thread-safe access.
-//
-// Note: The zero value is NOT useful - use NewHistory() to create instances.
-type History struct {
- mu sync.RWMutex
- messages []*ai.Message
-}
-
-// NewHistory creates a new History instance.
-func NewHistory() *History {
- return &History{
- messages: make([]*ai.Message, 0),
- }
-}
-
-// SetMessages replaces all messages in the history.
-// Used by SessionStore when loading history from the database.
-// Makes a defensive copy to prevent external modification.
-func (h *History) SetMessages(messages []*ai.Message) {
- h.mu.Lock()
- defer h.mu.Unlock()
- h.messages = make([]*ai.Message, len(messages))
- copy(h.messages, messages)
-}
-
-// Messages returns a copy of all messages for thread-safe access
-func (h *History) Messages() []*ai.Message {
- h.mu.RLock()
- defer h.mu.RUnlock()
- result := make([]*ai.Message, len(h.messages))
- copy(result, h.messages)
- return result
-}
-
-// Add appends user message and assistant response
-func (h *History) Add(userInput, assistantResponse string) {
- h.mu.Lock()
- defer h.mu.Unlock()
- h.messages = append(h.messages,
- ai.NewUserMessage(ai.NewTextPart(userInput)),
- ai.NewModelMessage(ai.NewTextPart(assistantResponse)),
- )
-}
-
-// AddMessage appends a single message
-// Returns without effect if msg is nil
-func (h *History) AddMessage(msg *ai.Message) {
- if msg == nil {
- return
- }
- h.mu.Lock()
- defer h.mu.Unlock()
- h.messages = append(h.messages, msg)
-}
-
-// Count returns the number of messages
-func (h *History) Count() int {
- h.mu.RLock()
- defer h.mu.RUnlock()
- return len(h.messages)
-}
-
-// Clear removes all messages
-func (h *History) Clear() {
- h.mu.Lock()
- defer h.mu.Unlock()
- h.messages = make([]*ai.Message, 0)
-}
-
-// Role constants define valid message roles for type safety.
-const (
- RoleUser = "user"
- RoleAssistant = "assistant"
- RoleTool = "tool"
-)
-
-// Session represents a conversation session (application-level type).
-type Session struct {
- ID uuid.UUID
- Title string
- CreatedAt time.Time
- UpdatedAt time.Time
- ModelName string
- SystemPrompt string
- MessageCount int
-}
-
-// Message represents a single conversation message (application-level type).
-// Content field stores Genkit's ai.Part slice, serialized as JSONB in database.
-type Message struct {
- ID uuid.UUID
- SessionID uuid.UUID
- Role string // "user" | "assistant" | "tool"
- Content []*ai.Part // Genkit Part slice (stored as JSONB)
- Status string // Message status: streaming/completed/failed
- SequenceNumber int
- CreatedAt time.Time
-}
diff --git a/internal/sqlc/documents.sql.go b/internal/sqlc/documents.sql.go
deleted file mode 100644
index 48ac5f5..0000000
--- a/internal/sqlc/documents.sql.go
+++ /dev/null
@@ -1,280 +0,0 @@
-// Code generated by sqlc. DO NOT EDIT.
-// versions:
-// sqlc v1.30.0
-// source: documents.sql
-
-package sqlc
-
-import (
- "context"
-
- "github.com/pgvector/pgvector-go"
-)
-
-const countDocuments = `-- name: CountDocuments :one
-SELECT COUNT(*)
-FROM documents
-WHERE metadata @> $1::jsonb
-`
-
-func (q *Queries) CountDocuments(ctx context.Context, dollar_1 []byte) (int64, error) {
- row := q.db.QueryRow(ctx, countDocuments, dollar_1)
- var count int64
- err := row.Scan(&count)
- return count, err
-}
-
-const countDocumentsAll = `-- name: CountDocumentsAll :one
-SELECT COUNT(*)
-FROM documents
-`
-
-func (q *Queries) CountDocumentsAll(ctx context.Context) (int64, error) {
- row := q.db.QueryRow(ctx, countDocumentsAll)
- var count int64
- err := row.Scan(&count)
- return count, err
-}
-
-const deleteDocument = `-- name: DeleteDocument :exec
-DELETE FROM documents
-WHERE id = $1
-`
-
-func (q *Queries) DeleteDocument(ctx context.Context, id string) error {
- _, err := q.db.Exec(ctx, deleteDocument, id)
- return err
-}
-
-const getDocument = `-- name: GetDocument :one
-SELECT id, content, metadata
-FROM documents
-WHERE id = $1
-`
-
-type GetDocumentRow struct {
- ID string `json:"id"`
- Content string `json:"content"`
- Metadata []byte `json:"metadata"`
-}
-
-func (q *Queries) GetDocument(ctx context.Context, id string) (GetDocumentRow, error) {
- row := q.db.QueryRow(ctx, getDocument, id)
- var i GetDocumentRow
- err := row.Scan(&i.ID, &i.Content, &i.Metadata)
- return i, err
-}
-
-const listDocumentsBySourceType = `-- name: ListDocumentsBySourceType :many
-SELECT id, content, metadata
-FROM documents
-WHERE source_type = $1::text
-LIMIT $2
-`
-
-type ListDocumentsBySourceTypeParams struct {
- SourceType string `json:"source_type"`
- ResultLimit int32 `json:"result_limit"`
-}
-
-type ListDocumentsBySourceTypeRow struct {
- ID string `json:"id"`
- Content string `json:"content"`
- Metadata []byte `json:"metadata"`
-}
-
-// List all documents by source_type using dedicated indexed column
-// Used for listing indexed files without needing embeddings
-func (q *Queries) ListDocumentsBySourceType(ctx context.Context, arg ListDocumentsBySourceTypeParams) ([]ListDocumentsBySourceTypeRow, error) {
- rows, err := q.db.Query(ctx, listDocumentsBySourceType, arg.SourceType, arg.ResultLimit)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
- items := []ListDocumentsBySourceTypeRow{}
- for rows.Next() {
- var i ListDocumentsBySourceTypeRow
- if err := rows.Scan(&i.ID, &i.Content, &i.Metadata); err != nil {
- return nil, err
- }
- items = append(items, i)
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return items, nil
-}
-
-const searchBySourceType = `-- name: SearchBySourceType :many
-
-SELECT id, content, metadata,
- (1 - (embedding <=> $1::vector))::float8 AS similarity
-FROM documents
-WHERE source_type = $2::text
-ORDER BY similarity DESC
-LIMIT $3
-`
-
-type SearchBySourceTypeParams struct {
- QueryEmbedding *pgvector.Vector `json:"query_embedding"`
- SourceType string `json:"source_type"`
- ResultLimit int32 `json:"result_limit"`
-}
-
-type SearchBySourceTypeRow struct {
- ID string `json:"id"`
- Content string `json:"content"`
- Metadata []byte `json:"metadata"`
- Similarity float64 `json:"similarity"`
-}
-
-// ===== Optimized RAG Queries (SQL-level filtering) =====
-// Generic search by source_type using dedicated indexed column
-func (q *Queries) SearchBySourceType(ctx context.Context, arg SearchBySourceTypeParams) ([]SearchBySourceTypeRow, error) {
- rows, err := q.db.Query(ctx, searchBySourceType, arg.QueryEmbedding, arg.SourceType, arg.ResultLimit)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
- items := []SearchBySourceTypeRow{}
- for rows.Next() {
- var i SearchBySourceTypeRow
- if err := rows.Scan(
- &i.ID,
- &i.Content,
- &i.Metadata,
- &i.Similarity,
- ); err != nil {
- return nil, err
- }
- items = append(items, i)
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return items, nil
-}
-
-const searchDocuments = `-- name: SearchDocuments :many
-SELECT id, content, metadata,
- (1 - (embedding <=> $1::vector))::float8 AS similarity
-FROM documents
-WHERE metadata @> $2::jsonb
-ORDER BY similarity DESC
-LIMIT $3
-`
-
-type SearchDocumentsParams struct {
- QueryEmbedding *pgvector.Vector `json:"query_embedding"`
- FilterMetadata []byte `json:"filter_metadata"`
- ResultLimit int32 `json:"result_limit"`
-}
-
-type SearchDocumentsRow struct {
- ID string `json:"id"`
- Content string `json:"content"`
- Metadata []byte `json:"metadata"`
- Similarity float64 `json:"similarity"`
-}
-
-func (q *Queries) SearchDocuments(ctx context.Context, arg SearchDocumentsParams) ([]SearchDocumentsRow, error) {
- rows, err := q.db.Query(ctx, searchDocuments, arg.QueryEmbedding, arg.FilterMetadata, arg.ResultLimit)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
- items := []SearchDocumentsRow{}
- for rows.Next() {
- var i SearchDocumentsRow
- if err := rows.Scan(
- &i.ID,
- &i.Content,
- &i.Metadata,
- &i.Similarity,
- ); err != nil {
- return nil, err
- }
- items = append(items, i)
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return items, nil
-}
-
-const searchDocumentsAll = `-- name: SearchDocumentsAll :many
-SELECT id, content, metadata,
- (1 - (embedding <=> $1::vector))::float8 AS similarity
-FROM documents
-ORDER BY similarity DESC
-LIMIT $2
-`
-
-type SearchDocumentsAllParams struct {
- QueryEmbedding *pgvector.Vector `json:"query_embedding"`
- ResultLimit int32 `json:"result_limit"`
-}
-
-type SearchDocumentsAllRow struct {
- ID string `json:"id"`
- Content string `json:"content"`
- Metadata []byte `json:"metadata"`
- Similarity float64 `json:"similarity"`
-}
-
-func (q *Queries) SearchDocumentsAll(ctx context.Context, arg SearchDocumentsAllParams) ([]SearchDocumentsAllRow, error) {
- rows, err := q.db.Query(ctx, searchDocumentsAll, arg.QueryEmbedding, arg.ResultLimit)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
- items := []SearchDocumentsAllRow{}
- for rows.Next() {
- var i SearchDocumentsAllRow
- if err := rows.Scan(
- &i.ID,
- &i.Content,
- &i.Metadata,
- &i.Similarity,
- ); err != nil {
- return nil, err
- }
- items = append(items, i)
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return items, nil
-}
-
-const upsertDocument = `-- name: UpsertDocument :exec
-
-INSERT INTO documents (id, content, embedding, source_type, metadata)
-VALUES ($1, $2, $3, $4, $5)
-ON CONFLICT (id) DO UPDATE SET
- content = EXCLUDED.content,
- embedding = EXCLUDED.embedding,
- source_type = EXCLUDED.source_type,
- metadata = EXCLUDED.metadata
-`
-
-type UpsertDocumentParams struct {
- ID string `json:"id"`
- Content string `json:"content"`
- Embedding *pgvector.Vector `json:"embedding"`
- SourceType *string `json:"source_type"`
- Metadata []byte `json:"metadata"`
-}
-
-// Documents queries for sqlc
-// Generated code will be in internal/sqlc/documents.sql.go
-func (q *Queries) UpsertDocument(ctx context.Context, arg UpsertDocumentParams) error {
- _, err := q.db.Exec(ctx, upsertDocument,
- arg.ID,
- arg.Content,
- arg.Embedding,
- arg.SourceType,
- arg.Metadata,
- )
- return err
-}
diff --git a/internal/sqlc/models.go b/internal/sqlc/models.go
index 90afee5..4ec4daf 100644
--- a/internal/sqlc/models.go
+++ b/internal/sqlc/models.go
@@ -25,16 +25,11 @@ type Message struct {
Content []byte `json:"content"`
SequenceNumber int32 `json:"sequence_number"`
CreatedAt pgtype.Timestamptz `json:"created_at"`
- Status string `json:"status"`
- UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
type Session struct {
- ID uuid.UUID `json:"id"`
- Title *string `json:"title"`
- CreatedAt pgtype.Timestamptz `json:"created_at"`
- UpdatedAt pgtype.Timestamptz `json:"updated_at"`
- ModelName *string `json:"model_name"`
- SystemPrompt *string `json:"system_prompt"`
- MessageCount *int32 `json:"message_count"`
+ ID uuid.UUID `json:"id"`
+ Title *string `json:"title"`
+ CreatedAt pgtype.Timestamptz `json:"created_at"`
+ UpdatedAt pgtype.Timestamptz `json:"updated_at"`
}
diff --git a/internal/sqlc/sessions.sql.go b/internal/sqlc/sessions.sql.go
index ade5590..df5e238 100644
--- a/internal/sqlc/sessions.sql.go
+++ b/internal/sqlc/sessions.sql.go
@@ -12,7 +12,7 @@ import (
)
const addMessage = `-- name: AddMessage :exec
-INSERT INTO message (session_id, role, content, sequence_number)
+INSERT INTO messages (session_id, role, content, sequence_number)
VALUES ($1, $2, $3, $4)
`
@@ -34,100 +34,27 @@ func (q *Queries) AddMessage(ctx context.Context, arg AddMessageParams) error {
return err
}
-const addMessageWithID = `-- name: AddMessageWithID :one
-INSERT INTO message (id, session_id, role, content, status, sequence_number)
-VALUES ($1, $2, $3, $4, $5, $6)
-RETURNING id, session_id, role, content, sequence_number, created_at, status, updated_at
-`
-
-type AddMessageWithIDParams struct {
- ID uuid.UUID `json:"id"`
- SessionID uuid.UUID `json:"session_id"`
- Role string `json:"role"`
- Content []byte `json:"content"`
- Status string `json:"status"`
- SequenceNumber int32 `json:"sequence_number"`
-}
-
-// Add message with pre-assigned ID and status (for streaming)
-func (q *Queries) AddMessageWithID(ctx context.Context, arg AddMessageWithIDParams) (Message, error) {
- row := q.db.QueryRow(ctx, addMessageWithID,
- arg.ID,
- arg.SessionID,
- arg.Role,
- arg.Content,
- arg.Status,
- arg.SequenceNumber,
- )
- var i Message
- err := row.Scan(
- &i.ID,
- &i.SessionID,
- &i.Role,
- &i.Content,
- &i.SequenceNumber,
- &i.CreatedAt,
- &i.Status,
- &i.UpdatedAt,
- )
- return i, err
-}
-
-const countMessages = `-- name: CountMessages :one
-SELECT COUNT(*)::integer AS count
-FROM message
-WHERE session_id = $1
-`
-
-// Count messages in a session
-func (q *Queries) CountMessages(ctx context.Context, sessionID uuid.UUID) (int32, error) {
- row := q.db.QueryRow(ctx, countMessages, sessionID)
- var count int32
- err := row.Scan(&count)
- return count, err
-}
-
const createSession = `-- name: CreateSession :one
-INSERT INTO sessions (title, model_name, system_prompt)
-VALUES ($1, $2, $3)
-RETURNING id, title, created_at, updated_at, model_name, system_prompt, message_count
+INSERT INTO sessions (title)
+VALUES ($1)
+RETURNING id, title, created_at, updated_at
`
-type CreateSessionParams struct {
- Title *string `json:"title"`
- ModelName *string `json:"model_name"`
- SystemPrompt *string `json:"system_prompt"`
-}
-
-// Sessions queries for sqlc
+// Sessions and messages queries for sqlc
// Generated code will be in internal/sqlc/sessions.sql.go
-func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
- row := q.db.QueryRow(ctx, createSession, arg.Title, arg.ModelName, arg.SystemPrompt)
+func (q *Queries) CreateSession(ctx context.Context, title *string) (Session, error) {
+ row := q.db.QueryRow(ctx, createSession, title)
var i Session
err := row.Scan(
&i.ID,
&i.Title,
&i.CreatedAt,
&i.UpdatedAt,
- &i.ModelName,
- &i.SystemPrompt,
- &i.MessageCount,
)
return i, err
}
-const deleteMessages = `-- name: DeleteMessages :exec
-DELETE FROM message
-WHERE session_id = $1
-`
-
-// Delete all messages in a session
-func (q *Queries) DeleteMessages(ctx context.Context, sessionID uuid.UUID) error {
- _, err := q.db.Exec(ctx, deleteMessages, sessionID)
- return err
-}
-
const deleteSession = `-- name: DeleteSession :exec
DELETE FROM sessions
WHERE id = $1
@@ -138,105 +65,20 @@ func (q *Queries) DeleteSession(ctx context.Context, id uuid.UUID) error {
return err
}
-const getMaxSequenceNumber = `-- name: GetMaxSequenceNumber :one
+const maxSequenceNumber = `-- name: MaxSequenceNumber :one
SELECT COALESCE(MAX(sequence_number), 0)::integer AS max_seq
-FROM message
+FROM messages
WHERE session_id = $1
`
-// Get max sequence number for a session
-func (q *Queries) GetMaxSequenceNumber(ctx context.Context, sessionID uuid.UUID) (int32, error) {
- row := q.db.QueryRow(ctx, getMaxSequenceNumber, sessionID)
+// MaxSequenceNumber returns the max sequence number for a session (returns 0 if no messages).
+func (q *Queries) MaxSequenceNumber(ctx context.Context, sessionID uuid.UUID) (int32, error) {
+ row := q.db.QueryRow(ctx, maxSequenceNumber, sessionID)
var max_seq int32
err := row.Scan(&max_seq)
return max_seq, err
}
-const listSessions = `-- name: ListSessions :many
-SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count
-FROM sessions
-ORDER BY updated_at DESC
-LIMIT $2
-OFFSET $1
-`
-
-type ListSessionsParams struct {
- ResultOffset int32 `json:"result_offset"`
- ResultLimit int32 `json:"result_limit"`
-}
-
-func (q *Queries) ListSessions(ctx context.Context, arg ListSessionsParams) ([]Session, error) {
- rows, err := q.db.Query(ctx, listSessions, arg.ResultOffset, arg.ResultLimit)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
- items := []Session{}
- for rows.Next() {
- var i Session
- if err := rows.Scan(
- &i.ID,
- &i.Title,
- &i.CreatedAt,
- &i.UpdatedAt,
- &i.ModelName,
- &i.SystemPrompt,
- &i.MessageCount,
- ); err != nil {
- return nil, err
- }
- items = append(items, i)
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return items, nil
-}
-
-const listSessionsWithMessages = `-- name: ListSessionsWithMessages :many
-SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count
-FROM sessions
-WHERE message_count > 0 OR title IS NOT NULL
-ORDER BY updated_at DESC
-LIMIT $2
-OFFSET $1
-`
-
-type ListSessionsWithMessagesParams struct {
- ResultOffset int32 `json:"result_offset"`
- ResultLimit int32 `json:"result_limit"`
-}
-
-// Only list sessions that have messages or titles (not empty sessions)
-// This is used for sidebar to hide "New Chat" placeholder sessions
-func (q *Queries) ListSessionsWithMessages(ctx context.Context, arg ListSessionsWithMessagesParams) ([]Session, error) {
- rows, err := q.db.Query(ctx, listSessionsWithMessages, arg.ResultOffset, arg.ResultLimit)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
- items := []Session{}
- for rows.Next() {
- var i Session
- if err := rows.Scan(
- &i.ID,
- &i.Title,
- &i.CreatedAt,
- &i.UpdatedAt,
- &i.ModelName,
- &i.SystemPrompt,
- &i.MessageCount,
- ); err != nil {
- return nil, err
- }
- items = append(items, i)
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return items, nil
-}
-
const lockSession = `-- name: LockSession :one
SELECT id FROM sessions WHERE id = $1 FOR UPDATE
`
@@ -249,8 +91,8 @@ func (q *Queries) LockSession(ctx context.Context, id uuid.UUID) (uuid.UUID, err
}
const messages = `-- name: Messages :many
-SELECT id, session_id, role, content, sequence_number, created_at, status, updated_at
-FROM message
+SELECT id, session_id, role, content, sequence_number, created_at
+FROM messages
WHERE session_id = $1
ORDER BY sequence_number ASC
LIMIT $3
@@ -280,8 +122,6 @@ func (q *Queries) Messages(ctx context.Context, arg MessagesParams) ([]Message,
&i.Content,
&i.SequenceNumber,
&i.CreatedAt,
- &i.Status,
- &i.UpdatedAt,
); err != nil {
return nil, err
}
@@ -294,7 +134,7 @@ func (q *Queries) Messages(ctx context.Context, arg MessagesParams) ([]Message,
}
const session = `-- name: Session :one
-SELECT id, title, created_at, updated_at, model_name, system_prompt, message_count
+SELECT id, title, created_at, updated_at
FROM sessions
WHERE id = $1
`
@@ -307,48 +147,46 @@ func (q *Queries) Session(ctx context.Context, id uuid.UUID) (Session, error) {
&i.Title,
&i.CreatedAt,
&i.UpdatedAt,
- &i.ModelName,
- &i.SystemPrompt,
- &i.MessageCount,
)
return i, err
}
-const updateMessageContent = `-- name: UpdateMessageContent :exec
-UPDATE message
-SET content = $2,
- status = 'completed',
- updated_at = NOW()
-WHERE id = $1
-`
-
-type UpdateMessageContentParams struct {
- ID uuid.UUID `json:"id"`
- Content []byte `json:"content"`
-}
-
-// Update message content and mark as completed
-func (q *Queries) UpdateMessageContent(ctx context.Context, arg UpdateMessageContentParams) error {
- _, err := q.db.Exec(ctx, updateMessageContent, arg.ID, arg.Content)
- return err
-}
-
-const updateMessageStatus = `-- name: UpdateMessageStatus :exec
-UPDATE message
-SET status = $2,
- updated_at = NOW()
-WHERE id = $1
+const sessions = `-- name: Sessions :many
+SELECT id, title, created_at, updated_at
+FROM sessions
+ORDER BY updated_at DESC
+LIMIT $2
+OFFSET $1
`
-type UpdateMessageStatusParams struct {
- ID uuid.UUID `json:"id"`
- Status string `json:"status"`
+type SessionsParams struct {
+ ResultOffset int32 `json:"result_offset"`
+ ResultLimit int32 `json:"result_limit"`
}
-// Update message status (streaming/completed/failed)
-func (q *Queries) UpdateMessageStatus(ctx context.Context, arg UpdateMessageStatusParams) error {
- _, err := q.db.Exec(ctx, updateMessageStatus, arg.ID, arg.Status)
- return err
+func (q *Queries) Sessions(ctx context.Context, arg SessionsParams) ([]Session, error) {
+ rows, err := q.db.Query(ctx, sessions, arg.ResultOffset, arg.ResultLimit)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ items := []Session{}
+ for rows.Next() {
+ var i Session
+ if err := rows.Scan(
+ &i.ID,
+ &i.Title,
+ &i.CreatedAt,
+ &i.UpdatedAt,
+ ); err != nil {
+ return nil, err
+ }
+ items = append(items, i)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
}
const updateSessionTitle = `-- name: UpdateSessionTitle :exec
@@ -371,17 +209,11 @@ func (q *Queries) UpdateSessionTitle(ctx context.Context, arg UpdateSessionTitle
const updateSessionUpdatedAt = `-- name: UpdateSessionUpdatedAt :exec
UPDATE sessions
-SET updated_at = NOW(),
- message_count = $1
-WHERE id = $2
+SET updated_at = NOW()
+WHERE id = $1
`
-type UpdateSessionUpdatedAtParams struct {
- MessageCount *int32 `json:"message_count"`
- SessionID uuid.UUID `json:"session_id"`
-}
-
-func (q *Queries) UpdateSessionUpdatedAt(ctx context.Context, arg UpdateSessionUpdatedAtParams) error {
- _, err := q.db.Exec(ctx, updateSessionUpdatedAt, arg.MessageCount, arg.SessionID)
+func (q *Queries) UpdateSessionUpdatedAt(ctx context.Context, sessionID uuid.UUID) error {
+ _, err := q.db.Exec(ctx, updateSessionUpdatedAt, sessionID)
return err
}
diff --git a/internal/testutil/googleai.go b/internal/testutil/googleai.go
index 62669bb..bf1265e 100644
--- a/internal/testutil/googleai.go
+++ b/internal/testutil/googleai.go
@@ -51,9 +51,9 @@ func SetupGoogleAI(tb testing.TB) *GoogleAISetup {
ctx := context.Background()
// Find project root to get absolute path to prompts directory
- projectRoot, err := findProjectRoot()
+ projectRoot, err := FindProjectRoot()
if err != nil {
- tb.Fatalf("Failed to find project root: %v", err)
+ tb.Fatalf("finding project root: %v", err)
}
promptsDir := filepath.Join(projectRoot, "prompts")
@@ -64,19 +64,18 @@ func SetupGoogleAI(tb testing.TB) *GoogleAISetup {
// Nil check: genkit.Init returns nil on internal initialization failure
if g == nil {
- tb.Fatal("Failed to initialize Genkit: genkit.Init returned nil")
+ tb.Fatal("genkit.Init returned nil")
}
// Create embedder using config constant for maintainability
- embedder := googlegenai.GoogleAIEmbedder(g, config.DefaultEmbedderModel)
+ embedder := googlegenai.GoogleAIEmbedder(g, config.DefaultGeminiEmbedderModel)
// Nil check: GoogleAIEmbedder returns nil if model lookup fails
if embedder == nil {
- tb.Fatalf("Failed to create embedder: GoogleAIEmbedder returned nil for model %q", config.DefaultEmbedderModel)
+ tb.Fatalf("GoogleAIEmbedder returned nil for model %q", config.DefaultGeminiEmbedderModel)
}
- // Create quiet logger for tests (discard all logs)
- logger := slog.New(slog.DiscardHandler)
+ logger := DiscardLogger()
return &GoogleAISetup{
Embedder: embedder,
diff --git a/internal/testutil/logger.go b/internal/testutil/logger.go
index b1acf34..2bf5fba 100644
--- a/internal/testutil/logger.go
+++ b/internal/testutil/logger.go
@@ -4,15 +4,8 @@ import (
"log/slog"
)
-// DiscardLogger returns a slog.Logger that discards all output.
-// This is the standard library pattern for test loggers.
-//
-// Use this in tests to reduce noise. For components that use log.Logger
-// (which is a type alias for *slog.Logger), use log.NewNop() directly.
-//
-// Note: log.Logger is a type alias for *slog.Logger, so this function
-// and log.NewNop() return the same type. Prefer log.NewNop() when working
-// with the internal/log package.
+// DiscardLogger returns a *slog.Logger that discards all output.
+// Use this in tests to reduce log noise.
func DiscardLogger() *slog.Logger {
return slog.New(slog.DiscardHandler)
}
diff --git a/internal/testutil/postgres.go b/internal/testutil/postgres.go
index 30cf689..a056ec0 100644
--- a/internal/testutil/postgres.go
+++ b/internal/testutil/postgres.go
@@ -26,15 +26,14 @@ import (
// Provides:
// - Isolated PostgreSQL instance with pgvector extension
// - Connection pool for database operations
-// - Automatic cleanup via cleanup function
+// - Automatic cleanup via tb.Cleanup (no manual cleanup needed)
//
// Usage:
//
-// db, cleanup := testutil.SetupTestDB(t)
-// defer cleanup()
+// db := testutil.SetupTestDB(t)
// // Use db.Pool for database operations
type TestDBContainer struct {
- Container *postgres.PostgresContainer
+ container *postgres.PostgresContainer
Pool *pgxpool.Pool
ConnStr string
}
@@ -46,25 +45,24 @@ type TestDBContainer struct {
// - Test database schema (via migrations)
// - Connection pool ready for use
//
-// Returns:
-// - TestDBContainer: Container with connection pool
-// - cleanup function: Must be called to terminate container
+// Cleanup is registered via tb.Cleanup and runs automatically when the test ends.
//
// Example:
//
// func TestMyFeature(t *testing.T) {
-// db, cleanup := testutil.SetupTestDB(t)
-// defer cleanup()
+// db := testutil.SetupTestDB(t)
//
// // Use db.Pool for queries
// var count int
// err := db.Pool.QueryRow(ctx, "SELECT COUNT(*) FROM documents").Scan(&count)
-// require.NoError(t, err)
+// if err != nil {
+// t.Fatalf("QueryRow() unexpected error: %v", err)
+// }
// }
//
// Note: Accepts testing.TB interface to support both *testing.T (tests) and
// *testing.B (benchmarks). This allows the same setup to be used in both contexts.
-func SetupTestDB(tb testing.TB) (*TestDBContainer, func()) {
+func SetupTestDB(tb testing.TB) *TestDBContainer {
tb.Helper()
ctx := context.Background()
@@ -81,62 +79,58 @@ func SetupTestDB(tb testing.TB) (*TestDBContainer, func()) {
WithStartupTimeout(60*time.Second)),
)
if err != nil {
- tb.Fatalf("Failed to start PostgreSQL container: %v", err)
+ tb.Fatalf("starting PostgreSQL container: %v", err)
}
// Get connection string
connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
if err != nil {
- _ = pgContainer.Terminate(ctx)
- tb.Fatalf("Failed to get connection string: %v", err)
+ _ = pgContainer.Terminate(ctx) // best-effort cleanup
+ tb.Fatalf("getting connection string: %v", err)
}
// Create connection pool
pool, err := pgxpool.New(ctx, connStr)
if err != nil {
- _ = pgContainer.Terminate(ctx)
- tb.Fatalf("Failed to create connection pool: %v", err)
+ _ = pgContainer.Terminate(ctx) // best-effort cleanup
+ tb.Fatalf("creating connection pool: %v", err)
}
// Verify connection
if err := pool.Ping(ctx); err != nil {
pool.Close()
- _ = pgContainer.Terminate(ctx)
- tb.Fatalf("Failed to ping database: %v", err)
+ _ = pgContainer.Terminate(ctx) // best-effort cleanup
+ tb.Fatalf("pinging database: %v", err)
}
// Run migrations
if err := runMigrations(ctx, pool); err != nil {
pool.Close()
_ = pgContainer.Terminate(ctx)
- tb.Fatalf("Failed to run migrations: %v", err)
+ tb.Fatalf("running migrations: %v", err)
}
container := &TestDBContainer{
- Container: pgContainer,
+ container: pgContainer,
Pool: pool,
ConnStr: connStr,
}
- cleanup := func() {
- if pool != nil {
- pool.Close()
- }
- if pgContainer != nil {
- _ = pgContainer.Terminate(context.Background())
- }
- }
+ tb.Cleanup(func() {
+ pool.Close()
+ _ = pgContainer.Terminate(context.Background())
+ })
- return container, cleanup
+ return container
}
-// findProjectRoot finds the project root directory by looking for go.mod.
+// FindProjectRoot finds the project root directory by looking for go.mod.
// This allows tests to run from any subdirectory and still find migration files.
-func findProjectRoot() (string, error) {
+func FindProjectRoot() (string, error) {
// Start from the current file's directory
_, filename, _, ok := runtime.Caller(0)
if !ok {
- return "", fmt.Errorf("failed to get current file path")
+ return "", fmt.Errorf("getting current file path")
}
dir := filepath.Dir(filename)
@@ -168,9 +162,9 @@ func findProjectRoot() (string, error) {
//nolint:gocognit // Complex error handling necessary for transaction safety in test utility
func runMigrations(ctx context.Context, pool *pgxpool.Pool) error {
// Find project root to build absolute paths to migrations
- projectRoot, err := findProjectRoot()
+ projectRoot, err := FindProjectRoot()
if err != nil {
- return fmt.Errorf("failed to find project root: %w", err)
+ return fmt.Errorf("finding project root: %w", err)
}
// Read and execute migration files in order
@@ -183,7 +177,7 @@ func runMigrations(ctx context.Context, pool *pgxpool.Pool) error {
// #nosec G304 -- migration paths are hardcoded constants, not from user input
migrationSQL, readErr := os.ReadFile(migrationPath)
if readErr != nil {
- return fmt.Errorf("failed to read migration %s: %w", migrationPath, readErr)
+ return fmt.Errorf("reading migration %s: %w", migrationPath, readErr)
}
// Skip empty migration files to avoid unnecessary execution
@@ -198,7 +192,7 @@ func runMigrations(ctx context.Context, pool *pgxpool.Pool) error {
// This ensures that if a migration fails, changes are rolled back
tx, beginErr := pool.Begin(ctx)
if beginErr != nil {
- return fmt.Errorf("failed to begin transaction for migration %s: %w", migrationPath, beginErr)
+ return fmt.Errorf("beginning transaction for migration %s: %w", migrationPath, beginErr)
}
// Ensure transaction is always closed (rollback unless committed)
@@ -215,11 +209,11 @@ func runMigrations(ctx context.Context, pool *pgxpool.Pool) error {
_, execErr := tx.Exec(ctx, string(migrationSQL))
if execErr != nil {
- return fmt.Errorf("failed to execute migration %s: %w", migrationPath, execErr)
+ return fmt.Errorf("executing migration %s: %w", migrationPath, execErr)
}
if commitErr := tx.Commit(ctx); commitErr != nil {
- return fmt.Errorf("failed to commit migration %s: %w", migrationPath, commitErr)
+ return fmt.Errorf("committing migration %s: %w", migrationPath, commitErr)
}
committed = true
return nil
diff --git a/internal/testutil/postgres_test.go b/internal/testutil/postgres_test.go
index 3a8f2e2..e5468ee 100644
--- a/internal/testutil/postgres_test.go
+++ b/internal/testutil/postgres_test.go
@@ -21,14 +21,13 @@ import (
// Run with: go test -tags=integration ./internal/testutil -v
func TestSetupTestDB_Integration(t *testing.T) {
// Setup test database
- dbContainer, cleanup := SetupTestDB(t)
- defer cleanup()
+ dbContainer := SetupTestDB(t)
// Verify database is accessible
ctx := context.Background()
err := dbContainer.Pool.Ping(ctx)
if err != nil {
- t.Fatalf("Failed to ping database: %v", err)
+ t.Fatalf("Pool.Ping() unexpected error: %v", err)
}
// Verify pgvector extension is installed
@@ -36,26 +35,24 @@ func TestSetupTestDB_Integration(t *testing.T) {
err = dbContainer.Pool.QueryRow(ctx,
"SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'vector')").Scan(&hasExtension)
if err != nil {
- t.Fatalf("Failed to check for vector extension: %v", err)
+ t.Fatalf("QueryRow(vector extension check) unexpected error: %v", err)
}
if !hasExtension {
- t.Error("pgvector extension not installed")
+ t.Error("pgvector extension installed = false, want true")
}
// Verify all required tables exist
- tables := []string{"documents", "sessions", "session_messages"}
+ tables := []string{"documents", "sessions", "messages"}
for _, table := range tables {
var exists bool
err = dbContainer.Pool.QueryRow(ctx,
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name = $1)", table).Scan(&exists)
if err != nil {
- t.Fatalf("Failed to check for table %s: %v", table, err)
+ t.Fatalf("QueryRow(table %q check) unexpected error: %v", table, err)
}
if !exists {
- t.Errorf("Table %s does not exist", table)
+ t.Errorf("table %q exists = false, want true", table)
}
}
-
- t.Log("Database setup successful with all required tables")
}
diff --git a/internal/testutil/rag.go b/internal/testutil/rag.go
index 2b2c918..712d5bc 100644
--- a/internal/testutil/rag.go
+++ b/internal/testutil/rag.go
@@ -2,6 +2,8 @@ package testutil
import (
"context"
+ "os"
+ "path/filepath"
"testing"
"github.com/firebase/genkit/go/ai"
@@ -46,8 +48,7 @@ type RAGSetup struct {
// Example:
//
// func TestRAGFeature(t *testing.T) {
-// db, dbCleanup := testutil.SetupTestDB(t)
-// defer dbCleanup()
+// db := testutil.SetupTestDB(t)
//
// rag := testutil.SetupRAG(t, db.Pool)
//
@@ -64,10 +65,12 @@ type RAGSetup struct {
func SetupRAG(tb testing.TB, pool *pgxpool.Pool) *RAGSetup {
tb.Helper()
- // Check for API key (SetupGoogleAI skips test if not set)
- // We don't use the returned setup because we need to create a
- // Genkit instance with both GoogleAI and PostgreSQL plugins
- _ = SetupGoogleAI(tb)
+ // Check for API key — skip test if not set.
+ // We don't use SetupGoogleAI here because we need a Genkit instance
+ // with both GoogleAI and PostgreSQL plugins.
+ if os.Getenv("GEMINI_API_KEY") == "" {
+ tb.Skip("GEMINI_API_KEY not set - skipping test requiring embedder")
+ }
ctx := context.Background()
@@ -77,7 +80,7 @@ func SetupRAG(tb testing.TB, pool *pgxpool.Pool) *RAGSetup {
postgresql.WithDatabase("koopa_test"),
)
if err != nil {
- tb.Fatalf("Failed to create PostgresEngine: %v", err)
+ tb.Fatalf("creating PostgresEngine: %v", err)
}
// Create PostgreSQL plugin
@@ -85,30 +88,30 @@ func SetupRAG(tb testing.TB, pool *pgxpool.Pool) *RAGSetup {
// Re-initialize Genkit with both plugins
// We need to create a new Genkit instance that has both plugins
- projectRoot, err := findProjectRoot()
+ projectRoot, err := FindProjectRoot()
if err != nil {
- tb.Fatalf("Failed to find project root: %v", err)
+ tb.Fatalf("finding project root: %v", err)
}
g := genkit.Init(ctx,
genkit.WithPlugins(&googlegenai.GoogleAI{}, postgres),
- genkit.WithPromptDir(projectRoot+"/prompts"),
+ genkit.WithPromptDir(filepath.Join(projectRoot, "prompts")),
)
if g == nil {
- tb.Fatal("Failed to initialize Genkit with PostgreSQL plugin")
+ tb.Fatal("genkit.Init with PostgreSQL plugin returned nil")
}
// Create embedder
- embedder := googlegenai.GoogleAIEmbedder(g, config.DefaultEmbedderModel)
+ embedder := googlegenai.GoogleAIEmbedder(g, config.DefaultGeminiEmbedderModel)
if embedder == nil {
- tb.Fatalf("Failed to create embedder for model %q", config.DefaultEmbedderModel)
+ tb.Fatalf("GoogleAIEmbedder returned nil for model %q", config.DefaultGeminiEmbedderModel)
}
// Create DocStore and Retriever using shared config factory
cfg := rag.NewDocStoreConfig(embedder)
docStore, retriever, err := postgresql.DefineRetriever(ctx, g, postgres, cfg)
if err != nil {
- tb.Fatalf("Failed to define retriever: %v", err)
+ tb.Fatalf("defining retriever: %v", err)
}
return &RAGSetup{
diff --git a/internal/tools/README.md b/internal/tools/README.md
deleted file mode 100644
index b8a6772..0000000
--- a/internal/tools/README.md
+++ /dev/null
@@ -1,132 +0,0 @@
-# Tools Package
-
-Secure AI agent toolkit with structured error handling and zero-global-state design.
-
-[繁體中文](./README_ZH_TW.md)
-
----
-
-## Design Philosophy
-
-The Tools building maintainable, testable, and LLM-friendly agent tools.
-
-### Core Principles
-
-**1. Config Struct Pattern**
-All dependencies injected via `KitConfig` structure—no global state, explicit dependencies, testable by design.
-
-**2. Functional Options Pattern**
-Optional features (logging, metrics) configured through functional options, preserving backward compatibility.
-
-**3. Structured Results**
-Standardized `Result{Status, Data, Error}` format enables consistent LLM interaction and rich error context.
-
-**4. Error Semantics: Agent vs System**
-- **Agent Errors**: Recoverable failures (file not found, permission denied) returned in `Result` → LLM can retry
-- **System Errors**: Infrastructure failures (database down, OOM) returned as Go `error` → requires human intervention
-
-**5. Zero Maintenance Tool Registry**
-Leverages `genkit.ListTools()` API—single source of truth, no manual list synchronization.
-
----
-
-## Architecture
-
-```mermaid
-graph TD
- Kit[Kit - Tool Collection
Config Struct Pattern
Functional Options
No Global State]
-
- Kit --> FileTools[File Tools]
- Kit --> SystemTools[System Tools]
- Kit --> NetworkTools[Network Tools]
- Kit --> KnowledgeTools[Knowledge Tools]
-
- FileTools --> ReadFile[readFile]
- FileTools --> WriteFile[writeFile]
- FileTools --> ListFiles[listFiles]
- FileTools --> DeleteFile[deleteFile]
- FileTools --> GetFileInfo[getFileInfo]
-
- SystemTools --> ExecuteCommand[executeCommand]
- SystemTools --> GetEnv[getEnv]
- SystemTools --> CurrentTime[currentTime]
-
- NetworkTools --> HTTPGet[httpGet
SSRF Protection]
-
- KnowledgeTools --> SearchHistory[searchHistory]
- KnowledgeTools --> SearchDocuments[searchDocuments]
- KnowledgeTools --> SearchSystemKnowledge[searchSystemKnowledge]
-```
-
-**12 tools across 4 categories**: File (5) • System (3) • Network (1) • Knowledge (3)
-
----
-
-## Design Decisions
-
-### Why Structured Result over Raw Returns?
-
-**Problem**: Inconsistent error handling across tools makes LLM interaction unpredictable.
-
-**Solution**: Standardized `Result` type with status, data, and structured errors.
-
-**Benefits**:
-- LLM can parse status semantically
-- Rich debugging context through error codes
-- Consistent interface across all tools
-- Enables programmatic error handling
-
-### Why Distinguish Agent Error from System Error?
-
-**Problem**: Treating all errors equally prevents LLM from recovering gracefully.
-
-**Solution**: Agent errors in `Result` (LLM-visible), system errors as Go `error` (human-visible).
-
-**Benefits**:
-- LLM can retry on agent errors (wrong path, missing permission)
-- System errors halt execution (database failure requires ops intervention)
-- Genkit framework handles each type appropriately
-
-### Why Zero Tool List Maintenance?
-
-**Problem**: Manual tool lists drift out of sync with registrations, causing runtime failures.
-
-**Solution**: Use Genkit's `ListTools()` API as single source of truth.
-
-**Benefits**:
-- DRY principle—register once, enumerate anywhere
-- No sync bugs when adding/removing tools
-- Runtime tool discovery support
-
----
-
-## Security Model
-
-All tools enforce **defense-in-depth** validation:
-
-- **Path Validation**: Prevents traversal attacks, enforces allow-list directories, resolves symlinks
-- **Command Validation**: Blocks dangerous commands (`rm -rf`, `dd`, `format`, `sudo`, ...)
-- **Environment Filtering**: Blocks sensitive variables (API keys, passwords, tokens)
-- **HTTP Protection**: SSRF defense (blocks internal IPs, localhost, metadata services), size limits, timeouts
-
-**Security Philosophy**: Fail closed—deny by default, allow explicitly validated operations only.
-
----
-
-## Design Influences
-
-### Genkit Framework
-
-**Source**: [Firebase Genkit](https://github.com/firebase/genkit) - Google's AI framework for building production-ready AI agents
-
-**Design Philosophy**:
-- **Tool-centric**: AI agents invoke tools to interact with the world
-- **Registry pattern**: Central tool registry (`ListTools()`, `LookupTool()`) eliminates manual list maintenance
-- **Structured I/O**: Typed inputs/outputs with JSON schema validation
-- **Framework-managed lifecycle**: Genkit handles tool registration, discovery, and invocation
-
-**Why we adopted it**:
-- Zero boilerplate—register once, use anywhere
-- Type-safe tool definitions with schema validation
-- Built-in observability and tracing
-- Production-ready error handling
diff --git a/internal/tools/README_ZH_TW.md b/internal/tools/README_ZH_TW.md
deleted file mode 100644
index 63cd50c..0000000
--- a/internal/tools/README_ZH_TW.md
+++ /dev/null
@@ -1,136 +0,0 @@
-# Tools 套件
-
-具備結構化錯誤處理和零全局狀態設計的安全 AI 代理工具包。
-
-[English](./README.md)
-
----
-
-## 設計理念
-
-Tools 用於構建可維護、可測試、LLM 友好的代理工具。
-
-### 核心原則
-
-**1. Config Struct Pattern**
-所有依賴通過 `KitConfig` 結構注入—無全局狀態、明確依賴、設計即可測試。
-
-**2. Functional Options Pattern**
-可選功能(日誌、指標)通過函數選項配置,保持向後兼容性。
-
-**3. 結構化結果**
-標準化的 `Result{Status, Data, Error}` 格式實現一致的 LLM 互動和豐富的錯誤上下文。
-
-**4. 錯誤語意:Agent vs System**
-
-- **Agent Errors**:可恢復故障(檔案不存在、權限拒絕)在 `Result` 中返回 → LLM 可重試
-- **System Errors**:基礎設施故障(資料庫停機、記憶體不足)作為 Go `error` 返回 → 需人工介入
-
-**5. 零維護工具註冊表**
-利用 `genkit.ListTools()` API—單一事實來源,無需手動列表同步。
-
----
-
-## 架構
-
-```mermaid
-graph TD
- Kit[Kit - 工具集合
Config Struct Pattern
Functional Options
無全局狀態]
-
- Kit --> FileTools[檔案工具]
- Kit --> SystemTools[系統工具]
- Kit --> NetworkTools[網路工具]
- Kit --> KnowledgeTools[知識工具]
-
- FileTools --> ReadFile[readFile]
- FileTools --> WriteFile[writeFile]
- FileTools --> ListFiles[listFiles]
- FileTools --> DeleteFile[deleteFile]
- FileTools --> GetFileInfo[getFileInfo]
-
- SystemTools --> ExecuteCommand[executeCommand]
- SystemTools --> GetEnv[getEnv]
- SystemTools --> CurrentTime[currentTime]
-
- NetworkTools --> HTTPGet[httpGet
SSRF 防護]
-
- KnowledgeTools --> SearchHistory[searchHistory]
- KnowledgeTools --> SearchDocuments[searchDocuments]
- KnowledgeTools --> SearchSystemKnowledge[searchSystemKnowledge]
-```
-
-**4 大類別 12 個工具**:檔案(5)• 系統(3)• 網路(1)• 知識(3)
-
----
-
-### 為什麼用結構化 Result 而非原始返回值?
-
-**問題**:工具間不一致的錯誤處理使 LLM 互動不可預測。
-
-**解決方案**:標準化的 `Result` 類型包含狀態、數據和結構化錯誤。
-
-**優點**:
-
-- LLM 可語意化解析狀態
-- 通過錯誤代碼提供豐富除錯上下文
-- 所有工具間一致的介面
-- 支援程式化錯誤處理
-
-### 為什麼區分 Agent Error 和 System Error?
-
-**問題**:平等對待所有錯誤阻止 LLM 優雅恢復。
-
-**解決方案**:Agent 錯誤放在 `Result` 中(LLM 可見),系統錯誤作為 Go `error`(人類可見)。
-
-**優點**:
-
-- LLM 可在 agent 錯誤時重試(錯誤路徑、缺少權限)
-- 系統錯誤停止執行(資料庫故障需運維介入)
-- Genkit 框架適當處理每種類型
-
-### 為什麼零工具列表維護?
-
-**問題**:手動工具列表與註冊不同步,導致執行時故障。
-
-**解決方案**:使用 Genkit 的 `ListTools()` API 作為單一事實來源。
-
-**優點**:
-
-- DRY 原則—註冊一次,隨處枚舉
-- 添加/刪除工具時無同步錯誤
-- 支援執行時工具發現
-
----
-
-## 安全模型
-
-所有工具強制執行**縱深防禦**驗證:
-
-- **路徑驗證**:防止穿越攻擊、強制允許列表目錄、解析符號連結
-- **命令驗證**:阻擋危險命令(`rm -rf`、`dd`、`format`、`sudo`、...)
-- **環境變數過濾**:阻擋敏感變數(API 金鑰、密碼、令牌)
-- **HTTP 防護**:SSRF 防禦(阻擋內部 IP、localhost、元數據服務)、大小限制、逾時
-
-**安全理念**:預設關閉—預設拒絕,僅允許明確驗證的操作。
-
----
-
-## 設計影響
-
-### Genkit 框架
-
-**來源**:[Firebase Genkit](https://github.com/firebase/genkit) - Google 的 AI 框架,用於構建生產級 AI 代理
-
-**設計理念**:
-
-- **以工具為中心**:AI 代理通過調用工具與世界互動
-- **註冊表模式**:中央工具註冊表(`ListTools()`、`LookupTool()`)消除手動列表維護
-- **結構化 I/O**:帶 JSON schema 驗證的類型化輸入/輸出
-- **框架管理生命週期**:Genkit 處理工具註冊、發現和調用
-
-**為什麼採用它**:
-
-- 零樣板代碼—註冊一次,隨處使用
-- 帶 schema 驗證的類型安全工具定義
-- 內建可觀察性和追蹤
-- 生產就緒的錯誤處理
diff --git a/internal/tools/benchmark_test.go b/internal/tools/benchmark_test.go
index 1b9e5c9..d6a4a0e 100644
--- a/internal/tools/benchmark_test.go
+++ b/internal/tools/benchmark_test.go
@@ -8,7 +8,7 @@ import (
// BenchmarkClampTopK benchmarks the clampTopK function.
func BenchmarkClampTopK(b *testing.B) {
b.ReportAllocs()
- for i := 0; i < b.N; i++ {
+ for b.Loop() {
_ = clampTopK(5, 3)
}
}
@@ -17,7 +17,7 @@ func BenchmarkClampTopK(b *testing.B) {
func BenchmarkResultConstruction(b *testing.B) {
b.Run("success", func(b *testing.B) {
b.ReportAllocs()
- for i := 0; i < b.N; i++ {
+ for b.Loop() {
_ = Result{
Status: StatusSuccess,
Data: map[string]any{"key": "value"},
@@ -27,7 +27,7 @@ func BenchmarkResultConstruction(b *testing.B) {
b.Run("error", func(b *testing.B) {
b.ReportAllocs()
- for i := 0; i < b.N; i++ {
+ for b.Loop() {
_ = Result{
Status: StatusError,
Error: &Error{Code: ErrCodeSecurity, Message: "test error"},
@@ -36,9 +36,9 @@ func BenchmarkResultConstruction(b *testing.B) {
})
}
-// BenchmarkNetworkToolsCreation benchmarks NetworkTools constructor.
-func BenchmarkNetworkToolsCreation(b *testing.B) {
- cfg := NetworkConfig{
+// BenchmarkNetworkCreation benchmarks Network constructor.
+func BenchmarkNetworkCreation(b *testing.B) {
+ cfg := NetConfig{
SearchBaseURL: "http://localhost:8080",
FetchParallelism: 2,
FetchDelay: time.Second,
@@ -48,20 +48,20 @@ func BenchmarkNetworkToolsCreation(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = NewNetworkTools(cfg, logger)
+ for b.Loop() {
+ _, _ = NewNetwork(cfg, logger)
}
}
// BenchmarkFilterURLs benchmarks URL filtering and validation.
func BenchmarkFilterURLs(b *testing.B) {
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: "http://localhost:8080",
FetchParallelism: 2,
FetchDelay: time.Second,
FetchTimeout: 30 * time.Second,
}
- nt, _ := NewNetworkTools(cfg, testLogger())
+ nt, _ := NewNetwork(cfg, testLogger())
urls := []string{
"https://example.com/",
@@ -74,21 +74,19 @@ func BenchmarkFilterURLs(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
- for i := 0; i < b.N; i++ {
+ for b.Loop() {
_, _ = nt.filterURLs(urls)
}
}
// BenchmarkExtractNonHTMLContent benchmarks non-HTML content extraction.
func BenchmarkExtractNonHTMLContent(b *testing.B) {
- nt := &NetworkTools{}
-
b.Run("json", func(b *testing.B) {
body := []byte(`{"key": "value", "nested": {"a": 1, "b": 2}}`)
b.ReportAllocs()
b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = nt.extractNonHTMLContent(body, "application/json")
+ for b.Loop() {
+ _, _ = extractNonHTMLContent(body, "application/json")
}
})
@@ -96,8 +94,8 @@ func BenchmarkExtractNonHTMLContent(b *testing.B) {
body := []byte("Plain text content here")
b.ReportAllocs()
b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = nt.extractNonHTMLContent(body, "text/plain")
+ for b.Loop() {
+ _, _ = extractNonHTMLContent(body, "text/plain")
}
})
@@ -112,8 +110,8 @@ func BenchmarkExtractNonHTMLContent(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _, _ = nt.extractNonHTMLContent(body, "application/json")
+ for b.Loop() {
+ _, _ = extractNonHTMLContent(body, "application/json")
}
})
}
@@ -133,7 +131,7 @@ func BenchmarkFetchState(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
- for i := 0; i < b.N; i++ {
+ for b.Loop() {
state.addResult(result)
}
})
@@ -145,23 +143,18 @@ func BenchmarkFetchState(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
- for i := 0; i < b.N; i++ {
+ i := 0
+ for b.Loop() {
state.markProcessed("https://example.com/page" + string(rune(i%100)))
+ i++
}
})
}
-// BenchmarkFileToolsCreation benchmarks FileTools constructor.
-func BenchmarkFileToolsCreation(b *testing.B) {
- // Note: This benchmark requires a valid path validator
- // Skip if we can't create one
- b.Skip("requires valid path validator setup")
-}
-
// BenchmarkKnowledgeSearchInput benchmarks the unified search input struct.
func BenchmarkKnowledgeSearchInput(b *testing.B) {
b.ReportAllocs()
- for i := 0; i < b.N; i++ {
+ for b.Loop() {
_ = KnowledgeSearchInput{
Query: "test query",
TopK: 5,
diff --git a/internal/tools/doc.go b/internal/tools/doc.go
index 533ab36..b29b97b 100644
--- a/internal/tools/doc.go
+++ b/internal/tools/doc.go
@@ -10,35 +10,36 @@
//
// Tools are organized into four categories:
//
-// - FileTools: File operations (read, write, list, delete, info)
-// - SystemTools: System operations (time, command execution, environment)
-// - NetworkTools: Network operations (web search, web fetch)
-// - knowledgeTools: Knowledge base operations (semantic search)
+// - File: File operations (read, write, list, delete, info)
+// - System: System operations (time, command execution, environment)
+// - Network: Network operations (web search, web fetch)
+// - Knowledge: Knowledge base operations (semantic search)
//
// Each tool struct is created with a constructor, then registered with Genkit.
//
// # Available Tools
//
-// File tools (FileTools):
+// File tools (File):
// - read_file: Read file contents (max 10MB)
// - write_file: Write or create files
// - list_files: List directory contents
// - delete_file: Delete a file
// - get_file_info: Get file metadata
//
-// System tools (SystemTools):
+// System tools (System):
// - current_time: Get current system time
// - execute_command: Execute shell commands (whitelist enforced)
// - get_env: Read environment variables (secrets protected)
//
-// Network tools (NetworkTools):
+// Network tools (Network):
// - web_search: Search via SearXNG
// - web_fetch: Fetch web content with SSRF protection
//
-// Knowledge tools (knowledgeTools):
+// Knowledge tools (Knowledge):
// - search_history: Search conversation history
// - search_documents: Search indexed documents
// - search_system_knowledge: Search system knowledge base
+// - knowledge_store: Store knowledge documents (when DocStore is available)
//
// # Security
//
@@ -50,7 +51,9 @@
//
// # Result Type
//
-// All tools return the unified Result type:
+// File, System, and Knowledge tools return the unified Result type.
+// Network tools (Search, Fetch) use typed output structs (SearchOutput, FetchOutput)
+// with an Error string field for LLM-facing business errors.
//
// type Result struct {
// Status Status // StatusSuccess or StatusError
@@ -97,11 +100,11 @@
//
// // Create tools with security validators
// pathVal, _ := security.NewPath([]string{"/allowed/path"})
-// fileTools, err := tools.NewFileTools(pathVal, logger)
+// fileTools, err := tools.NewFile(pathVal, logger)
// if err != nil {
// return err
// }
//
// // Register with Genkit
-// fileToolList, _ := tools.RegisterFileTools(g, fileTools)
+// fileToolList, _ := tools.RegisterFile(g, fileTools)
package tools
diff --git a/internal/tools/emitter.go b/internal/tools/emitter.go
index b130b66..4dc6ca3 100644
--- a/internal/tools/emitter.go
+++ b/internal/tools/emitter.go
@@ -1,25 +1,22 @@
-// Package tools provides tool abstractions for AI agent interactions.
package tools
import (
"context"
)
-// emitterKey uses empty struct for zero-allocation context key.
-// Per Rob Pike: empty struct is idiomatic for context keys.
+// emitterKey is an unexported context key for zero-allocation type safety.
type emitterKey struct{}
-// ToolEventEmitter receives tool lifecycle events.
+// Emitter receives tool lifecycle events.
// Interface is minimal - only tool name, no UI concerns.
-// Per architecture-master: Interface for loose coupling between tools and SSE layer.
-// UI presentation logic moved to web/handlers layer.
+// UI presentation logic is handled by the SSE/API layer.
//
// Usage:
// 1. Handler creates emitter bound to SSE writer
// 2. Handler stores emitter in context via ContextWithEmitter()
// 3. Wrapped tool retrieves emitter via EmitterFromContext()
// 4. Tool calls OnToolStart/Complete/Error during execution
-type ToolEventEmitter interface {
+type Emitter interface {
// OnToolStart signals that a tool has started execution.
// name: tool name (e.g., "web_search")
// UI presentation (messages, icons) handled by web layer.
@@ -35,16 +32,14 @@ type ToolEventEmitter interface {
OnToolError(name string)
}
-// EmitterFromContext retrieves ToolEventEmitter from context.
+// EmitterFromContext retrieves Emitter from context.
// Returns nil if not set, allowing graceful degradation (no events emitted).
-// Per architecture-master: Non-streaming code paths won't have emitter set.
-func EmitterFromContext(ctx context.Context) ToolEventEmitter {
- emitter, _ := ctx.Value(emitterKey{}).(ToolEventEmitter)
+func EmitterFromContext(ctx context.Context) Emitter {
+ emitter, _ := ctx.Value(emitterKey{}).(Emitter)
return emitter
}
-// ContextWithEmitter stores ToolEventEmitter in context.
-// Per architecture-master: Per-request binding via context.Context.
-func ContextWithEmitter(ctx context.Context, emitter ToolEventEmitter) context.Context {
+// ContextWithEmitter stores Emitter in context for per-request binding.
+func ContextWithEmitter(ctx context.Context, emitter Emitter) context.Context {
return context.WithValue(ctx, emitterKey{}, emitter)
}
diff --git a/internal/tools/emitter_test.go b/internal/tools/emitter_test.go
index f4761e1..1a318f7 100644
--- a/internal/tools/emitter_test.go
+++ b/internal/tools/emitter_test.go
@@ -7,7 +7,7 @@ import (
"github.com/koopa0/koopa/internal/tools"
)
-// mockEmitter is a test implementation of ToolEventEmitter.
+// mockEmitter is a test implementation of Emitter.
// Interface simplified to only tool name parameter.
type mockEmitter struct {
startCalls []string
@@ -27,8 +27,8 @@ func (m *mockEmitter) OnToolError(name string) {
m.errorCalls = append(m.errorCalls, name)
}
-// Verify mockEmitter implements ToolEventEmitter.
-var _ tools.ToolEventEmitter = (*mockEmitter)(nil)
+// Verify mockEmitter implements Emitter.
+var _ tools.Emitter = (*mockEmitter)(nil)
func TestContextWithEmitter(t *testing.T) {
t.Parallel()
@@ -43,7 +43,7 @@ func TestContextWithEmitter(t *testing.T) {
retrieved := tools.EmitterFromContext(ctxWithEmitter)
if retrieved == nil {
- t.Fatal("expected emitter to be retrieved from context")
+ t.Fatal("EmitterFromContext() = nil, want non-nil")
}
// Compare via interface method behavior instead of pointer equality
retrieved.OnToolStart("test")
@@ -65,7 +65,7 @@ func TestContextWithEmitter(t *testing.T) {
retrieved.OnToolStart("test")
// emitter2 should receive the call, not emitter1
if len(emitter2.startCalls) != 1 {
- t.Error("expected second emitter to overwrite first")
+ t.Error("ContextWithEmitter() did not overwrite previous emitter")
}
if len(emitter1.startCalls) != 0 {
t.Error("first emitter should not receive calls")
@@ -83,7 +83,7 @@ func TestEmitterFromContext(t *testing.T) {
emitter := tools.EmitterFromContext(ctx)
if emitter != nil {
- t.Error("expected nil emitter from empty context")
+ t.Error("EmitterFromContext(empty) = non-nil, want nil")
}
})
@@ -96,7 +96,7 @@ func TestEmitterFromContext(t *testing.T) {
emitter := tools.EmitterFromContext(ctx)
if emitter != nil {
- t.Error("expected nil for missing emitter")
+ t.Error("EmitterFromContext(no-emitter) = non-nil, want nil")
}
})
@@ -124,7 +124,7 @@ func TestEmitterInterface(t *testing.T) {
emitter.OnToolStart("web_search")
if len(emitter.startCalls) != 1 {
- t.Fatalf("expected 1 start call, got %d", len(emitter.startCalls))
+ t.Fatalf("OnToolStart() call count = %d, want 1", len(emitter.startCalls))
}
if emitter.startCalls[0] != "web_search" {
@@ -139,7 +139,7 @@ func TestEmitterInterface(t *testing.T) {
emitter.OnToolComplete("web_search")
if len(emitter.completeCalls) != 1 {
- t.Fatalf("expected 1 complete call, got %d", len(emitter.completeCalls))
+ t.Fatalf("OnToolComplete() call count = %d, want 1", len(emitter.completeCalls))
}
if emitter.completeCalls[0] != "web_search" {
@@ -154,7 +154,7 @@ func TestEmitterInterface(t *testing.T) {
emitter.OnToolError("web_search")
if len(emitter.errorCalls) != 1 {
- t.Fatalf("expected 1 error call, got %d", len(emitter.errorCalls))
+ t.Fatalf("OnToolError() call count = %d, want 1", len(emitter.errorCalls))
}
if emitter.errorCalls[0] != "web_search" {
diff --git a/internal/tools/events_test.go b/internal/tools/events_test.go
index 130e9e4..f66bfa3 100644
--- a/internal/tools/events_test.go
+++ b/internal/tools/events_test.go
@@ -8,7 +8,7 @@ import (
"github.com/firebase/genkit/go/ai"
)
-// mockEmitterForEvents is a test implementation of ToolEventEmitter.
+// mockEmitterForEvents is a test implementation of Emitter.
type mockEmitterForEvents struct {
startCalls []string
completeCalls []string
@@ -27,8 +27,8 @@ func (m *mockEmitterForEvents) OnToolError(name string) {
m.errorCalls = append(m.errorCalls, name)
}
-// Verify mockEmitterForEvents implements ToolEventEmitter.
-var _ ToolEventEmitter = (*mockEmitterForEvents)(nil)
+// Verify mockEmitterForEvents implements Emitter.
+var _ Emitter = (*mockEmitterForEvents)(nil)
func TestWithEvents_Success(t *testing.T) {
emitter := &mockEmitterForEvents{}
diff --git a/internal/tools/file.go b/internal/tools/file.go
index 9f7c81c..2d92d10 100644
--- a/internal/tools/file.go
+++ b/internal/tools/file.go
@@ -3,13 +3,13 @@ package tools
import (
"fmt"
"io"
+ "log/slog"
"os"
"path/filepath"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
- "github.com/koopa0/koopa/internal/log"
"github.com/koopa0/koopa/internal/security"
)
@@ -25,12 +25,18 @@ type FileEntry struct {
Type string `json:"type"` // "file" or "directory"
}
+// Tool name constants for file operations registered with Genkit.
const (
- ToolReadFile = "read_file"
- ToolWriteFile = "write_file"
- ToolListFiles = "list_files"
- ToolDeleteFile = "delete_file"
- ToolGetFileInfo = "get_file_info"
+ // ReadFileName is the Genkit tool name for reading file contents.
+ ReadFileName = "read_file"
+ // WriteFileName is the Genkit tool name for writing file contents.
+ WriteFileName = "write_file"
+ // ListFilesName is the Genkit tool name for listing directory contents.
+ ListFilesName = "list_files"
+ // DeleteFileName is the Genkit tool name for deleting files.
+ DeleteFileName = "delete_file"
+ // FileInfoName is the Genkit tool name for retrieving file metadata.
+ FileInfoName = "get_file_info"
)
// MaxReadFileSize is the maximum file size allowed for ReadFile (10 MB).
@@ -63,78 +69,78 @@ type GetFileInfoInput struct {
Path string `json:"path" jsonschema_description:"The file path to get info for"`
}
-// FileTools provides file operation handlers.
-// Use NewFileTools to create an instance, then either:
+// File provides file operation handlers.
+// Use NewFile to create an instance, then either:
// - Call methods directly (for MCP)
-// - Use RegisterFileTools to register with Genkit
-type FileTools struct {
+// - Use RegisterFile to register with Genkit
+type File struct {
pathVal *security.Path
- logger log.Logger
+ logger *slog.Logger
}
-// NewFileTools creates a FileTools instance.
-func NewFileTools(pathVal *security.Path, logger log.Logger) (*FileTools, error) {
+// NewFile creates a File instance.
+func NewFile(pathVal *security.Path, logger *slog.Logger) (*File, error) {
if pathVal == nil {
return nil, fmt.Errorf("path validator is required")
}
if logger == nil {
return nil, fmt.Errorf("logger is required")
}
- return &FileTools{pathVal: pathVal, logger: logger}, nil
+ return &File{pathVal: pathVal, logger: logger}, nil
}
-// RegisterFileTools registers all file operation tools with Genkit.
-func RegisterFileTools(g *genkit.Genkit, ft *FileTools) ([]ai.Tool, error) {
+// RegisterFile registers all file operation tools with Genkit.
+func RegisterFile(g *genkit.Genkit, ft *File) ([]ai.Tool, error) {
if g == nil {
return nil, fmt.Errorf("genkit instance is required")
}
if ft == nil {
- return nil, fmt.Errorf("FileTools is required")
+ return nil, fmt.Errorf("File is required")
}
return []ai.Tool{
- genkit.DefineTool(g, ToolReadFile,
+ genkit.DefineTool(g, ReadFileName,
"Read the complete content of a text-based file. "+
"Use this to examine source code, configuration files, logs, or documentation. "+
"Supports files up to 10MB. Binary files are not supported and will return an error. "+
"Returns: file path, content (UTF-8), size in bytes, and line count. "+
"Common errors: file not found (verify path with list_files), "+
"permission denied, file too large, binary file detected.",
- WithEvents(ToolReadFile, ft.ReadFile)),
- genkit.DefineTool(g, ToolWriteFile,
+ WithEvents(ReadFileName, ft.ReadFile)),
+ genkit.DefineTool(g, WriteFileName,
"Write or create a text-based file with the specified content. "+
"Creates parent directories automatically if they don't exist. "+
"Overwrites existing files without confirmation. "+
"Use this for: creating new files, updating configuration, saving generated content. "+
"Returns: file path, bytes written, whether file was created or updated. "+
"Common errors: permission denied, disk full, invalid path.",
- WithEvents(ToolWriteFile, ft.WriteFile)),
- genkit.DefineTool(g, ToolListFiles,
+ WithEvents(WriteFileName, ft.WriteFile)),
+ genkit.DefineTool(g, ListFilesName,
"List files and subdirectories in a directory. "+
"Returns file names, sizes, types (file/directory), and modification times. "+
"Does not recurse into subdirectories (use recursively for deep exploration). "+
"Use this to: explore project structure, find files by name, verify paths. "+
"Tip: Start from the project root and navigate down to find specific files.",
- WithEvents(ToolListFiles, ft.ListFiles)),
- genkit.DefineTool(g, ToolDeleteFile,
+ WithEvents(ListFilesName, ft.ListFiles)),
+ genkit.DefineTool(g, DeleteFileName,
"Permanently delete a file or empty directory. "+
"WARNING: This action cannot be undone. "+
"Only deletes empty directories (use with caution). "+
"Returns: confirmation of deletion with file path. "+
"Common errors: file not found, directory not empty, permission denied.",
- WithEvents(ToolDeleteFile, ft.DeleteFile)),
- genkit.DefineTool(g, ToolGetFileInfo,
+ WithEvents(DeleteFileName, ft.DeleteFile)),
+ genkit.DefineTool(g, FileInfoName,
"Get detailed metadata about a file without reading its contents. "+
"Returns: file size, modification time, permissions, and type (file/directory). "+
"Use this to: check if a file exists, verify file size before reading, "+
"determine file type without opening it. "+
"More efficient than read_file when you only need metadata.",
- WithEvents(ToolGetFileInfo, ft.GetFileInfo)),
+ WithEvents(FileInfoName, ft.GetFileInfo)),
}, nil
}
// ReadFile reads and returns the complete content of a file with security validation.
-func (f *FileTools) ReadFile(_ *ai.ToolContext, input ReadFileInput) (Result, error) {
+func (f *File) ReadFile(_ *ai.ToolContext, input ReadFileInput) (Result, error) {
f.logger.Info("ReadFile called", "path", input.Path)
safePath, err := f.pathVal.Validate(input.Path)
@@ -143,7 +149,7 @@ func (f *FileTools) ReadFile(_ *ai.ToolContext, input ReadFileInput) (Result, er
Status: StatusError,
Error: &Error{
Code: ErrCodeSecurity,
- Message: fmt.Sprintf("path validation failed: %v", err),
+ Message: fmt.Sprintf("validating path: %v", err),
},
}, nil
}
@@ -212,7 +218,7 @@ func (f *FileTools) ReadFile(_ *ai.ToolContext, input ReadFileInput) (Result, er
}
// WriteFile writes content to a file with security validation.
-func (f *FileTools) WriteFile(_ *ai.ToolContext, input WriteFileInput) (Result, error) {
+func (f *File) WriteFile(_ *ai.ToolContext, input WriteFileInput) (Result, error) {
f.logger.Info("WriteFile called", "path", input.Path)
safePath, err := f.pathVal.Validate(input.Path)
@@ -221,7 +227,7 @@ func (f *FileTools) WriteFile(_ *ai.ToolContext, input WriteFileInput) (Result,
Status: StatusError,
Error: &Error{
Code: ErrCodeSecurity,
- Message: fmt.Sprintf("path validation failed: %v", err),
+ Message: fmt.Sprintf("validating path: %v", err),
},
}, nil
}
@@ -270,7 +276,7 @@ func (f *FileTools) WriteFile(_ *ai.ToolContext, input WriteFileInput) (Result,
}
// ListFiles lists files in a directory.
-func (f *FileTools) ListFiles(_ *ai.ToolContext, input ListFilesInput) (Result, error) {
+func (f *File) ListFiles(_ *ai.ToolContext, input ListFilesInput) (Result, error) {
f.logger.Info("ListFiles called", "path", input.Path)
safePath, err := f.pathVal.Validate(input.Path)
@@ -279,7 +285,7 @@ func (f *FileTools) ListFiles(_ *ai.ToolContext, input ListFilesInput) (Result,
Status: StatusError,
Error: &Error{
Code: ErrCodeSecurity,
- Message: fmt.Sprintf("path validation failed: %v", err),
+ Message: fmt.Sprintf("validating path: %v", err),
},
}, nil
}
@@ -318,7 +324,7 @@ func (f *FileTools) ListFiles(_ *ai.ToolContext, input ListFilesInput) (Result,
}
// DeleteFile permanently deletes a file with security validation.
-func (f *FileTools) DeleteFile(_ *ai.ToolContext, input DeleteFileInput) (Result, error) {
+func (f *File) DeleteFile(_ *ai.ToolContext, input DeleteFileInput) (Result, error) {
f.logger.Info("DeleteFile called", "path", input.Path)
safePath, err := f.pathVal.Validate(input.Path)
@@ -327,7 +333,7 @@ func (f *FileTools) DeleteFile(_ *ai.ToolContext, input DeleteFileInput) (Result
Status: StatusError,
Error: &Error{
Code: ErrCodeSecurity,
- Message: fmt.Sprintf("path validation failed: %v", err),
+ Message: fmt.Sprintf("validating path: %v", err),
},
}, nil
}
@@ -351,7 +357,7 @@ func (f *FileTools) DeleteFile(_ *ai.ToolContext, input DeleteFileInput) (Result
}
// GetFileInfo gets file metadata.
-func (f *FileTools) GetFileInfo(_ *ai.ToolContext, input GetFileInfoInput) (Result, error) {
+func (f *File) GetFileInfo(_ *ai.ToolContext, input GetFileInfoInput) (Result, error) {
f.logger.Info("GetFileInfo called", "path", input.Path)
safePath, err := f.pathVal.Validate(input.Path)
@@ -360,7 +366,7 @@ func (f *FileTools) GetFileInfo(_ *ai.ToolContext, input GetFileInfoInput) (Resu
Status: StatusError,
Error: &Error{
Code: ErrCodeSecurity,
- Message: fmt.Sprintf("path validation failed: %v", err),
+ Message: fmt.Sprintf("validating path: %v", err),
},
}, nil
}
diff --git a/internal/tools/file_integration_test.go b/internal/tools/file_integration_test.go
index 34f67a1..ee0b204 100644
--- a/internal/tools/file_integration_test.go
+++ b/internal/tools/file_integration_test.go
@@ -9,7 +9,7 @@ import (
"github.com/koopa0/koopa/internal/security"
)
-// fileTools provides test utilities for FileTools.
+// fileTools provides test utilities for File.
type fileTools struct {
t *testing.T
tempDir string
@@ -21,20 +21,20 @@ func newfileTools(t *testing.T) *fileTools {
// Resolve symlinks (macOS /var -> /private/var)
realTempDir, err := filepath.EvalSymlinks(tempDir)
if err != nil {
- t.Fatalf("failed to resolve temp dir symlinks: %v", err)
+ t.Fatalf("resolving temp dir symlinks: %v", err)
}
return &fileTools{t: t, tempDir: realTempDir}
}
-func (h *fileTools) createFileTools() *FileTools {
+func (h *fileTools) createFile() *File {
h.t.Helper()
pathVal, err := security.NewPath([]string{h.tempDir})
if err != nil {
- h.t.Fatalf("failed to create path validator: %v", err)
+ h.t.Fatalf("creating path validator: %v", err)
}
- ft, err := NewFileTools(pathVal, testLogger())
+ ft, err := NewFile(pathVal, testLogger())
if err != nil {
- h.t.Fatalf("failed to create file tools: %v", err)
+ h.t.Fatalf("creating file tools: %v", err)
}
return ft
}
@@ -44,16 +44,12 @@ func (h *fileTools) createTestFile(name, content string) string {
path := filepath.Join(h.tempDir, name)
err := os.WriteFile(path, []byte(content), 0o600)
if err != nil {
- h.t.Fatalf("failed to create test file: %v", err)
+ h.t.Fatalf("creating test file: %v", err)
}
return path
}
-// ============================================================================
-// ReadFile Integration Tests
-// ============================================================================
-
-func TestFileTools_ReadFile_PathSecurity(t *testing.T) {
+func TestFile_ReadFile_PathSecurity(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -87,11 +83,11 @@ func TestFileTools_ReadFile_PathSecurity(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
result, err := ft.ReadFile(nil, ReadFileInput{Path: tt.path})
- // FileTools returns business errors in Result, not Go errors
+ // File returns business errors in Result, not Go errors
if err != nil {
t.Fatalf("ReadFile(%q) unexpected Go error: %v (should not return Go error)", tt.path, err)
}
@@ -104,18 +100,18 @@ func TestFileTools_ReadFile_PathSecurity(t *testing.T) {
if got, want := result.Error.Code, tt.wantErrCode; got != want {
t.Errorf("ReadFile(%q).Error.Code = %v, want %v", tt.path, got, want)
}
- if !strings.Contains(result.Error.Message, "path validation failed") {
- t.Errorf("ReadFile(%q).Error.Message = %q, want contains %q", tt.path, result.Error.Message, "path validation failed")
+ if !strings.Contains(result.Error.Message, "validating path") {
+ t.Errorf("ReadFile(%q).Error.Message = %q, want contains %q", tt.path, result.Error.Message, "validating path")
}
})
}
}
-func TestFileTools_ReadFile_Success(t *testing.T) {
+func TestFile_ReadFile_Success(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
// Create a test file
testContent := "Hello, World!"
@@ -143,11 +139,11 @@ func TestFileTools_ReadFile_Success(t *testing.T) {
}
}
-func TestFileTools_ReadFile_NotFound(t *testing.T) {
+func TestFile_ReadFile_NotFound(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
// Try to read non-existent file within allowed directory
nonExistentPath := filepath.Join(h.tempDir, "does-not-exist.txt")
@@ -168,11 +164,11 @@ func TestFileTools_ReadFile_NotFound(t *testing.T) {
}
}
-func TestFileTools_ReadFile_FileTooLarge(t *testing.T) {
+func TestFile_ReadFile_FileTooLarge(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
// Create a file larger than MaxReadFileSize (10MB)
largePath := filepath.Join(h.tempDir, "large.txt")
@@ -209,11 +205,7 @@ func TestFileTools_ReadFile_FileTooLarge(t *testing.T) {
}
}
-// ============================================================================
-// WriteFile Integration Tests
-// ============================================================================
-
-func TestFileTools_WriteFile_PathSecurity(t *testing.T) {
+func TestFile_WriteFile_PathSecurity(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -247,7 +239,7 @@ func TestFileTools_WriteFile_PathSecurity(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
result, err := ft.WriteFile(nil, WriteFileInput{
Path: tt.path,
@@ -270,11 +262,11 @@ func TestFileTools_WriteFile_PathSecurity(t *testing.T) {
}
}
-func TestFileTools_WriteFile_Success(t *testing.T) {
+func TestFile_WriteFile_Success(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
testPath := filepath.Join(h.tempDir, "new-file.txt")
testContent := "New content"
@@ -304,11 +296,11 @@ func TestFileTools_WriteFile_Success(t *testing.T) {
}
}
-func TestFileTools_WriteFile_CreatesDirectories(t *testing.T) {
+func TestFile_WriteFile_CreatesDirectories(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
// Write to a nested path that doesn't exist
nestedPath := filepath.Join(h.tempDir, "subdir", "nested", "file.txt")
@@ -336,11 +328,7 @@ func TestFileTools_WriteFile_CreatesDirectories(t *testing.T) {
}
}
-// ============================================================================
-// DeleteFile Integration Tests
-// ============================================================================
-
-func TestFileTools_DeleteFile_PathSecurity(t *testing.T) {
+func TestFile_DeleteFile_PathSecurity(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -368,7 +356,7 @@ func TestFileTools_DeleteFile_PathSecurity(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
result, err := ft.DeleteFile(nil, DeleteFileInput{Path: tt.path})
@@ -388,11 +376,11 @@ func TestFileTools_DeleteFile_PathSecurity(t *testing.T) {
}
}
-func TestFileTools_DeleteFile_Success(t *testing.T) {
+func TestFile_DeleteFile_Success(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
// Create a file to delete
testPath := h.createTestFile("to-delete.txt", "content")
@@ -417,11 +405,7 @@ func TestFileTools_DeleteFile_Success(t *testing.T) {
}
}
-// ============================================================================
-// ListFiles Integration Tests
-// ============================================================================
-
-func TestFileTools_ListFiles_PathSecurity(t *testing.T) {
+func TestFile_ListFiles_PathSecurity(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -449,7 +433,7 @@ func TestFileTools_ListFiles_PathSecurity(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
result, err := ft.ListFiles(nil, ListFilesInput{Path: tt.path})
@@ -469,11 +453,11 @@ func TestFileTools_ListFiles_PathSecurity(t *testing.T) {
}
}
-func TestFileTools_ListFiles_Success(t *testing.T) {
+func TestFile_ListFiles_Success(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
// Create some test files
h.createTestFile("file1.txt", "content1")
@@ -516,15 +500,11 @@ func TestFileTools_ListFiles_Success(t *testing.T) {
}
}
-// ============================================================================
-// GetFileInfo Integration Tests
-// ============================================================================
-
-func TestFileTools_GetFileInfo_PathSecurity(t *testing.T) {
+func TestFile_GetFileInfo_PathSecurity(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
result, err := ft.GetFileInfo(nil, GetFileInfoInput{Path: "/etc/passwd"})
@@ -542,11 +522,11 @@ func TestFileTools_GetFileInfo_PathSecurity(t *testing.T) {
}
}
-func TestFileTools_GetFileInfo_Success(t *testing.T) {
+func TestFile_GetFileInfo_Success(t *testing.T) {
t.Parallel()
h := newfileTools(t)
- ft := h.createFileTools()
+ ft := h.createFile()
// Create a test file
testPath := h.createTestFile("info.txt", "test content")
diff --git a/internal/tools/file_test.go b/internal/tools/file_test.go
index 904ed3c..33ed506 100644
--- a/internal/tools/file_test.go
+++ b/internal/tools/file_test.go
@@ -6,44 +6,44 @@ import (
"github.com/koopa0/koopa/internal/security"
)
-func TestFileTools_Constructor(t *testing.T) {
+func TestFile_Constructor(t *testing.T) {
t.Run("valid inputs", func(t *testing.T) {
pathVal, err := security.NewPath([]string{"/tmp"})
if err != nil {
- t.Fatalf("failed to create path validator: %v", err)
+ t.Fatalf("creating path validator: %v", err)
}
- ft, err := NewFileTools(pathVal, testLogger())
+ ft, err := NewFile(pathVal, testLogger())
if err != nil {
- t.Errorf("NewFileTools() error = %v, want nil", err)
+ t.Errorf("NewFile() error = %v, want nil", err)
}
if ft == nil {
- t.Error("NewFileTools() returned nil, want non-nil")
+ t.Error("NewFile() returned nil, want non-nil")
}
})
t.Run("nil path validator", func(t *testing.T) {
- ft, err := NewFileTools(nil, testLogger())
+ ft, err := NewFile(nil, testLogger())
if err == nil {
- t.Error("NewFileTools() error = nil, want error")
+ t.Error("NewFile() error = nil, want error")
}
if ft != nil {
- t.Error("NewFileTools() returned non-nil, want nil")
+ t.Error("NewFile() returned non-nil, want nil")
}
})
t.Run("nil logger", func(t *testing.T) {
pathVal, err := security.NewPath([]string{"/tmp"})
if err != nil {
- t.Fatalf("failed to create path validator: %v", err)
+ t.Fatalf("creating path validator: %v", err)
}
- ft, err := NewFileTools(pathVal, nil)
+ ft, err := NewFile(pathVal, nil)
if err == nil {
- t.Error("NewFileTools() error = nil, want error")
+ t.Error("NewFile() error = nil, want error")
}
if ft != nil {
- t.Error("NewFileTools() returned non-nil, want nil")
+ t.Error("NewFile() returned non-nil, want nil")
}
})
}
@@ -51,27 +51,27 @@ func TestFileTools_Constructor(t *testing.T) {
func TestFileToolConstants(t *testing.T) {
// Verify tool name constants are correct
expectedNames := map[string]string{
- "ToolReadFile": "read_file",
- "ToolWriteFile": "write_file",
- "ToolListFiles": "list_files",
- "ToolDeleteFile": "delete_file",
- "ToolGetFileInfo": "get_file_info",
+ "ReadFileName": "read_file",
+ "WriteFileName": "write_file",
+ "ListFilesName": "list_files",
+ "DeleteFileName": "delete_file",
+ "FileInfoName": "get_file_info",
}
- if ToolReadFile != expectedNames["ToolReadFile"] {
- t.Errorf("ToolReadFile = %q, want %q", ToolReadFile, expectedNames["ToolReadFile"])
+ if ReadFileName != expectedNames["ReadFileName"] {
+ t.Errorf("ReadFileName = %q, want %q", ReadFileName, expectedNames["ReadFileName"])
}
- if ToolWriteFile != expectedNames["ToolWriteFile"] {
- t.Errorf("ToolWriteFile = %q, want %q", ToolWriteFile, expectedNames["ToolWriteFile"])
+ if WriteFileName != expectedNames["WriteFileName"] {
+ t.Errorf("WriteFileName = %q, want %q", WriteFileName, expectedNames["WriteFileName"])
}
- if ToolListFiles != expectedNames["ToolListFiles"] {
- t.Errorf("ToolListFiles = %q, want %q", ToolListFiles, expectedNames["ToolListFiles"])
+ if ListFilesName != expectedNames["ListFilesName"] {
+ t.Errorf("ListFilesName = %q, want %q", ListFilesName, expectedNames["ListFilesName"])
}
- if ToolDeleteFile != expectedNames["ToolDeleteFile"] {
- t.Errorf("ToolDeleteFile = %q, want %q", ToolDeleteFile, expectedNames["ToolDeleteFile"])
+ if DeleteFileName != expectedNames["DeleteFileName"] {
+ t.Errorf("DeleteFileName = %q, want %q", DeleteFileName, expectedNames["DeleteFileName"])
}
- if ToolGetFileInfo != expectedNames["ToolGetFileInfo"] {
- t.Errorf("ToolGetFileInfo = %q, want %q", ToolGetFileInfo, expectedNames["ToolGetFileInfo"])
+ if FileInfoName != expectedNames["FileInfoName"] {
+ t.Errorf("FileInfoName = %q, want %q", FileInfoName, expectedNames["FileInfoName"])
}
}
@@ -82,45 +82,3 @@ func TestMaxReadFileSize(t *testing.T) {
t.Errorf("MaxReadFileSize = %d, want %d (10MB)", MaxReadFileSize, expected)
}
}
-
-func TestReadFileInput(t *testing.T) {
- // Test that ReadFileInput struct can be created
- input := ReadFileInput{Path: "/tmp/test.txt"}
- if input.Path != "/tmp/test.txt" {
- t.Errorf("ReadFileInput.Path = %q, want %q", input.Path, "/tmp/test.txt")
- }
-}
-
-func TestWriteFileInput(t *testing.T) {
- input := WriteFileInput{
- Path: "/tmp/test.txt",
- Content: "hello world",
- }
- if input.Path != "/tmp/test.txt" {
- t.Errorf("WriteFileInput.Path = %q, want %q", input.Path, "/tmp/test.txt")
- }
- if input.Content != "hello world" {
- t.Errorf("WriteFileInput.Content = %q, want %q", input.Content, "hello world")
- }
-}
-
-func TestListFilesInput(t *testing.T) {
- input := ListFilesInput{Path: "/tmp"}
- if input.Path != "/tmp" {
- t.Errorf("ListFilesInput.Path = %q, want %q", input.Path, "/tmp")
- }
-}
-
-func TestDeleteFileInput(t *testing.T) {
- input := DeleteFileInput{Path: "/tmp/test.txt"}
- if input.Path != "/tmp/test.txt" {
- t.Errorf("DeleteFileInput.Path = %q, want %q", input.Path, "/tmp/test.txt")
- }
-}
-
-func TestGetFileInfoInput(t *testing.T) {
- input := GetFileInfoInput{Path: "/tmp/test.txt"}
- if input.Path != "/tmp/test.txt" {
- t.Errorf("GetFileInfoInput.Path = %q, want %q", input.Path, "/tmp/test.txt")
- }
-}
diff --git a/internal/tools/fuzz_test.go b/internal/tools/fuzz_test.go
index 639167d..3fb4026 100644
--- a/internal/tools/fuzz_test.go
+++ b/internal/tools/fuzz_test.go
@@ -22,17 +22,17 @@ func FuzzClampTopK(f *testing.F) {
if topK <= 0 {
if result != defaultVal {
- t.Errorf("clampTopK(%d, %d) = %d, expected defaultVal %d for zero/negative topK",
+ t.Errorf("clampTopK(%d, %d) = %d, want %d for zero/negative topK",
topK, defaultVal, result, defaultVal)
}
} else if topK > 10 {
if result != 10 {
- t.Errorf("clampTopK(%d, %d) = %d, expected 10 for topK > 10",
+ t.Errorf("clampTopK(%d, %d) = %d, want 10 for topK > 10",
topK, defaultVal, result)
}
} else {
if result != topK {
- t.Errorf("clampTopK(%d, %d) = %d, expected %d for valid topK",
+ t.Errorf("clampTopK(%d, %d) = %d, want %d for valid topK",
topK, defaultVal, result, topK)
}
}
@@ -63,10 +63,6 @@ func FuzzResultConstruction(f *testing.F) {
})
}
-// =============================================================================
-// Security Fuzz Tests
-// =============================================================================
-
// FuzzPathTraversal tests path validation never panics and handles edge cases.
// The validator allows relative paths (resolved from working directory) and
// absolute paths within allowed directories.
@@ -227,7 +223,7 @@ func FuzzCommandInjection(f *testing.F) {
f.Fuzz(func(t *testing.T, cmd, args string) {
validator := security.NewCommand()
argList := strings.Fields(args)
- err := validator.ValidateCommand(cmd, argList)
+ err := validator.Validate(cmd, argList)
// If validation passes, verify it's not a dangerous command
if err == nil {
@@ -278,7 +274,7 @@ func FuzzEnvVarBypass(f *testing.F) {
f.Fuzz(func(t *testing.T, envName string) {
validator := security.NewEnv()
- err := validator.ValidateEnvAccess(envName)
+ err := validator.Validate(envName)
// If validation passes, verify it's not a sensitive variable
if err == nil {
diff --git a/internal/tools/knowledge.go b/internal/tools/knowledge.go
index 18af242..f1bce84 100644
--- a/internal/tools/knowledge.go
+++ b/internal/tools/knowledge.go
@@ -10,20 +10,25 @@ import (
"context"
"crypto/sha256"
"fmt"
+ "log/slog"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/plugins/postgresql"
- "github.com/koopa0/koopa/internal/log"
"github.com/koopa0/koopa/internal/rag"
)
+// Tool name constants for knowledge operations registered with Genkit.
const (
- ToolSearchHistory = "search_history"
- ToolSearchDocuments = "search_documents"
- ToolSearchSystemKnowledge = "search_system_knowledge"
- ToolStoreKnowledge = "knowledge_store"
+ // SearchHistoryName is the Genkit tool name for searching conversation history.
+ SearchHistoryName = "search_history"
+ // SearchDocumentsName is the Genkit tool name for searching indexed documents.
+ SearchDocumentsName = "search_documents"
+ // SearchSystemKnowledgeName is the Genkit tool name for searching system knowledge.
+ SearchSystemKnowledgeName = "search_system_knowledge"
+ // StoreKnowledgeName is the Genkit tool name for storing new knowledge documents.
+ StoreKnowledgeName = "knowledge_store"
)
// Default TopK values for knowledge searches.
@@ -47,67 +52,73 @@ type KnowledgeStoreInput struct {
Content string `json:"content" jsonschema_description:"The knowledge content to store"`
}
-// KnowledgeTools holds dependencies for knowledge operation handlers.
-type KnowledgeTools struct {
+// Knowledge holds dependencies for knowledge operation handlers.
+type Knowledge struct {
retriever ai.Retriever
docStore *postgresql.DocStore // nil disables knowledge_store tool
- logger log.Logger
+ logger *slog.Logger
}
-// NewKnowledgeTools creates a KnowledgeTools instance.
+// HasDocStore reports whether the document store is available.
+// Used by MCP server to conditionally register the knowledge_store tool.
+func (k *Knowledge) HasDocStore() bool {
+ return k.docStore != nil
+}
+
+// NewKnowledge creates a Knowledge instance.
// docStore is optional: when nil, the knowledge_store tool is not registered.
-func NewKnowledgeTools(retriever ai.Retriever, docStore *postgresql.DocStore, logger log.Logger) (*KnowledgeTools, error) {
+func NewKnowledge(retriever ai.Retriever, docStore *postgresql.DocStore, logger *slog.Logger) (*Knowledge, error) {
if retriever == nil {
return nil, fmt.Errorf("retriever is required")
}
if logger == nil {
return nil, fmt.Errorf("logger is required")
}
- return &KnowledgeTools{retriever: retriever, docStore: docStore, logger: logger}, nil
+ return &Knowledge{retriever: retriever, docStore: docStore, logger: logger}, nil
}
-// RegisterKnowledgeTools registers all knowledge search tools with Genkit.
+// RegisterKnowledge registers all knowledge search tools with Genkit.
// Tools are registered with event emission wrappers for streaming support.
-func RegisterKnowledgeTools(g *genkit.Genkit, kt *KnowledgeTools) ([]ai.Tool, error) {
+func RegisterKnowledge(g *genkit.Genkit, kt *Knowledge) ([]ai.Tool, error) {
if g == nil {
return nil, fmt.Errorf("genkit instance is required")
}
if kt == nil {
- return nil, fmt.Errorf("KnowledgeTools is required")
+ return nil, fmt.Errorf("Knowledge is required")
}
tools := []ai.Tool{
- genkit.DefineTool(g, ToolSearchHistory,
+ genkit.DefineTool(g, SearchHistoryName,
"Search conversation history using semantic similarity. "+
"Finds past exchanges that are conceptually related to the query. "+
"Returns: matched conversation turns with timestamps and similarity scores. "+
"Use this to: recall past discussions, find context from earlier conversations. "+
"Default topK: 3. Maximum topK: 10.",
- WithEvents(ToolSearchHistory, kt.SearchHistory)),
- genkit.DefineTool(g, ToolSearchDocuments,
+ WithEvents(SearchHistoryName, kt.SearchHistory)),
+ genkit.DefineTool(g, SearchDocumentsName,
"Search indexed documents (PDFs, code files, notes) using semantic similarity. "+
"Finds document sections that are conceptually related to the query. "+
"Returns: document titles, content excerpts, and similarity scores. "+
"Use this to: find relevant documentation, locate code examples, research topics. "+
"Default topK: 5. Maximum topK: 10.",
- WithEvents(ToolSearchDocuments, kt.SearchDocuments)),
- genkit.DefineTool(g, ToolSearchSystemKnowledge,
+ WithEvents(SearchDocumentsName, kt.SearchDocuments)),
+ genkit.DefineTool(g, SearchSystemKnowledgeName,
"Search system knowledge base (tool usage, commands, patterns) using semantic similarity. "+
"Finds internal system documentation and usage patterns. "+
"Returns: knowledge entries with descriptions and examples. "+
"Use this to: understand tool capabilities, find command syntax, learn system patterns. "+
"Default topK: 3. Maximum topK: 10.",
- WithEvents(ToolSearchSystemKnowledge, kt.SearchSystemKnowledge)),
+ WithEvents(SearchSystemKnowledgeName, kt.SearchSystemKnowledge)),
}
// Register knowledge_store only when DocStore is available.
if kt.docStore != nil {
- tools = append(tools, genkit.DefineTool(g, ToolStoreKnowledge,
+ tools = append(tools, genkit.DefineTool(g, StoreKnowledgeName,
"Store a knowledge entry for later retrieval via search_documents. "+
"Use this to save important information, notes, or learnings "+
"that the user wants to remember across sessions. "+
"Each entry gets a unique ID and is indexed for semantic search.",
- WithEvents(ToolStoreKnowledge, kt.StoreKnowledge)))
+ WithEvents(StoreKnowledgeName, kt.StoreKnowledge)))
}
return tools, nil
@@ -135,13 +146,17 @@ var validSourceTypes = map[string]bool{
// search performs a knowledge search with the given source type filter.
// Returns error if sourceType is not in the allowed whitelist.
-func (k *KnowledgeTools) search(ctx context.Context, query string, topK int, sourceType string) ([]*ai.Document, error) {
+func (k *Knowledge) search(ctx context.Context, query string, topK int, sourceType string) ([]*ai.Document, error) {
// Validate source type against whitelist (SQL injection prevention)
if !validSourceTypes[sourceType] {
return nil, fmt.Errorf("invalid source type: %q", sourceType)
}
- // Build WHERE clause filter for source_type (safe: sourceType is validated)
+ // Build WHERE clause filter for source_type.
+ // SECURITY: sourceType is SQL injection-safe because it's validated against
+ // a hardcoded whitelist (validSourceTypes). This filter is passed to the
+ // Genkit PostgreSQL retriever which includes it in a SQL query.
+ // DO NOT bypass the whitelist validation above.
filter := fmt.Sprintf("source_type = '%s'", sourceType)
req := &ai.RetrieverRequest{
@@ -161,7 +176,7 @@ func (k *KnowledgeTools) search(ctx context.Context, query string, topK int, sou
}
// SearchHistory searches conversation history using semantic similarity.
-func (k *KnowledgeTools) SearchHistory(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) {
+func (k *Knowledge) SearchHistory(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) {
k.logger.Info("SearchHistory called", "query", input.Query, "topK", input.TopK)
topK := clampTopK(input.TopK, DefaultHistoryTopK)
@@ -173,7 +188,7 @@ func (k *KnowledgeTools) SearchHistory(ctx *ai.ToolContext, input KnowledgeSearc
Status: StatusError,
Error: &Error{
Code: ErrCodeExecution,
- Message: fmt.Sprintf("history search failed: %v", err),
+ Message: fmt.Sprintf("searching history: %v", err),
},
}, nil
}
@@ -190,7 +205,7 @@ func (k *KnowledgeTools) SearchHistory(ctx *ai.ToolContext, input KnowledgeSearc
}
// SearchDocuments searches indexed documents using semantic similarity.
-func (k *KnowledgeTools) SearchDocuments(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) {
+func (k *Knowledge) SearchDocuments(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) {
k.logger.Info("SearchDocuments called", "query", input.Query, "topK", input.TopK)
topK := clampTopK(input.TopK, DefaultDocumentsTopK)
@@ -202,7 +217,7 @@ func (k *KnowledgeTools) SearchDocuments(ctx *ai.ToolContext, input KnowledgeSea
Status: StatusError,
Error: &Error{
Code: ErrCodeExecution,
- Message: fmt.Sprintf("document search failed: %v", err),
+ Message: fmt.Sprintf("searching documents: %v", err),
},
}, nil
}
@@ -219,7 +234,7 @@ func (k *KnowledgeTools) SearchDocuments(ctx *ai.ToolContext, input KnowledgeSea
}
// StoreKnowledge stores a new knowledge document for later retrieval.
-func (k *KnowledgeTools) StoreKnowledge(ctx *ai.ToolContext, input KnowledgeStoreInput) (Result, error) {
+func (k *Knowledge) StoreKnowledge(ctx *ai.ToolContext, input KnowledgeStoreInput) (Result, error) {
k.logger.Info("StoreKnowledge called", "title", input.Title)
if k.docStore == nil {
@@ -252,6 +267,7 @@ func (k *KnowledgeTools) StoreKnowledge(ctx *ai.ToolContext, input KnowledgeStor
}
// Generate a deterministic document ID from the title using SHA-256.
+ // Changing the title creates a new document; the old entry remains.
// Prefix "user:" namespaces user-created knowledge (vs "system:" for built-in).
docID := fmt.Sprintf("user:%x", sha256.Sum256([]byte(input.Title)))
@@ -267,7 +283,7 @@ func (k *KnowledgeTools) StoreKnowledge(ctx *ai.ToolContext, input KnowledgeStor
Status: StatusError,
Error: &Error{
Code: ErrCodeExecution,
- Message: fmt.Sprintf("failed to store knowledge: %v", err),
+ Message: fmt.Sprintf("storing knowledge: %v", err),
},
}, nil
}
@@ -283,7 +299,7 @@ func (k *KnowledgeTools) StoreKnowledge(ctx *ai.ToolContext, input KnowledgeStor
}
// SearchSystemKnowledge searches system knowledge base using semantic similarity.
-func (k *KnowledgeTools) SearchSystemKnowledge(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) {
+func (k *Knowledge) SearchSystemKnowledge(ctx *ai.ToolContext, input KnowledgeSearchInput) (Result, error) {
k.logger.Info("SearchSystemKnowledge called", "query", input.Query, "topK", input.TopK)
topK := clampTopK(input.TopK, DefaultSystemKnowledgeTopK)
@@ -295,7 +311,7 @@ func (k *KnowledgeTools) SearchSystemKnowledge(ctx *ai.ToolContext, input Knowle
Status: StatusError,
Error: &Error{
Code: ErrCodeExecution,
- Message: fmt.Sprintf("system knowledge search failed: %v", err),
+ Message: fmt.Sprintf("searching system knowledge: %v", err),
},
}, nil
}
diff --git a/internal/tools/knowledge_test.go b/internal/tools/knowledge_test.go
index 2e0abbe..4500edf 100644
--- a/internal/tools/knowledge_test.go
+++ b/internal/tools/knowledge_test.go
@@ -2,11 +2,11 @@ package tools
import (
"context"
+ "log/slog"
"testing"
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/core/api"
- "github.com/koopa0/koopa/internal/log"
)
// mockRetriever is a minimal ai.Retriever implementation for testing.
@@ -25,12 +25,12 @@ func TestClampTopK(t *testing.T) {
defaultVal int
want int
}{
- {"zero uses default", 0, 3, 3},
- {"negative uses default", -5, 5, 5},
- {"value in range unchanged", 5, 3, 5},
- {"max boundary", 10, 3, 10},
- {"exceeds max clamped to 10", 50, 3, 10},
- {"min value", 1, 3, 1},
+ {name: "zero uses default", topK: 0, defaultVal: 3, want: 3},
+ {name: "negative uses default", topK: -5, defaultVal: 5, want: 5},
+ {name: "value in range unchanged", topK: 5, defaultVal: 3, want: 5},
+ {name: "max boundary", topK: 10, defaultVal: 3, want: 10},
+ {name: "exceeds max clamped to 10", topK: 50, defaultVal: 3, want: 10},
+ {name: "min value", topK: 1, defaultVal: 3, want: 1},
}
for _, tt := range tests {
@@ -44,30 +44,30 @@ func TestClampTopK(t *testing.T) {
}
func TestKnowledgeToolConstants(t *testing.T) {
- if ToolSearchHistory != "search_history" {
- t.Errorf("ToolSearchHistory = %q, want %q", ToolSearchHistory, "search_history")
+ if SearchHistoryName != "search_history" {
+ t.Errorf("SearchHistoryName = %q, want %q", SearchHistoryName, "search_history")
}
- if ToolSearchDocuments != "search_documents" {
- t.Errorf("ToolSearchDocuments = %q, want %q", ToolSearchDocuments, "search_documents")
+ if SearchDocumentsName != "search_documents" {
+ t.Errorf("SearchDocumentsName = %q, want %q", SearchDocumentsName, "search_documents")
}
- if ToolSearchSystemKnowledge != "search_system_knowledge" {
- t.Errorf("ToolSearchSystemKnowledge = %q, want %q", ToolSearchSystemKnowledge, "search_system_knowledge")
+ if SearchSystemKnowledgeName != "search_system_knowledge" {
+ t.Errorf("SearchSystemKnowledgeName = %q, want %q", SearchSystemKnowledgeName, "search_system_knowledge")
}
- if ToolStoreKnowledge != "knowledge_store" {
- t.Errorf("ToolStoreKnowledge = %q, want %q", ToolStoreKnowledge, "knowledge_store")
+ if StoreKnowledgeName != "knowledge_store" {
+ t.Errorf("StoreKnowledgeName = %q, want %q", StoreKnowledgeName, "knowledge_store")
}
}
-func TestNewKnowledgeTools(t *testing.T) {
+func TestNewKnowledge(t *testing.T) {
t.Run("nil retriever returns error", func(t *testing.T) {
- if _, err := NewKnowledgeTools(nil, nil, log.NewNop()); err == nil {
- t.Error("expected error for nil retriever")
+ if _, err := NewKnowledge(nil, nil, slog.New(slog.DiscardHandler)); err == nil {
+ t.Error("NewKnowledge(nil, nil, logger) error = nil, want non-nil")
}
})
t.Run("nil logger returns error", func(t *testing.T) {
- if _, err := NewKnowledgeTools(&mockRetriever{}, nil, nil); err == nil {
- t.Error("expected error for nil logger")
+ if _, err := NewKnowledge(&mockRetriever{}, nil, nil); err == nil {
+ t.Error("NewKnowledge(retriever, nil, nil) error = nil, want non-nil")
}
})
}
@@ -92,7 +92,7 @@ func TestValidSourceTypes(t *testing.T) {
validTypes := []string{"conversation", "file", "system"}
for _, st := range validTypes {
if !validSourceTypes[st] {
- t.Errorf("expected %q to be valid source type", st)
+ t.Errorf("validSourceTypes[%q] = false, want true", st)
}
}
@@ -106,7 +106,7 @@ func TestValidSourceTypes(t *testing.T) {
}
for _, st := range invalidTypes {
if validSourceTypes[st] {
- t.Errorf("expected %q to be invalid source type (SQL injection risk)", st)
+ t.Errorf("validSourceTypes[%q] = true, want false (SQL injection risk)", st)
}
}
}
diff --git a/internal/tools/network.go b/internal/tools/network.go
index 6ef3986..32f3c64 100644
--- a/internal/tools/network.go
+++ b/internal/tools/network.go
@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "log/slog"
"net/http"
"net/url"
"strings"
@@ -16,13 +17,15 @@ import (
"github.com/go-shiori/go-readability"
"github.com/gocolly/colly/v2"
- "github.com/koopa0/koopa/internal/log"
"github.com/koopa0/koopa/internal/security"
)
+// Tool name constants for network operations registered with Genkit.
const (
- ToolWebSearch = "web_search"
- ToolWebFetch = "web_fetch"
+ // WebSearchName is the Genkit tool name for performing web searches.
+ WebSearchName = "web_search"
+ // WebFetchName is the Genkit tool name for fetching web page content.
+ WebFetchName = "web_fetch"
)
// Content limits.
@@ -39,11 +42,11 @@ const (
MaxRedirects = 5
)
-// NetworkTools holds dependencies for network operation handlers.
-// Use NewNetworkTools to create an instance, then either:
+// Network holds dependencies for network operation handlers.
+// Use NewNetwork to create an instance, then either:
// - Call methods directly (for MCP)
-// - Use RegisterNetworkTools to register with Genkit
-type NetworkTools struct {
+// - Use RegisterNetwork to register with Genkit
+type Network struct {
// Search configuration (SearXNG)
searchBaseURL string
searchClient *http.Client
@@ -60,19 +63,19 @@ type NetworkTools struct {
// Only settable within the tools package (unexported field).
skipSSRFCheck bool
- logger log.Logger
+ logger *slog.Logger
}
-// NetworkConfig holds configuration for network tools.
-type NetworkConfig struct {
+// NetConfig holds configuration for network tools.
+type NetConfig struct {
SearchBaseURL string
FetchParallelism int
FetchDelay time.Duration
FetchTimeout time.Duration
}
-// NewNetworkTools creates a NetworkTools instance.
-func NewNetworkTools(cfg NetworkConfig, logger log.Logger) (*NetworkTools, error) {
+// NewNetwork creates a Network instance.
+func NewNetwork(cfg NetConfig, logger *slog.Logger) (*Network, error) {
if cfg.SearchBaseURL == "" {
return nil, fmt.Errorf("search base URL is required")
}
@@ -91,33 +94,38 @@ func NewNetworkTools(cfg NetworkConfig, logger log.Logger) (*NetworkTools, error
cfg.FetchTimeout = 30 * time.Second
}
- return &NetworkTools{
- searchBaseURL: strings.TrimSuffix(cfg.SearchBaseURL, "/"),
- searchClient: &http.Client{Timeout: 30 * time.Second},
+ urlValidator := security.NewURL()
+
+ return &Network{
+ searchBaseURL: strings.TrimSuffix(cfg.SearchBaseURL, "/"),
+ searchClient: &http.Client{
+ Timeout: 30 * time.Second,
+ Transport: urlValidator.SafeTransport(),
+ },
fetchParallelism: cfg.FetchParallelism,
fetchDelay: cfg.FetchDelay,
fetchTimeout: cfg.FetchTimeout,
- urlValidator: security.NewURL(),
+ urlValidator: urlValidator,
logger: logger,
}, nil
}
-// RegisterNetworkTools registers all network operation tools with Genkit.
+// RegisterNetwork registers all network operation tools with Genkit.
// Tools are registered with event emission wrappers for streaming support.
-func RegisterNetworkTools(g *genkit.Genkit, nt *NetworkTools) ([]ai.Tool, error) {
+func RegisterNetwork(g *genkit.Genkit, nt *Network) ([]ai.Tool, error) {
if g == nil {
return nil, fmt.Errorf("genkit instance is required")
}
if nt == nil {
- return nil, fmt.Errorf("NetworkTools is required")
+ return nil, fmt.Errorf("Network is required")
}
return []ai.Tool{
- genkit.DefineTool(g, ToolWebSearch,
+ genkit.DefineTool(g, WebSearchName,
"Search the web for information. Returns relevant results with titles, URLs, and content snippets. "+
"Use this to find current information, news, or facts from the internet.",
- WithEvents(ToolWebSearch, nt.Search)),
- genkit.DefineTool(g, ToolWebFetch,
+ WithEvents(WebSearchName, nt.Search)),
+ genkit.DefineTool(g, WebFetchName,
"Fetch and extract content from one or more URLs (max 10). "+
"Supports HTML pages, JSON APIs, and plain text. "+
"For HTML: uses Readability algorithm to extract main content. "+
@@ -126,14 +134,10 @@ func RegisterNetworkTools(g *genkit.Genkit, nt *NetworkTools) ([]ai.Tool, error)
"Supports parallel fetching with rate limiting. "+
"Returns extracted content (max 50KB per URL). "+
"Note: Does not render JavaScript - for SPA pages, content may be incomplete.",
- WithEvents(ToolWebFetch, nt.Fetch)),
+ WithEvents(WebFetchName, nt.Fetch)),
}, nil
}
-// ============================================================================
-// web_search: Search the web via SearXNG
-// ============================================================================
-
// SearchInput defines the input for web_search tool.
type SearchInput struct {
// Query is the search query string. Required.
@@ -166,7 +170,7 @@ type SearchResult struct {
}
// Search performs web search via SearXNG.
-func (n *NetworkTools) Search(ctx *ai.ToolContext, input SearchInput) (SearchOutput, error) {
+func (n *Network) Search(ctx *ai.ToolContext, input SearchInput) (SearchOutput, error) {
// Validate required fields
if strings.TrimSpace(input.Query) == "" {
return SearchOutput{Error: "Query is required. Please provide a search query."}, nil
@@ -202,8 +206,7 @@ func (n *NetworkTools) Search(ctx *ai.ToolContext, input SearchInput) (SearchOut
// Execute request
resp, err := n.searchClient.Do(req)
if err != nil {
- n.logger.Error("search request failed", "error", err)
- return SearchOutput{}, fmt.Errorf("search request failed: %w", err)
+ return SearchOutput{}, fmt.Errorf("executing search request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
@@ -275,10 +278,6 @@ type searxngResponse struct {
Results []searxngResult `json:"results"`
}
-// ============================================================================
-// web_fetch: Fetch and extract content from URLs
-// ============================================================================
-
// FetchInput defines the input for web_fetch tool.
type FetchInput struct {
// URLs is one or more URLs to fetch. Required, max 10.
@@ -352,7 +351,7 @@ func (s *fetchState) markProcessed(urlStr string) bool {
// Fetch retrieves and extracts content from one or more URLs.
// Includes SSRF protection to block private IPs and cloud metadata endpoints.
-func (n *NetworkTools) Fetch(ctx *ai.ToolContext, input FetchInput) (FetchOutput, error) {
+func (n *Network) Fetch(ctx *ai.ToolContext, input FetchInput) (FetchOutput, error) {
// Validate input
if len(input.URLs) == 0 {
return FetchOutput{Error: "At least one URL is required. Please provide URLs to fetch."}, nil
@@ -403,7 +402,7 @@ func (n *NetworkTools) Fetch(ctx *ai.ToolContext, input FetchInput) (FetchOutput
}
// filterURLs validates and deduplicates URLs, returning safe and failed lists.
-func (n *NetworkTools) filterURLs(urls []string) (safe []string, failed []FailedURL) {
+func (n *Network) filterURLs(urls []string) (safe []string, failed []FailedURL) {
urlSet := make(map[string]struct{})
for _, u := range urls {
if _, exists := urlSet[u]; exists {
@@ -424,7 +423,7 @@ func (n *NetworkTools) filterURLs(urls []string) (safe []string, failed []Failed
}
// createCollector creates a configured Colly collector.
-func (n *NetworkTools) createCollector() *colly.Collector {
+func (n *Network) createCollector() *colly.Collector {
c := colly.NewCollector(
colly.Async(true),
colly.MaxDepth(1),
@@ -462,7 +461,7 @@ func (n *NetworkTools) createCollector() *colly.Collector {
}
// setupCallbacks configures all Colly callbacks.
-func (n *NetworkTools) setupCallbacks(c *colly.Collector, ctx *ai.ToolContext, state *fetchState, selector string) {
+func (n *Network) setupCallbacks(c *colly.Collector, ctx *ai.ToolContext, state *fetchState, selector string) {
c.OnRequest(func(r *colly.Request) {
select {
case <-ctx.Done():
@@ -472,7 +471,7 @@ func (n *NetworkTools) setupCallbacks(c *colly.Collector, ctx *ai.ToolContext, s
})
c.OnResponse(func(r *colly.Response) {
- n.handleNonHTMLResponse(r, state)
+ handleNonHTMLResponse(r, state)
})
c.OnHTML("html", func(e *colly.HTMLElement) {
@@ -485,7 +484,7 @@ func (n *NetworkTools) setupCallbacks(c *colly.Collector, ctx *ai.ToolContext, s
}
// handleNonHTMLResponse processes JSON, XML, and text responses.
-func (n *NetworkTools) handleNonHTMLResponse(r *colly.Response, state *fetchState) {
+func handleNonHTMLResponse(r *colly.Response, state *fetchState) {
contentType := r.Headers.Get("Content-Type")
if strings.Contains(contentType, "text/html") {
return
@@ -496,7 +495,7 @@ func (n *NetworkTools) handleNonHTMLResponse(r *colly.Response, state *fetchStat
return
}
- title, content := n.extractNonHTMLContent(r.Body, contentType)
+ title, content := extractNonHTMLContent(r.Body, contentType)
if len(content) > MaxContentLength {
content = content[:MaxContentLength] + "\n\n[Content truncated...]"
}
@@ -510,7 +509,7 @@ func (n *NetworkTools) handleNonHTMLResponse(r *colly.Response, state *fetchStat
}
// extractNonHTMLContent extracts content from non-HTML responses.
-func (*NetworkTools) extractNonHTMLContent(body []byte, contentType string) (title, content string) {
+func extractNonHTMLContent(body []byte, contentType string) (title, content string) {
switch {
case strings.Contains(contentType, "application/json"):
var jsonData any
@@ -532,7 +531,7 @@ func (*NetworkTools) extractNonHTMLContent(body []byte, contentType string) (tit
}
// handleHTMLResponse processes HTML responses using readability.
-func (n *NetworkTools) handleHTMLResponse(e *colly.HTMLElement, state *fetchState, selector string) {
+func (n *Network) handleHTMLResponse(e *colly.HTMLElement, state *fetchState, selector string) {
urlStr := e.Request.URL.String()
if state.markProcessed(urlStr) {
return
@@ -558,7 +557,7 @@ func (n *NetworkTools) handleHTMLResponse(e *colly.HTMLElement, state *fetchStat
}
// handleError processes fetch errors.
-func (n *NetworkTools) handleError(r *colly.Response, err error, state *fetchState) {
+func (n *Network) handleError(r *colly.Response, err error, state *fetchState) {
reason := err.Error()
statusCode := 0
if r.StatusCode > 0 {
@@ -578,7 +577,7 @@ func (n *NetworkTools) handleError(r *colly.Response, err error, state *fetchSta
// extractWithReadability extracts content using go-readability with CSS selector fallback.
// Returns (title, content).
// u should be the final URL after redirects (e.Request.URL from Colly).
-func (n *NetworkTools) extractWithReadability(u *url.URL, html string, e *colly.HTMLElement, selector string) (title, content string) {
+func (n *Network) extractWithReadability(u *url.URL, html string, e *colly.HTMLElement, selector string) (title, content string) {
// Try go-readability first (Mozilla Readability algorithm)
// u is already parsed and reflects the final URL after any redirects
article, err := readability.FromReader(bytes.NewReader([]byte(html)), u)
@@ -590,7 +589,7 @@ func (n *NetworkTools) extractWithReadability(u *url.URL, html string, e *colly.
}
// Convert HTML content to plain text (remove tags)
- content = n.htmlToText(article.Content)
+ content = htmlToText(article.Content)
if content != "" {
return title, content
}
@@ -598,12 +597,12 @@ func (n *NetworkTools) extractWithReadability(u *url.URL, html string, e *colly.
// Readability failed or returned empty - fallback to CSS selector
n.logger.Debug("readability fallback to selector", "url", u.String())
- return n.extractWithSelector(e, selector)
+ return extractWithSelector(e, selector)
}
// extractWithSelector extracts content using CSS selectors (fallback method).
// Returns (title, content).
-func (*NetworkTools) extractWithSelector(e *colly.HTMLElement, selector string) (extractedTitle, extractedContent string) {
+func extractWithSelector(e *colly.HTMLElement, selector string) (extractedTitle, extractedContent string) {
// Extract title
extractedTitle = e.ChildText("title")
if extractedTitle == "" {
@@ -631,7 +630,7 @@ func (*NetworkTools) extractWithSelector(e *colly.HTMLElement, selector string)
}
// htmlToText converts HTML content to plain text.
-func (*NetworkTools) htmlToText(html string) string {
+func htmlToText(html string) string {
doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
if err != nil {
return ""
diff --git a/internal/tools/network_integration_test.go b/internal/tools/network_integration_test.go
index 642c54f..3010d0e 100644
--- a/internal/tools/network_integration_test.go
+++ b/internal/tools/network_integration_test.go
@@ -12,7 +12,7 @@ import (
"github.com/firebase/genkit/go/ai"
)
-// networkTools provides test utilities for NetworkTools.
+// networkTools provides test utilities for Network.
type networkTools struct {
t *testing.T
}
@@ -22,17 +22,17 @@ func newnetworkTools(t *testing.T) *networkTools {
return &networkTools{t: t}
}
-func (h *networkTools) createNetworkTools(serverURL string) *NetworkTools {
+func (h *networkTools) createNetwork(serverURL string) *Network {
h.t.Helper()
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: serverURL,
FetchParallelism: 2,
FetchDelay: 10 * time.Millisecond,
FetchTimeout: 5 * time.Second,
}
- nt, err := NewNetworkTools(cfg, testLogger())
+ nt, err := NewNetwork(cfg, testLogger())
if err != nil {
- h.t.Fatalf("failed to create network tools: %v", err)
+ h.t.Fatalf("creating network tools: %v", err)
}
return nt
}
@@ -61,11 +61,7 @@ func (*networkTools) toolContext() *ai.ToolContext {
return &ai.ToolContext{Context: context.Background()}
}
-// ============================================================================
-// SSRF Protection Tests - Blocked Hosts
-// ============================================================================
-
-func TestNetworkTools_Fetch_SSRFBlockedHosts(t *testing.T) {
+func TestNetwork_Fetch_SSRFBlockedHosts(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -165,7 +161,7 @@ func TestNetworkTools_Fetch_SSRFBlockedHosts(t *testing.T) {
h := newnetworkTools(t)
server := h.createMockServer()
- nt := h.createNetworkTools(server.URL)
+ nt := h.createNetwork(server.URL)
ctx := h.toolContext()
output, err := nt.Fetch(ctx, FetchInput{URLs: []string{tt.url}})
@@ -192,11 +188,7 @@ func TestNetworkTools_Fetch_SSRFBlockedHosts(t *testing.T) {
}
}
-// ============================================================================
-// SSRF Protection Tests - Scheme Validation
-// ============================================================================
-
-func TestNetworkTools_Fetch_SchemeValidation(t *testing.T) {
+func TestNetwork_Fetch_SchemeValidation(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -237,7 +229,7 @@ func TestNetworkTools_Fetch_SchemeValidation(t *testing.T) {
h := newnetworkTools(t)
server := h.createMockServer()
- nt := h.createNetworkTools(server.URL)
+ nt := h.createNetwork(server.URL)
ctx := h.toolContext()
output, err := nt.Fetch(ctx, FetchInput{URLs: []string{tt.url}})
@@ -261,24 +253,20 @@ func TestNetworkTools_Fetch_SchemeValidation(t *testing.T) {
}
}
-// ============================================================================
-// SSRF Protection Tests - Mixed URLs
-// ============================================================================
-
-func TestNetworkTools_Fetch_MixedURLsFiltered(t *testing.T) {
+func TestNetwork_Fetch_MixedURLsFiltered(t *testing.T) {
t.Parallel()
h := newnetworkTools(t)
server := h.createMockServer()
- // Create NetworkTools with testing mode (SSRF protection enabled but using mock server)
- cfg := NetworkConfig{
+ // Create Network with testing mode (SSRF protection enabled but using mock server)
+ cfg := NetConfig{
SearchBaseURL: server.URL,
FetchParallelism: 2,
FetchDelay: 10 * time.Millisecond,
FetchTimeout: 5 * time.Second,
}
- nt := newNetworkToolsForTesting(t, cfg, testLogger())
+ nt := newNetworkForTesting(t, cfg, testLogger())
ctx := h.toolContext()
@@ -290,24 +278,13 @@ func TestNetworkTools_Fetch_MixedURLsFiltered(t *testing.T) {
"http://169.254.169.254/", // Cloud metadata - blocked
}
- output, err := nt.Fetch(ctx, FetchInput{URLs: urls})
-
+ _, err := nt.Fetch(ctx, FetchInput{URLs: urls})
if err != nil {
t.Fatalf("Fetch(mixed URLs) unexpected error: %v", err)
}
-
- // The mock server URL should succeed in testing mode
- // Private IPs should fail
- // Note: skipSSRFCheck affects URL validation, so in testing mode
- // even private URLs might pass. Let's verify the test setup.
- t.Logf("Results: %d, Failed: %d", len(output.Results), len(output.FailedURLs))
}
-// ============================================================================
-// SSRF Protection Tests - Redirect Protection
-// ============================================================================
-
-func TestNetworkTools_Fetch_RedirectSSRFProtection(t *testing.T) {
+func TestNetwork_Fetch_RedirectSSRFProtection(t *testing.T) {
t.Parallel()
h := newnetworkTools(t)
@@ -331,7 +308,7 @@ func TestNetworkTools_Fetch_RedirectSSRFProtection(t *testing.T) {
}))
t.Cleanup(func() { redirectServer.Close() })
- nt := h.createNetworkTools(redirectServer.URL)
+ nt := h.createNetwork(redirectServer.URL)
ctx := h.toolContext()
tests := []struct {
@@ -393,16 +370,12 @@ func TestNetworkTools_Fetch_RedirectSSRFProtection(t *testing.T) {
}
}
-// ============================================================================
-// Input Validation Tests
-// ============================================================================
-
-func TestNetworkTools_Fetch_InputValidation(t *testing.T) {
+func TestNetwork_Fetch_InputValidation(t *testing.T) {
t.Parallel()
h := newnetworkTools(t)
server := h.createMockServer()
- nt := h.createNetworkTools(server.URL)
+ nt := h.createNetwork(server.URL)
ctx := h.toolContext()
t.Run("empty URL list", func(t *testing.T) {
@@ -467,16 +440,12 @@ func TestNetworkTools_Fetch_InputValidation(t *testing.T) {
})
}
-// ============================================================================
-// Search Input Validation Tests
-// ============================================================================
-
-func TestNetworkTools_Search_InputValidation(t *testing.T) {
+func TestNetwork_Search_InputValidation(t *testing.T) {
t.Parallel()
h := newnetworkTools(t)
server := h.createMockServer()
- nt := h.createNetworkTools(server.URL)
+ nt := h.createNetwork(server.URL)
ctx := h.toolContext()
t.Run("empty query rejected", func(t *testing.T) {
@@ -514,24 +483,20 @@ func TestNetworkTools_Search_InputValidation(t *testing.T) {
})
}
-// ============================================================================
-// Public URL Success Test (using httptest)
-// ============================================================================
-
-func TestNetworkTools_Fetch_PublicURLSuccess(t *testing.T) {
+func TestNetwork_Fetch_PublicURLSuccess(t *testing.T) {
t.Parallel()
h := newnetworkTools(t)
server := h.createMockServer()
// Use ForTesting to allow httptest server (which uses localhost)
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: server.URL,
FetchParallelism: 2,
FetchDelay: 10 * time.Millisecond,
FetchTimeout: 5 * time.Second,
}
- nt := newNetworkToolsForTesting(t, cfg, testLogger())
+ nt := newNetworkForTesting(t, cfg, testLogger())
ctx := h.toolContext()
@@ -558,11 +523,7 @@ func TestNetworkTools_Fetch_PublicURLSuccess(t *testing.T) {
}
}
-// ============================================================================
-// Concurrent Fetch Test
-// ============================================================================
-
-func TestNetworkTools_Fetch_Concurrent(t *testing.T) {
+func TestNetwork_Fetch_Concurrent(t *testing.T) {
t.Parallel()
h := newnetworkTools(t)
@@ -576,13 +537,13 @@ func TestNetworkTools_Fetch_Concurrent(t *testing.T) {
}))
t.Cleanup(func() { server.Close() })
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: server.URL,
FetchParallelism: 5,
FetchDelay: 5 * time.Millisecond,
FetchTimeout: 5 * time.Second,
}
- nt := newNetworkToolsForTesting(t, cfg, testLogger())
+ nt := newNetworkForTesting(t, cfg, testLogger())
ctx := h.toolContext()
diff --git a/internal/tools/network_test.go b/internal/tools/network_test.go
index d495cdc..fe29113 100644
--- a/internal/tools/network_test.go
+++ b/internal/tools/network_test.go
@@ -5,64 +5,64 @@ import (
"time"
)
-func TestNetworkTools_Constructor(t *testing.T) {
+func TestNetwork_Constructor(t *testing.T) {
t.Run("valid inputs", func(t *testing.T) {
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: "http://localhost:8080",
FetchParallelism: 2,
FetchDelay: time.Second,
FetchTimeout: 30 * time.Second,
}
- nt, err := NewNetworkTools(cfg, testLogger())
+ nt, err := NewNetwork(cfg, testLogger())
if err != nil {
- t.Errorf("NewNetworkTools() error = %v, want nil", err)
+ t.Errorf("NewNetwork() error = %v, want nil", err)
}
if nt == nil {
- t.Error("NewNetworkTools() returned nil, want non-nil")
+ t.Error("NewNetwork() returned nil, want non-nil")
}
})
t.Run("empty search URL", func(t *testing.T) {
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: "",
}
- nt, err := NewNetworkTools(cfg, testLogger())
+ nt, err := NewNetwork(cfg, testLogger())
if err == nil {
- t.Error("NewNetworkTools() error = nil, want error")
+ t.Error("NewNetwork() error = nil, want error")
}
if nt != nil {
- t.Error("NewNetworkTools() returned non-nil, want nil")
+ t.Error("NewNetwork() returned non-nil, want nil")
}
})
t.Run("nil logger", func(t *testing.T) {
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: "http://localhost:8080",
}
- nt, err := NewNetworkTools(cfg, nil)
+ nt, err := NewNetwork(cfg, nil)
if err == nil {
- t.Error("NewNetworkTools() error = nil, want error")
+ t.Error("NewNetwork() error = nil, want error")
}
if nt != nil {
- t.Error("NewNetworkTools() returned non-nil, want nil")
+ t.Error("NewNetwork() returned non-nil, want nil")
}
})
t.Run("defaults applied", func(t *testing.T) {
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: "http://localhost:8080",
// Leave other fields as zero values
}
- nt, err := NewNetworkTools(cfg, testLogger())
+ nt, err := NewNetwork(cfg, testLogger())
if err != nil {
- t.Errorf("NewNetworkTools() error = %v, want nil", err)
+ t.Errorf("NewNetwork() error = %v, want nil", err)
}
if nt == nil {
- t.Fatal("NewNetworkTools() returned nil")
+ t.Fatal("NewNetwork() returned nil")
}
// Verify defaults were applied (internal fields not accessible, but no error means success)
})
@@ -70,19 +70,19 @@ func TestNetworkTools_Constructor(t *testing.T) {
func TestNetworkToolConstants(t *testing.T) {
expectedNames := map[string]string{
- "ToolWebSearch": "web_search",
- "ToolWebFetch": "web_fetch",
+ "WebSearchName": "web_search",
+ "WebFetchName": "web_fetch",
}
- if ToolWebSearch != expectedNames["ToolWebSearch"] {
- t.Errorf("ToolWebSearch = %q, want %q", ToolWebSearch, expectedNames["ToolWebSearch"])
+ if WebSearchName != expectedNames["WebSearchName"] {
+ t.Errorf("WebSearchName = %q, want %q", WebSearchName, expectedNames["WebSearchName"])
}
- if ToolWebFetch != expectedNames["ToolWebFetch"] {
- t.Errorf("ToolWebFetch = %q, want %q", ToolWebFetch, expectedNames["ToolWebFetch"])
+ if WebFetchName != expectedNames["WebFetchName"] {
+ t.Errorf("WebFetchName = %q, want %q", WebFetchName, expectedNames["WebFetchName"])
}
}
-func TestNetworkConfigConstants(t *testing.T) {
+func TestNetConfigConstants(t *testing.T) {
// Verify content limits
if MaxURLsPerRequest != 10 {
t.Errorf("MaxURLsPerRequest = %d, want 10", MaxURLsPerRequest)
@@ -97,137 +97,3 @@ func TestNetworkConfigConstants(t *testing.T) {
t.Errorf("DefaultSearchResults = %d, want 10", DefaultSearchResults)
}
}
-
-func TestSearchInput(t *testing.T) {
- input := SearchInput{
- Query: "test query",
- Categories: []string{"general", "news"},
- Language: "en",
- MaxResults: 20,
- }
- if input.Query != "test query" {
- t.Errorf("SearchInput.Query = %q, want %q", input.Query, "test query")
- }
- if len(input.Categories) != 2 {
- t.Errorf("SearchInput.Categories length = %d, want 2", len(input.Categories))
- }
- if input.Language != "en" {
- t.Errorf("SearchInput.Language = %q, want %q", input.Language, "en")
- }
- if input.MaxResults != 20 {
- t.Errorf("SearchInput.MaxResults = %d, want 20", input.MaxResults)
- }
-}
-
-func TestFetchInput(t *testing.T) {
- input := FetchInput{
- URLs: []string{"https://example.com", "https://test.com"},
- Selector: "article",
- }
- if len(input.URLs) != 2 {
- t.Errorf("FetchInput.URLs length = %d, want 2", len(input.URLs))
- }
- if input.Selector != "article" {
- t.Errorf("FetchInput.Selector = %q, want %q", input.Selector, "article")
- }
-}
-
-func TestSearchOutput(t *testing.T) {
- output := SearchOutput{
- Results: []SearchResult{
- {
- Title: "Test",
- URL: "https://example.com",
- Content: "Test content",
- Engine: "google",
- },
- },
- Query: "test",
- }
- if len(output.Results) != 1 {
- t.Errorf("SearchOutput.Results length = %d, want 1", len(output.Results))
- }
- if output.Query != "test" {
- t.Errorf("SearchOutput.Query = %q, want %q", output.Query, "test")
- }
-}
-
-func TestFetchOutput(t *testing.T) {
- output := FetchOutput{
- Results: []FetchResult{
- {
- URL: "https://example.com",
- Title: "Example",
- Content: "Content",
- ContentType: "text/html",
- },
- },
- FailedURLs: []FailedURL{
- {
- URL: "https://failed.com",
- Reason: "connection refused",
- StatusCode: 503,
- },
- },
- }
- if len(output.Results) != 1 {
- t.Errorf("FetchOutput.Results length = %d, want 1", len(output.Results))
- }
- if len(output.FailedURLs) != 1 {
- t.Errorf("FetchOutput.FailedURLs length = %d, want 1", len(output.FailedURLs))
- }
-}
-
-func TestSearchResult(t *testing.T) {
- result := SearchResult{
- Title: "Test Title",
- URL: "https://example.com",
- Content: "Test content",
- Engine: "google",
- PublishedAt: "2024-01-01",
- }
- if result.Title != "Test Title" {
- t.Errorf("SearchResult.Title = %q, want %q", result.Title, "Test Title")
- }
- if result.URL != "https://example.com" {
- t.Errorf("SearchResult.URL = %q, want %q", result.URL, "https://example.com")
- }
- if result.Engine != "google" {
- t.Errorf("SearchResult.Engine = %q, want %q", result.Engine, "google")
- }
- if result.PublishedAt != "2024-01-01" {
- t.Errorf("SearchResult.PublishedAt = %q, want %q", result.PublishedAt, "2024-01-01")
- }
-}
-
-func TestFetchResult(t *testing.T) {
- result := FetchResult{
- URL: "https://example.com",
- Title: "Example",
- Content: "Content",
- ContentType: "text/html",
- }
- if result.URL != "https://example.com" {
- t.Errorf("FetchResult.URL = %q, want %q", result.URL, "https://example.com")
- }
- if result.ContentType != "text/html" {
- t.Errorf("FetchResult.ContentType = %q, want %q", result.ContentType, "text/html")
- }
-}
-
-func TestFailedURL(t *testing.T) {
- failed := FailedURL{
- URL: "https://failed.com",
- Reason: "timeout",
- StatusCode: 504,
- }
- if failed.URL != "https://failed.com" {
- t.Errorf("FailedURL.URL = %q, want %q", failed.URL, "https://failed.com")
- }
- if failed.Reason != "timeout" {
- t.Errorf("FailedURL.Reason = %q, want %q", failed.Reason, "timeout")
- }
- if failed.StatusCode != 504 {
- t.Errorf("FailedURL.StatusCode = %d, want 504", failed.StatusCode)
- }
-}
diff --git a/internal/tools/register_test.go b/internal/tools/register_test.go
index af4982e..abb2bb7 100644
--- a/internal/tools/register_test.go
+++ b/internal/tools/register_test.go
@@ -17,11 +17,7 @@ func setupTestGenkit(t *testing.T) *genkit.Genkit {
return genkit.Init(context.Background())
}
-// ============================================================================
-// NewFileTools Tests
-// ============================================================================
-
-func TestNewFileTools(t *testing.T) {
+func TestNewFile(t *testing.T) {
t.Parallel()
t.Run("successful creation", func(t *testing.T) {
@@ -31,27 +27,27 @@ func TestNewFileTools(t *testing.T) {
t.Fatalf("NewPath() unexpected error: %v", err)
}
- ft, err := NewFileTools(pathVal, testLogger())
+ ft, err := NewFile(pathVal, testLogger())
if err != nil {
- t.Fatalf("NewFileTools() unexpected error: %v", err)
+ t.Fatalf("NewFile() unexpected error: %v", err)
}
if ft == nil {
- t.Fatal("NewFileTools() = nil, want non-nil")
+ t.Fatal("NewFile() = nil, want non-nil")
}
})
t.Run("nil path validator", func(t *testing.T) {
t.Parallel()
- ft, err := NewFileTools(nil, testLogger())
+ ft, err := NewFile(nil, testLogger())
if err == nil {
- t.Fatal("NewFileTools(nil, logger) expected error, got nil")
+ t.Fatal("NewFile(nil, logger) expected error, got nil")
}
if ft != nil {
- t.Errorf("NewFileTools(nil, logger) = %v, want nil", ft)
+ t.Errorf("NewFile(nil, logger) = %v, want nil", ft)
}
if !strings.Contains(err.Error(), "path validator is required") {
- t.Errorf("NewFileTools(nil, logger) error = %q, want contains %q", err.Error(), "path validator is required")
+ t.Errorf("NewFile(nil, logger) error = %q, want contains %q", err.Error(), "path validator is required")
}
})
@@ -62,24 +58,20 @@ func TestNewFileTools(t *testing.T) {
t.Fatalf("NewPath() unexpected error: %v", err)
}
- ft, err := NewFileTools(pathVal, nil)
+ ft, err := NewFile(pathVal, nil)
if err == nil {
- t.Fatal("NewFileTools(pathVal, nil) expected error, got nil")
+ t.Fatal("NewFile(pathVal, nil) expected error, got nil")
}
if ft != nil {
- t.Errorf("NewFileTools(pathVal, nil) = %v, want nil", ft)
+ t.Errorf("NewFile(pathVal, nil) = %v, want nil", ft)
}
if !strings.Contains(err.Error(), "logger is required") {
- t.Errorf("NewFileTools(pathVal, nil) error = %q, want contains %q", err.Error(), "logger is required")
+ t.Errorf("NewFile(pathVal, nil) error = %q, want contains %q", err.Error(), "logger is required")
}
})
}
-// ============================================================================
-// RegisterFileTools Tests
-// ============================================================================
-
-func TestRegisterFileTools(t *testing.T) {
+func TestRegisterFile(t *testing.T) {
t.Parallel()
t.Run("successful registration", func(t *testing.T) {
@@ -90,24 +82,24 @@ func TestRegisterFileTools(t *testing.T) {
t.Fatalf("NewPath() unexpected error: %v", err)
}
- ft, err := NewFileTools(pathVal, testLogger())
+ ft, err := NewFile(pathVal, testLogger())
if err != nil {
- t.Fatalf("NewFileTools() unexpected error: %v", err)
+ t.Fatalf("NewFile() unexpected error: %v", err)
}
- tools, err := RegisterFileTools(g, ft)
+ tools, err := RegisterFile(g, ft)
if err != nil {
- t.Fatalf("RegisterFileTools() unexpected error: %v", err)
+ t.Fatalf("RegisterFile() unexpected error: %v", err)
}
if got, want := len(tools), 5; got != want {
- t.Errorf("RegisterFileTools() tool count = %d, want %d (should register 5 file tools)", got, want)
+ t.Errorf("RegisterFile() tool count = %d, want %d (should register 5 file tools)", got, want)
}
// Verify tool names
- expectedNames := []string{ToolReadFile, ToolWriteFile, ToolListFiles, ToolDeleteFile, ToolGetFileInfo}
+ expectedNames := []string{ReadFileName, WriteFileName, ListFilesName, DeleteFileName, FileInfoName}
actualNames := extractToolNames(tools)
if !slicesEqual(expectedNames, actualNames) {
- t.Errorf("RegisterFileTools() tool names = %v, want %v", actualNames, expectedNames)
+ t.Errorf("RegisterFile() tool names = %v, want %v", actualNames, expectedNames)
}
})
@@ -118,45 +110,41 @@ func TestRegisterFileTools(t *testing.T) {
t.Fatalf("NewPath() unexpected error: %v", err)
}
- ft, err := NewFileTools(pathVal, testLogger())
+ ft, err := NewFile(pathVal, testLogger())
if err != nil {
- t.Fatalf("NewFileTools() unexpected error: %v", err)
+ t.Fatalf("NewFile() unexpected error: %v", err)
}
- tools, err := RegisterFileTools(nil, ft)
+ tools, err := RegisterFile(nil, ft)
if err == nil {
- t.Fatal("RegisterFileTools(nil, ft) expected error, got nil")
+ t.Fatal("RegisterFile(nil, ft) expected error, got nil")
}
if tools != nil {
- t.Errorf("RegisterFileTools(nil, ft) = %v, want nil", tools)
+ t.Errorf("RegisterFile(nil, ft) = %v, want nil", tools)
}
if !strings.Contains(err.Error(), "genkit instance is required") {
- t.Errorf("RegisterFileTools(nil, ft) error = %q, want contains %q", err.Error(), "genkit instance is required")
+ t.Errorf("RegisterFile(nil, ft) error = %q, want contains %q", err.Error(), "genkit instance is required")
}
})
- t.Run("nil FileTools", func(t *testing.T) {
+ t.Run("nil File", func(t *testing.T) {
t.Parallel()
g := setupTestGenkit(t)
- tools, err := RegisterFileTools(g, nil)
+ tools, err := RegisterFile(g, nil)
if err == nil {
- t.Fatal("RegisterFileTools(g, nil) expected error, got nil")
+ t.Fatal("RegisterFile(g, nil) expected error, got nil")
}
if tools != nil {
- t.Errorf("RegisterFileTools(g, nil) = %v, want nil", tools)
+ t.Errorf("RegisterFile(g, nil) = %v, want nil", tools)
}
- if !strings.Contains(err.Error(), "FileTools is required") {
- t.Errorf("RegisterFileTools(g, nil) error = %q, want contains %q", err.Error(), "FileTools is required")
+ if !strings.Contains(err.Error(), "File is required") {
+ t.Errorf("RegisterFile(g, nil) error = %q, want contains %q", err.Error(), "File is required")
}
})
}
-// ============================================================================
-// NewSystemTools Tests
-// ============================================================================
-
-func TestNewSystemTools(t *testing.T) {
+func TestNewSystem(t *testing.T) {
t.Parallel()
t.Run("successful creation", func(t *testing.T) {
@@ -164,12 +152,12 @@ func TestNewSystemTools(t *testing.T) {
cmdVal := security.NewCommand()
envVal := security.NewEnv()
- st, err := NewSystemTools(cmdVal, envVal, testLogger())
+ st, err := NewSystem(cmdVal, envVal, testLogger())
if err != nil {
- t.Fatalf("NewSystemTools() unexpected error: %v", err)
+ t.Fatalf("NewSystem() unexpected error: %v", err)
}
if st == nil {
- t.Fatal("NewSystemTools() = nil, want non-nil")
+ t.Fatal("NewSystem() = nil, want non-nil")
}
})
@@ -177,15 +165,15 @@ func TestNewSystemTools(t *testing.T) {
t.Parallel()
envVal := security.NewEnv()
- st, err := NewSystemTools(nil, envVal, testLogger())
+ st, err := NewSystem(nil, envVal, testLogger())
if err == nil {
- t.Fatal("NewSystemTools(nil, envVal, logger) expected error, got nil")
+ t.Fatal("NewSystem(nil, envVal, logger) expected error, got nil")
}
if st != nil {
- t.Errorf("NewSystemTools(nil, envVal, logger) = %v, want nil", st)
+ t.Errorf("NewSystem(nil, envVal, logger) = %v, want nil", st)
}
if !strings.Contains(err.Error(), "command validator is required") {
- t.Errorf("NewSystemTools(nil, envVal, logger) error = %q, want contains %q", err.Error(), "command validator is required")
+ t.Errorf("NewSystem(nil, envVal, logger) error = %q, want contains %q", err.Error(), "command validator is required")
}
})
@@ -193,15 +181,15 @@ func TestNewSystemTools(t *testing.T) {
t.Parallel()
cmdVal := security.NewCommand()
- st, err := NewSystemTools(cmdVal, nil, testLogger())
+ st, err := NewSystem(cmdVal, nil, testLogger())
if err == nil {
- t.Fatal("NewSystemTools(cmdVal, nil, logger) expected error, got nil")
+ t.Fatal("NewSystem(cmdVal, nil, logger) expected error, got nil")
}
if st != nil {
- t.Errorf("NewSystemTools(cmdVal, nil, logger) = %v, want nil", st)
+ t.Errorf("NewSystem(cmdVal, nil, logger) = %v, want nil", st)
}
if !strings.Contains(err.Error(), "env validator is required") {
- t.Errorf("NewSystemTools(cmdVal, nil, logger) error = %q, want contains %q", err.Error(), "env validator is required")
+ t.Errorf("NewSystem(cmdVal, nil, logger) error = %q, want contains %q", err.Error(), "env validator is required")
}
})
@@ -210,24 +198,20 @@ func TestNewSystemTools(t *testing.T) {
cmdVal := security.NewCommand()
envVal := security.NewEnv()
- st, err := NewSystemTools(cmdVal, envVal, nil)
+ st, err := NewSystem(cmdVal, envVal, nil)
if err == nil {
- t.Fatal("NewSystemTools(cmdVal, envVal, nil) expected error, got nil")
+ t.Fatal("NewSystem(cmdVal, envVal, nil) expected error, got nil")
}
if st != nil {
- t.Errorf("NewSystemTools(cmdVal, envVal, nil) = %v, want nil", st)
+ t.Errorf("NewSystem(cmdVal, envVal, nil) = %v, want nil", st)
}
if !strings.Contains(err.Error(), "logger is required") {
- t.Errorf("NewSystemTools(cmdVal, envVal, nil) error = %q, want contains %q", err.Error(), "logger is required")
+ t.Errorf("NewSystem(cmdVal, envVal, nil) error = %q, want contains %q", err.Error(), "logger is required")
}
})
}
-// ============================================================================
-// RegisterSystemTools Tests
-// ============================================================================
-
-func TestRegisterSystemTools(t *testing.T) {
+func TestRegisterSystem(t *testing.T) {
t.Parallel()
t.Run("successful registration", func(t *testing.T) {
@@ -236,24 +220,24 @@ func TestRegisterSystemTools(t *testing.T) {
cmdVal := security.NewCommand()
envVal := security.NewEnv()
- st, err := NewSystemTools(cmdVal, envVal, testLogger())
+ st, err := NewSystem(cmdVal, envVal, testLogger())
if err != nil {
- t.Fatalf("NewSystemTools() unexpected error: %v", err)
+ t.Fatalf("NewSystem() unexpected error: %v", err)
}
- tools, err := RegisterSystemTools(g, st)
+ tools, err := RegisterSystem(g, st)
if err != nil {
- t.Fatalf("RegisterSystemTools() unexpected error: %v", err)
+ t.Fatalf("RegisterSystem() unexpected error: %v", err)
}
if got, want := len(tools), 3; got != want {
- t.Errorf("RegisterSystemTools() tool count = %d, want %d (should register 3 system tools)", got, want)
+ t.Errorf("RegisterSystem() tool count = %d, want %d (should register 3 system tools)", got, want)
}
// Verify tool names
- expectedNames := []string{ToolCurrentTime, ToolExecuteCommand, ToolGetEnv}
+ expectedNames := []string{CurrentTimeName, ExecuteCommandName, GetEnvName}
actualNames := extractToolNames(tools)
if !slicesEqual(expectedNames, actualNames) {
- t.Errorf("RegisterSystemTools() tool names = %v, want %v", actualNames, expectedNames)
+ t.Errorf("RegisterSystem() tool names = %v, want %v", actualNames, expectedNames)
}
})
@@ -262,235 +246,219 @@ func TestRegisterSystemTools(t *testing.T) {
cmdVal := security.NewCommand()
envVal := security.NewEnv()
- st, err := NewSystemTools(cmdVal, envVal, testLogger())
+ st, err := NewSystem(cmdVal, envVal, testLogger())
if err != nil {
- t.Fatalf("NewSystemTools() unexpected error: %v", err)
+ t.Fatalf("NewSystem() unexpected error: %v", err)
}
- tools, err := RegisterSystemTools(nil, st)
+ tools, err := RegisterSystem(nil, st)
if err == nil {
- t.Fatal("RegisterSystemTools(nil, st) expected error, got nil")
+ t.Fatal("RegisterSystem(nil, st) expected error, got nil")
}
if tools != nil {
- t.Errorf("RegisterSystemTools(nil, st) = %v, want nil", tools)
+ t.Errorf("RegisterSystem(nil, st) = %v, want nil", tools)
}
if !strings.Contains(err.Error(), "genkit instance is required") {
- t.Errorf("RegisterSystemTools(nil, st) error = %q, want contains %q", err.Error(), "genkit instance is required")
+ t.Errorf("RegisterSystem(nil, st) error = %q, want contains %q", err.Error(), "genkit instance is required")
}
})
- t.Run("nil SystemTools", func(t *testing.T) {
+ t.Run("nil System", func(t *testing.T) {
t.Parallel()
g := setupTestGenkit(t)
- tools, err := RegisterSystemTools(g, nil)
+ tools, err := RegisterSystem(g, nil)
if err == nil {
- t.Fatal("RegisterSystemTools(g, nil) expected error, got nil")
+ t.Fatal("RegisterSystem(g, nil) expected error, got nil")
}
if tools != nil {
- t.Errorf("RegisterSystemTools(g, nil) = %v, want nil", tools)
+ t.Errorf("RegisterSystem(g, nil) = %v, want nil", tools)
}
- if !strings.Contains(err.Error(), "SystemTools is required") {
- t.Errorf("RegisterSystemTools(g, nil) error = %q, want contains %q", err.Error(), "SystemTools is required")
+ if !strings.Contains(err.Error(), "System is required") {
+ t.Errorf("RegisterSystem(g, nil) error = %q, want contains %q", err.Error(), "System is required")
}
})
}
-// ============================================================================
-// NewNetworkTools Tests
-// ============================================================================
-
-func TestNewNetworkTools(t *testing.T) {
+func TestNewNetwork(t *testing.T) {
t.Parallel()
t.Run("successful creation", func(t *testing.T) {
t.Parallel()
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: "http://localhost:8080",
}
- nt, err := NewNetworkTools(cfg, testLogger())
+ nt, err := NewNetwork(cfg, testLogger())
if err != nil {
- t.Fatalf("NewNetworkTools() unexpected error: %v", err)
+ t.Fatalf("NewNetwork() unexpected error: %v", err)
}
if nt == nil {
- t.Fatal("NewNetworkTools() = nil, want non-nil")
+ t.Fatal("NewNetwork() = nil, want non-nil")
}
})
t.Run("empty search base URL", func(t *testing.T) {
t.Parallel()
- cfg := NetworkConfig{}
+ cfg := NetConfig{}
- nt, err := NewNetworkTools(cfg, testLogger())
+ nt, err := NewNetwork(cfg, testLogger())
if err == nil {
- t.Fatal("NewNetworkTools(empty config) expected error, got nil")
+ t.Fatal("NewNetwork(empty config) expected error, got nil")
}
if nt != nil {
- t.Errorf("NewNetworkTools(empty config) = %v, want nil", nt)
+ t.Errorf("NewNetwork(empty config) = %v, want nil", nt)
}
if !strings.Contains(err.Error(), "search base URL is required") {
- t.Errorf("NewNetworkTools(empty config) error = %q, want contains %q", err.Error(), "search base URL is required")
+ t.Errorf("NewNetwork(empty config) error = %q, want contains %q", err.Error(), "search base URL is required")
}
})
t.Run("nil logger", func(t *testing.T) {
t.Parallel()
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: "http://localhost:8080",
}
- nt, err := NewNetworkTools(cfg, nil)
+ nt, err := NewNetwork(cfg, nil)
if err == nil {
- t.Fatal("NewNetworkTools(cfg, nil) expected error, got nil")
+ t.Fatal("NewNetwork(cfg, nil) expected error, got nil")
}
if nt != nil {
- t.Errorf("NewNetworkTools(cfg, nil) = %v, want nil", nt)
+ t.Errorf("NewNetwork(cfg, nil) = %v, want nil", nt)
}
if !strings.Contains(err.Error(), "logger is required") {
- t.Errorf("NewNetworkTools(cfg, nil) error = %q, want contains %q", err.Error(), "logger is required")
+ t.Errorf("NewNetwork(cfg, nil) error = %q, want contains %q", err.Error(), "logger is required")
}
})
t.Run("default values applied", func(t *testing.T) {
t.Parallel()
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: "http://localhost:8080/",
// Leave other values at zero - should get defaults
}
- nt, err := NewNetworkTools(cfg, testLogger())
+ nt, err := NewNetwork(cfg, testLogger())
if err != nil {
- t.Fatalf("NewNetworkTools() unexpected error: %v", err)
+ t.Fatalf("NewNetwork() unexpected error: %v", err)
}
if nt == nil {
- t.Fatal("NewNetworkTools() = nil, want non-nil")
+ t.Fatal("NewNetwork() = nil, want non-nil")
}
})
}
-// ============================================================================
-// RegisterNetworkTools Tests
-// ============================================================================
-
-func TestRegisterNetworkTools(t *testing.T) {
+func TestRegisterNetwork(t *testing.T) {
t.Parallel()
t.Run("successful registration", func(t *testing.T) {
t.Parallel()
g := setupTestGenkit(t)
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: "http://localhost:8080",
}
- nt, err := NewNetworkTools(cfg, testLogger())
+ nt, err := NewNetwork(cfg, testLogger())
if err != nil {
- t.Fatalf("NewNetworkTools() unexpected error: %v", err)
+ t.Fatalf("NewNetwork() unexpected error: %v", err)
}
- tools, err := RegisterNetworkTools(g, nt)
+ tools, err := RegisterNetwork(g, nt)
if err != nil {
- t.Fatalf("RegisterNetworkTools() unexpected error: %v", err)
+ t.Fatalf("RegisterNetwork() unexpected error: %v", err)
}
if got, want := len(tools), 2; got != want {
- t.Errorf("RegisterNetworkTools() tool count = %d, want %d (should register 2 network tools)", got, want)
+ t.Errorf("RegisterNetwork() tool count = %d, want %d (should register 2 network tools)", got, want)
}
// Verify tool names
- expectedNames := []string{ToolWebSearch, ToolWebFetch}
+ expectedNames := []string{WebSearchName, WebFetchName}
actualNames := extractToolNames(tools)
if !slicesEqual(expectedNames, actualNames) {
- t.Errorf("RegisterNetworkTools() tool names = %v, want %v", actualNames, expectedNames)
+ t.Errorf("RegisterNetwork() tool names = %v, want %v", actualNames, expectedNames)
}
})
t.Run("nil genkit", func(t *testing.T) {
t.Parallel()
- cfg := NetworkConfig{
+ cfg := NetConfig{
SearchBaseURL: "http://localhost:8080",
}
- nt, err := NewNetworkTools(cfg, testLogger())
+ nt, err := NewNetwork(cfg, testLogger())
if err != nil {
- t.Fatalf("NewNetworkTools() unexpected error: %v", err)
+ t.Fatalf("NewNetwork() unexpected error: %v", err)
}
- tools, err := RegisterNetworkTools(nil, nt)
+ tools, err := RegisterNetwork(nil, nt)
if err == nil {
- t.Fatal("RegisterNetworkTools(nil, nt) expected error, got nil")
+ t.Fatal("RegisterNetwork(nil, nt) expected error, got nil")
}
if tools != nil {
- t.Errorf("RegisterNetworkTools(nil, nt) = %v, want nil", tools)
+ t.Errorf("RegisterNetwork(nil, nt) = %v, want nil", tools)
}
if !strings.Contains(err.Error(), "genkit instance is required") {
- t.Errorf("RegisterNetworkTools(nil, nt) error = %q, want contains %q", err.Error(), "genkit instance is required")
+ t.Errorf("RegisterNetwork(nil, nt) error = %q, want contains %q", err.Error(), "genkit instance is required")
}
})
- t.Run("nil NetworkTools", func(t *testing.T) {
+ t.Run("nil Network", func(t *testing.T) {
t.Parallel()
g := setupTestGenkit(t)
- tools, err := RegisterNetworkTools(g, nil)
+ tools, err := RegisterNetwork(g, nil)
if err == nil {
- t.Fatal("RegisterNetworkTools(g, nil) expected error, got nil")
+ t.Fatal("RegisterNetwork(g, nil) expected error, got nil")
}
if tools != nil {
- t.Errorf("RegisterNetworkTools(g, nil) = %v, want nil", tools)
+ t.Errorf("RegisterNetwork(g, nil) = %v, want nil", tools)
}
- if !strings.Contains(err.Error(), "NetworkTools is required") {
- t.Errorf("RegisterNetworkTools(g, nil) error = %q, want contains %q", err.Error(), "NetworkTools is required")
+ if !strings.Contains(err.Error(), "Network is required") {
+ t.Errorf("RegisterNetwork(g, nil) error = %q, want contains %q", err.Error(), "Network is required")
}
})
}
-// ============================================================================
-// RegisterKnowledgeTools Tests
-// ============================================================================
+// Note: Knowledge validation (nil store, nil logger) is tested in
+// TestNewKnowledge in knowledge_test.go. These tests verify
+// RegisterKnowledge parameter validation only.
-// Note: KnowledgeTools validation (nil store, nil logger) is tested in
-// TestNewKnowledgeTools in knowledge_test.go. These tests verify
-// RegisterKnowledgeTools parameter validation only.
-
-func TestRegisterKnowledgeTools(t *testing.T) {
+func TestRegisterKnowledge(t *testing.T) {
t.Parallel()
t.Run("nil genkit", func(t *testing.T) {
t.Parallel()
- tools, err := RegisterKnowledgeTools(nil, &KnowledgeTools{})
+ tools, err := RegisterKnowledge(nil, &Knowledge{})
if err == nil {
- t.Fatal("RegisterKnowledgeTools(nil, kt) expected error, got nil")
+ t.Fatal("RegisterKnowledge(nil, kt) expected error, got nil")
}
if tools != nil {
- t.Errorf("RegisterKnowledgeTools(nil, kt) = %v, want nil", tools)
+ t.Errorf("RegisterKnowledge(nil, kt) = %v, want nil", tools)
}
if !strings.Contains(err.Error(), "genkit instance is required") {
- t.Errorf("RegisterKnowledgeTools(nil, kt) error = %q, want contains %q", err.Error(), "genkit instance is required")
+ t.Errorf("RegisterKnowledge(nil, kt) error = %q, want contains %q", err.Error(), "genkit instance is required")
}
})
- t.Run("nil KnowledgeTools", func(t *testing.T) {
+ t.Run("nil Knowledge", func(t *testing.T) {
t.Parallel()
g := setupTestGenkit(t)
- tools, err := RegisterKnowledgeTools(g, nil)
+ tools, err := RegisterKnowledge(g, nil)
if err == nil {
- t.Fatal("RegisterKnowledgeTools(g, nil) expected error, got nil")
+ t.Fatal("RegisterKnowledge(g, nil) expected error, got nil")
}
if tools != nil {
- t.Errorf("RegisterKnowledgeTools(g, nil) = %v, want nil", tools)
+ t.Errorf("RegisterKnowledge(g, nil) = %v, want nil", tools)
}
- if !strings.Contains(err.Error(), "KnowledgeTools is required") {
- t.Errorf("RegisterKnowledgeTools(g, nil) error = %q, want contains %q", err.Error(), "KnowledgeTools is required")
+ if !strings.Contains(err.Error(), "Knowledge is required") {
+ t.Errorf("RegisterKnowledge(g, nil) error = %q, want contains %q", err.Error(), "Knowledge is required")
}
})
}
-// ============================================================================
-// Helper Functions
-// ============================================================================
-
func extractToolNames(tools []ai.Tool) []string {
names := make([]string, len(tools))
for i, tool := range tools {
diff --git a/internal/tools/setup_test.go b/internal/tools/setup_test.go
index b710699..491c48f 100644
--- a/internal/tools/setup_test.go
+++ b/internal/tools/setup_test.go
@@ -1,25 +1,25 @@
package tools
import (
+ "log/slog"
"testing"
-
- "github.com/koopa0/koopa/internal/log"
)
// testLogger returns a no-op logger for testing.
-func testLogger() log.Logger {
- return log.NewNop()
+func testLogger() *slog.Logger {
+ return slog.New(slog.DiscardHandler)
}
-// newNetworkToolsForTesting creates a NetworkTools instance with SSRF protection
+// newNetworkForTesting creates a Network instance with SSRF protection
// disabled. This allows tests to use httptest.Server (which binds to localhost).
// Only accessible within tools package tests (unexported).
-func newNetworkToolsForTesting(tb testing.TB, cfg NetworkConfig, logger log.Logger) *NetworkTools {
+func newNetworkForTesting(tb testing.TB, cfg NetConfig, logger *slog.Logger) *Network {
tb.Helper()
- nt, err := NewNetworkTools(cfg, logger)
+ nt, err := NewNetwork(cfg, logger)
if err != nil {
- tb.Fatalf("NewNetworkTools() unexpected error: %v", err)
+ tb.Fatalf("NewNetwork() unexpected error: %v", err)
}
nt.skipSSRFCheck = true
+ nt.searchClient.Transport = nil // allow localhost in tests
return nt
}
diff --git a/internal/tools/system.go b/internal/tools/system.go
index 3dca4ff..9f6dbb5 100644
--- a/internal/tools/system.go
+++ b/internal/tools/system.go
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
+ "log/slog"
"os"
"os/exec"
"strings"
@@ -11,14 +12,17 @@ import (
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
- "github.com/koopa0/koopa/internal/log"
"github.com/koopa0/koopa/internal/security"
)
+// Tool name constants for system operations registered with Genkit.
const (
- ToolCurrentTime = "current_time"
- ToolExecuteCommand = "execute_command"
- ToolGetEnv = "get_env"
+ // CurrentTimeName is the Genkit tool name for retrieving the current time.
+ CurrentTimeName = "current_time"
+ // ExecuteCommandName is the Genkit tool name for executing shell commands.
+ ExecuteCommandName = "execute_command"
+ // GetEnvName is the Genkit tool name for reading environment variables.
+ GetEnvName = "get_env"
)
// ExecuteCommandInput defines input for execute_command tool.
@@ -35,18 +39,18 @@ type GetEnvInput struct {
// CurrentTimeInput defines input for current_time tool (no input needed).
type CurrentTimeInput struct{}
-// SystemTools holds dependencies for system operation handlers.
-// Use NewSystemTools to create an instance, then either:
+// System holds dependencies for system operation handlers.
+// Use NewSystem to create an instance, then either:
// - Call methods directly (for MCP)
-// - Use RegisterSystemTools to register with Genkit
-type SystemTools struct {
+// - Use RegisterSystem to register with Genkit
+type System struct {
cmdVal *security.Command
envVal *security.Env
- logger log.Logger
+ logger *slog.Logger
}
-// NewSystemTools creates a SystemTools instance.
-func NewSystemTools(cmdVal *security.Command, envVal *security.Env, logger log.Logger) (*SystemTools, error) {
+// NewSystem creates a System instance.
+func NewSystem(cmdVal *security.Command, envVal *security.Env, logger *slog.Logger) (*System, error) {
if cmdVal == nil {
return nil, fmt.Errorf("command validator is required")
}
@@ -56,45 +60,45 @@ func NewSystemTools(cmdVal *security.Command, envVal *security.Env, logger log.L
if logger == nil {
return nil, fmt.Errorf("logger is required")
}
- return &SystemTools{cmdVal: cmdVal, envVal: envVal, logger: logger}, nil
+ return &System{cmdVal: cmdVal, envVal: envVal, logger: logger}, nil
}
-// RegisterSystemTools registers all system operation tools with Genkit.
+// RegisterSystem registers all system operation tools with Genkit.
// Tools are registered with event emission wrappers for streaming support.
-func RegisterSystemTools(g *genkit.Genkit, st *SystemTools) ([]ai.Tool, error) {
+func RegisterSystem(g *genkit.Genkit, st *System) ([]ai.Tool, error) {
if g == nil {
return nil, fmt.Errorf("genkit instance is required")
}
if st == nil {
- return nil, fmt.Errorf("SystemTools is required")
+ return nil, fmt.Errorf("System is required")
}
return []ai.Tool{
- genkit.DefineTool(g, ToolCurrentTime,
+ genkit.DefineTool(g, CurrentTimeName,
"Get the current system date and time. "+
"Returns: formatted time string, Unix timestamp, and ISO 8601 format. "+
"Use this to: check current time, calculate relative times, add timestamps to outputs. "+
"Always returns the server's local time zone.",
- WithEvents(ToolCurrentTime, st.CurrentTime)),
- genkit.DefineTool(g, ToolExecuteCommand,
+ WithEvents(CurrentTimeName, st.CurrentTime)),
+ genkit.DefineTool(g, ExecuteCommandName,
"Execute a shell command from the allowed list with security validation. "+
"Allowed commands: git, npm, yarn, go, make, docker, kubectl, ls, cat, grep, find, pwd, echo. "+
"Commands run with a timeout to prevent hanging. "+
"Returns: stdout, stderr, exit code, and execution time. "+
"Use this for: running builds, checking git status, listing processes, viewing file contents. "+
"Security: Dangerous commands (rm -rf, sudo, chmod, etc.) are blocked.",
- WithEvents(ToolExecuteCommand, st.ExecuteCommand)),
- genkit.DefineTool(g, ToolGetEnv,
+ WithEvents(ExecuteCommandName, st.ExecuteCommand)),
+ genkit.DefineTool(g, GetEnvName,
"Read an environment variable value from the system. "+
"Returns: the variable name and its value. "+
"Use this to: check configuration, verify paths, read non-sensitive settings. "+
"Security: Sensitive variables containing KEY, SECRET, TOKEN, or PASSWORD in their names are protected and will not be returned.",
- WithEvents(ToolGetEnv, st.GetEnv)),
+ WithEvents(GetEnvName, st.GetEnv)),
}, nil
}
// CurrentTime returns the current system date and time in multiple formats.
-func (s *SystemTools) CurrentTime(_ *ai.ToolContext, _ CurrentTimeInput) (Result, error) {
+func (s *System) CurrentTime(_ *ai.ToolContext, _ CurrentTimeInput) (Result, error) {
s.logger.Info("CurrentTime called")
now := time.Now()
s.logger.Info("CurrentTime succeeded")
@@ -112,11 +116,11 @@ func (s *SystemTools) CurrentTime(_ *ai.ToolContext, _ CurrentTimeInput) (Result
// Dangerous commands like rm -rf, sudo, and shutdown are blocked.
// Business errors (blocked commands, execution failures) are returned in Result.Error.
// Only context cancellation returns a Go error.
-func (s *SystemTools) ExecuteCommand(ctx *ai.ToolContext, input ExecuteCommandInput) (Result, error) {
+func (s *System) ExecuteCommand(ctx *ai.ToolContext, input ExecuteCommandInput) (Result, error) {
s.logger.Info("ExecuteCommand called", "command", input.Command, "args", input.Args)
// Command security validation (prevent command injection attacks CWE-78)
- if err := s.cmdVal.ValidateCommand(input.Command, input.Args); err != nil {
+ if err := s.cmdVal.Validate(input.Command, input.Args); err != nil {
s.logger.Error("ExecuteCommand dangerous command rejected", "command", input.Command, "args", input.Args, "error", err)
return Result{
Status: StatusError,
@@ -138,17 +142,16 @@ func (s *SystemTools) ExecuteCommand(ctx *ai.ToolContext, input ExecuteCommandIn
if err != nil {
// Check if it was canceled by context - this is infrastructure error
if execCtx.Err() != nil {
- s.logger.Error("ExecuteCommand canceled", "command", input.Command, "error", execCtx.Err())
return Result{}, fmt.Errorf("command execution canceled: %w", execCtx.Err())
}
// Command execution failure is a business error
- s.logger.Error("ExecuteCommand failed", "command", input.Command, "error", err, "output", string(output))
+ s.logger.Error("executing command", "command", input.Command, "error", err, "output", string(output))
return Result{
Status: StatusError,
Error: &Error{
Code: ErrCodeExecution,
- Message: fmt.Sprintf("command failed: %v", err),
+ Message: fmt.Sprintf("executing command: %v", err),
Details: map[string]any{
"command": input.Command,
"args": strings.Join(input.Args, " "),
@@ -174,11 +177,11 @@ func (s *SystemTools) ExecuteCommand(ctx *ai.ToolContext, input ExecuteCommandIn
// GetEnv reads an environment variable value with security protection.
// Sensitive variables containing KEY, SECRET, or TOKEN in the name are blocked.
// Business errors (sensitive variable blocked) are returned in Result.Error.
-func (s *SystemTools) GetEnv(_ *ai.ToolContext, input GetEnvInput) (Result, error) {
+func (s *System) GetEnv(_ *ai.ToolContext, input GetEnvInput) (Result, error) {
s.logger.Info("GetEnv called", "key", input.Key)
// Environment variable security validation (prevent sensitive information leakage)
- if err := s.envVal.ValidateEnvAccess(input.Key); err != nil {
+ if err := s.envVal.Validate(input.Key); err != nil {
s.logger.Error("GetEnv sensitive variable blocked", "key", input.Key, "error", err)
return Result{
Status: StatusError,
diff --git a/internal/tools/system_integration_test.go b/internal/tools/system_integration_test.go
index 5cc593c..8e73a9a 100644
--- a/internal/tools/system_integration_test.go
+++ b/internal/tools/system_integration_test.go
@@ -11,7 +11,7 @@ import (
"github.com/koopa0/koopa/internal/security"
)
-// systemTools provides test utilities for SystemTools.
+// systemTools provides test utilities for System.
type systemTools struct {
t *testing.T
}
@@ -21,13 +21,13 @@ func newsystemTools(t *testing.T) *systemTools {
return &systemTools{t: t}
}
-func (h *systemTools) createSystemTools() *SystemTools {
+func (h *systemTools) createSystem() *System {
h.t.Helper()
cmdVal := security.NewCommand()
envVal := security.NewEnv()
- st, err := NewSystemTools(cmdVal, envVal, testLogger())
+ st, err := NewSystem(cmdVal, envVal, testLogger())
if err != nil {
- h.t.Fatalf("failed to create system tools: %v", err)
+ h.t.Fatalf("creating system tools: %v", err)
}
return st
}
@@ -36,11 +36,7 @@ func (*systemTools) toolContext() *ai.ToolContext {
return &ai.ToolContext{Context: context.Background()}
}
-// ============================================================================
-// ExecuteCommand Integration Tests - Command Injection Prevention
-// ============================================================================
-
-func TestSystemTools_ExecuteCommand_WhitelistEnforcement(t *testing.T) {
+func TestSystem_ExecuteCommand_WhitelistEnforcement(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -138,7 +134,7 @@ func TestSystemTools_ExecuteCommand_WhitelistEnforcement(t *testing.T) {
t.Parallel()
h := newsystemTools(t)
- st := h.createSystemTools()
+ st := h.createSystem()
ctx := h.toolContext()
result, err := st.ExecuteCommand(ctx, ExecuteCommandInput{
@@ -176,7 +172,7 @@ func TestSystemTools_ExecuteCommand_WhitelistEnforcement(t *testing.T) {
}
}
-func TestSystemTools_ExecuteCommand_DangerousPatterns(t *testing.T) {
+func TestSystem_ExecuteCommand_DangerousPatterns(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -234,7 +230,7 @@ func TestSystemTools_ExecuteCommand_DangerousPatterns(t *testing.T) {
t.Parallel()
h := newsystemTools(t)
- st := h.createSystemTools()
+ st := h.createSystem()
ctx := h.toolContext()
result, err := st.ExecuteCommand(ctx, ExecuteCommandInput{
@@ -257,11 +253,11 @@ func TestSystemTools_ExecuteCommand_DangerousPatterns(t *testing.T) {
}
}
-func TestSystemTools_ExecuteCommand_Success(t *testing.T) {
+func TestSystem_ExecuteCommand_Success(t *testing.T) {
t.Parallel()
h := newsystemTools(t)
- st := h.createSystemTools()
+ st := h.createSystem()
ctx := h.toolContext()
result, err := st.ExecuteCommand(ctx, ExecuteCommandInput{
@@ -298,7 +294,7 @@ func TestSystemTools_ExecuteCommand_Success(t *testing.T) {
}
}
-func TestSystemTools_ExecuteCommand_ContextCancellation(t *testing.T) {
+func TestSystem_ExecuteCommand_ContextCancellation(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("sleep command not available on Windows")
}
@@ -306,7 +302,7 @@ func TestSystemTools_ExecuteCommand_ContextCancellation(t *testing.T) {
t.Parallel()
h := newsystemTools(t)
- st := h.createSystemTools()
+ st := h.createSystem()
// Create a context that's already canceled
ctx, cancel := context.WithCancel(context.Background())
@@ -328,11 +324,7 @@ func TestSystemTools_ExecuteCommand_ContextCancellation(t *testing.T) {
}
}
-// ============================================================================
-// GetEnv Integration Tests - Sensitive Variable Protection
-// ============================================================================
-
-func TestSystemTools_GetEnv_SensitiveVariableBlocked(t *testing.T) {
+func TestSystem_GetEnv_SensitiveVariableBlocked(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -440,7 +432,7 @@ func TestSystemTools_GetEnv_SensitiveVariableBlocked(t *testing.T) {
t.Parallel()
h := newsystemTools(t)
- st := h.createSystemTools()
+ st := h.createSystem()
result, err := st.GetEnv(nil, GetEnvInput{Key: tt.envKey})
@@ -459,7 +451,7 @@ func TestSystemTools_GetEnv_SensitiveVariableBlocked(t *testing.T) {
}
}
-func TestSystemTools_GetEnv_SafeVariableAllowed(t *testing.T) {
+func TestSystem_GetEnv_SafeVariableAllowed(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -481,7 +473,7 @@ func TestSystemTools_GetEnv_SafeVariableAllowed(t *testing.T) {
t.Parallel()
h := newsystemTools(t)
- st := h.createSystemTools()
+ st := h.createSystem()
result, err := st.GetEnv(nil, GetEnvInput{Key: tt.envKey})
@@ -507,7 +499,7 @@ func TestSystemTools_GetEnv_SafeVariableAllowed(t *testing.T) {
}
}
-func TestSystemTools_GetEnv_CaseInsensitiveBlocking(t *testing.T) {
+func TestSystem_GetEnv_CaseInsensitiveBlocking(t *testing.T) {
t.Parallel()
tests := []struct {
@@ -527,7 +519,7 @@ func TestSystemTools_GetEnv_CaseInsensitiveBlocking(t *testing.T) {
t.Parallel()
h := newsystemTools(t)
- st := h.createSystemTools()
+ st := h.createSystem()
result, err := st.GetEnv(nil, GetEnvInput{Key: tt.envKey})
@@ -546,15 +538,11 @@ func TestSystemTools_GetEnv_CaseInsensitiveBlocking(t *testing.T) {
}
}
-// ============================================================================
-// CurrentTime Integration Tests
-// ============================================================================
-
-func TestSystemTools_CurrentTime_Success(t *testing.T) {
+func TestSystem_CurrentTime_Success(t *testing.T) {
t.Parallel()
h := newsystemTools(t)
- st := h.createSystemTools()
+ st := h.createSystem()
result, err := st.CurrentTime(nil, CurrentTimeInput{})
diff --git a/internal/tools/system_test.go b/internal/tools/system_test.go
index 2950310..3677005 100644
--- a/internal/tools/system_test.go
+++ b/internal/tools/system_test.go
@@ -6,41 +6,41 @@ import (
"github.com/koopa0/koopa/internal/security"
)
-func TestSystemTools_Constructor(t *testing.T) {
+func TestSystem_Constructor(t *testing.T) {
t.Run("valid inputs", func(t *testing.T) {
cmdVal := security.NewCommand()
envVal := security.NewEnv()
- st, err := NewSystemTools(cmdVal, envVal, testLogger())
+ st, err := NewSystem(cmdVal, envVal, testLogger())
if err != nil {
- t.Errorf("NewSystemTools() error = %v, want nil", err)
+ t.Errorf("NewSystem() error = %v, want nil", err)
}
if st == nil {
- t.Error("NewSystemTools() returned nil, want non-nil")
+ t.Error("NewSystem() returned nil, want non-nil")
}
})
t.Run("nil command validator", func(t *testing.T) {
envVal := security.NewEnv()
- st, err := NewSystemTools(nil, envVal, testLogger())
+ st, err := NewSystem(nil, envVal, testLogger())
if err == nil {
- t.Error("NewSystemTools() error = nil, want error")
+ t.Error("NewSystem() error = nil, want error")
}
if st != nil {
- t.Error("NewSystemTools() returned non-nil, want nil")
+ t.Error("NewSystem() returned non-nil, want nil")
}
})
t.Run("nil env validator", func(t *testing.T) {
cmdVal := security.NewCommand()
- st, err := NewSystemTools(cmdVal, nil, testLogger())
+ st, err := NewSystem(cmdVal, nil, testLogger())
if err == nil {
- t.Error("NewSystemTools() error = nil, want error")
+ t.Error("NewSystem() error = nil, want error")
}
if st != nil {
- t.Error("NewSystemTools() returned non-nil, want nil")
+ t.Error("NewSystem() returned non-nil, want nil")
}
})
@@ -48,59 +48,30 @@ func TestSystemTools_Constructor(t *testing.T) {
cmdVal := security.NewCommand()
envVal := security.NewEnv()
- st, err := NewSystemTools(cmdVal, envVal, nil)
+ st, err := NewSystem(cmdVal, envVal, nil)
if err == nil {
- t.Error("NewSystemTools() error = nil, want error")
+ t.Error("NewSystem() error = nil, want error")
}
if st != nil {
- t.Error("NewSystemTools() returned non-nil, want nil")
+ t.Error("NewSystem() returned non-nil, want nil")
}
})
}
func TestSystemToolConstants(t *testing.T) {
expectedNames := map[string]string{
- "ToolCurrentTime": "current_time",
- "ToolExecuteCommand": "execute_command",
- "ToolGetEnv": "get_env",
+ "CurrentTimeName": "current_time",
+ "ExecuteCommandName": "execute_command",
+ "GetEnvName": "get_env",
}
- if ToolCurrentTime != expectedNames["ToolCurrentTime"] {
- t.Errorf("ToolCurrentTime = %q, want %q", ToolCurrentTime, expectedNames["ToolCurrentTime"])
+ if CurrentTimeName != expectedNames["CurrentTimeName"] {
+ t.Errorf("CurrentTimeName = %q, want %q", CurrentTimeName, expectedNames["CurrentTimeName"])
}
- if ToolExecuteCommand != expectedNames["ToolExecuteCommand"] {
- t.Errorf("ToolExecuteCommand = %q, want %q", ToolExecuteCommand, expectedNames["ToolExecuteCommand"])
+ if ExecuteCommandName != expectedNames["ExecuteCommandName"] {
+ t.Errorf("ExecuteCommandName = %q, want %q", ExecuteCommandName, expectedNames["ExecuteCommandName"])
}
- if ToolGetEnv != expectedNames["ToolGetEnv"] {
- t.Errorf("ToolGetEnv = %q, want %q", ToolGetEnv, expectedNames["ToolGetEnv"])
+ if GetEnvName != expectedNames["GetEnvName"] {
+ t.Errorf("GetEnvName = %q, want %q", GetEnvName, expectedNames["GetEnvName"])
}
}
-
-func TestExecuteCommandInput(t *testing.T) {
- input := ExecuteCommandInput{
- Command: "ls",
- Args: []string{"-la", "/tmp"},
- }
- if input.Command != "ls" {
- t.Errorf("ExecuteCommandInput.Command = %q, want %q", input.Command, "ls")
- }
- if len(input.Args) != 2 {
- t.Errorf("ExecuteCommandInput.Args length = %d, want 2", len(input.Args))
- }
- if input.Args[0] != "-la" {
- t.Errorf("ExecuteCommandInput.Args[0] = %q, want %q", input.Args[0], "-la")
- }
-}
-
-func TestGetEnvInput(t *testing.T) {
- input := GetEnvInput{Key: "PATH"}
- if input.Key != "PATH" {
- t.Errorf("GetEnvInput.Key = %q, want %q", input.Key, "PATH")
- }
-}
-
-func TestCurrentTimeInput(t *testing.T) {
- // CurrentTimeInput is an empty struct
- input := CurrentTimeInput{}
- _ = input // Just verify it can be created
-}
diff --git a/internal/tools/types.go b/internal/tools/types.go
index bcbe524..32e20cd 100644
--- a/internal/tools/types.go
+++ b/internal/tools/types.go
@@ -1,15 +1,5 @@
-// Package tools provides tool types and result helpers for agent tool operations.
-//
-// Error Handling:
-// - All tools return Result with structured error information
-// - Business errors (validation, not found, etc.) use Result.Error
-// - Only infrastructure errors (context cancellation) return Go error
package tools
-// ============================================================================
-// Status and ErrorCode (for JSON responses to LLM)
-// ============================================================================
-
// Status represents the execution status of a tool.
type Status string
@@ -35,10 +25,6 @@ const (
ErrCodeValidation ErrorCode = "ValidationError"
)
-// ============================================================================
-// Result Types (for structured JSON responses)
-// ============================================================================
-
// Result is the standard return format for all tools.
type Result struct {
Status Status `json:"status" jsonschema_description:"The execution status"`
diff --git a/internal/tools/types_test.go b/internal/tools/types_test.go
index 8f735c9..436a50b 100644
--- a/internal/tools/types_test.go
+++ b/internal/tools/types_test.go
@@ -10,17 +10,17 @@ func TestResult_Success(t *testing.T) {
result := Result{Status: StatusSuccess, Data: data}
if result.Status != StatusSuccess {
- t.Errorf("Status = %v, want %v", result.Status, StatusSuccess)
+ t.Errorf("Result{...}.Status = %v, want %v", result.Status, StatusSuccess)
}
if result.Data == nil {
- t.Fatal("Data is nil, want non-nil")
+ t.Fatal("Result{...}.Data is nil, want non-nil")
}
dataMap, ok := result.Data.(map[string]any)
if !ok {
- t.Fatalf("Data type = %T, want map[string]any", result.Data)
+ t.Fatalf("Result{...}.Data type = %T, want map[string]any", result.Data)
}
if dataMap["path"] != "/tmp/test" {
- t.Errorf("Data[path] = %v, want /tmp/test", dataMap["path"])
+ t.Errorf("Result{...}.Data[\"path\"] = %v, want %q", dataMap["path"], "/tmp/test")
}
})
@@ -28,10 +28,10 @@ func TestResult_Success(t *testing.T) {
result := Result{Status: StatusSuccess}
if result.Status != StatusSuccess {
- t.Errorf("Status = %v, want %v", result.Status, StatusSuccess)
+ t.Errorf("Result{...}.Status = %v, want %v", result.Status, StatusSuccess)
}
if result.Data != nil {
- t.Errorf("Data = %v, want nil", result.Data)
+ t.Errorf("Result{...}.Data = %v, want nil", result.Data)
}
})
}
@@ -42,9 +42,9 @@ func TestResult_Error(t *testing.T) {
code ErrorCode
message string
}{
- {"security error", ErrCodeSecurity, "access denied"},
- {"not found error", ErrCodeNotFound, "file not found"},
- {"execution error", ErrCodeExecution, "command failed"},
+ {name: "security error", code: ErrCodeSecurity, message: "access denied"},
+ {name: "not found error", code: ErrCodeNotFound, message: "file not found"},
+ {name: "execution error", code: ErrCodeExecution, message: "executing command"},
}
for _, tt := range tests {
@@ -55,19 +55,19 @@ func TestResult_Error(t *testing.T) {
}
if result.Status != StatusError {
- t.Errorf("Status = %v, want %v", result.Status, StatusError)
+ t.Errorf("Result{...}.Status = %v, want %v", result.Status, StatusError)
}
if result.Data != nil {
- t.Errorf("Data = %v, want nil", result.Data)
+ t.Errorf("Result{...}.Data = %v, want nil", result.Data)
}
if result.Error == nil {
- t.Fatal("Error is nil, want non-nil")
+ t.Fatal("Result{...}.Error is nil, want non-nil")
}
if result.Error.Code != tt.code {
- t.Errorf("Error.Code = %v, want %v", result.Error.Code, tt.code)
+ t.Errorf("Result{...}.Error.Code = %v, want %v", result.Error.Code, tt.code)
}
if result.Error.Message != tt.message {
- t.Errorf("Error.Message = %v, want %v", result.Error.Message, tt.message)
+ t.Errorf("Result{...}.Error.Message = %q, want %q", result.Error.Message, tt.message)
}
})
}
@@ -83,35 +83,35 @@ func TestResult_ErrorWithDetails(t *testing.T) {
Status: StatusError,
Error: &Error{
Code: ErrCodeExecution,
- Message: "command failed",
+ Message: "executing command",
Details: details,
},
}
if result.Status != StatusError {
- t.Errorf("Status = %v, want %v", result.Status, StatusError)
+ t.Errorf("Result{...}.Status = %v, want %v", result.Status, StatusError)
}
if result.Error == nil {
- t.Fatal("Error is nil, want non-nil")
+ t.Fatal("Result{...}.Error is nil, want non-nil")
}
if result.Error.Details == nil {
- t.Error("Error.Details is nil, want non-nil")
+ t.Error("Result{...}.Error.Details is nil, want non-nil")
}
detailsMap, ok := result.Error.Details.(map[string]any)
if !ok {
- t.Fatalf("Error.Details type = %T, want map[string]any", result.Error.Details)
+ t.Fatalf("Result{...}.Error.Details type = %T, want map[string]any", result.Error.Details)
}
if detailsMap["command"] != "ls" {
- t.Errorf("Error.Details[command] = %v, want ls", detailsMap["command"])
+ t.Errorf("Result{...}.Error.Details[\"command\"] = %v, want %q", detailsMap["command"], "ls")
}
}
func TestStatusConstants(t *testing.T) {
if StatusSuccess != "success" {
- t.Errorf("StatusSuccess = %v, want success", StatusSuccess)
+ t.Errorf("StatusSuccess = %q, want %q", StatusSuccess, "success")
}
if StatusError != "error" {
- t.Errorf("StatusError = %v, want error", StatusError)
+ t.Errorf("StatusError = %q, want %q", StatusError, "error")
}
}
@@ -127,9 +127,9 @@ func TestErrorCodeConstants(t *testing.T) {
ErrCodeValidation: "ValidationError",
}
- for code, expected := range codes {
- if string(code) != expected {
- t.Errorf("ErrorCode %v = %v, want %v", code, string(code), expected)
+ for code, want := range codes {
+ if string(code) != want {
+ t.Errorf("ErrorCode(%q) = %q, want %q", code, string(code), want)
}
}
}
diff --git a/internal/tui/benchmark_test.go b/internal/tui/benchmark_test.go
index 0e2569a..b898761 100644
--- a/internal/tui/benchmark_test.go
+++ b/internal/tui/benchmark_test.go
@@ -8,15 +8,15 @@ import (
"charm.land/bubbles/v2/textarea"
tea "charm.land/bubbletea/v2"
- "github.com/koopa0/koopa/internal/agent/chat"
+ "github.com/koopa0/koopa/internal/chat"
)
-// newBenchmarkTUI creates a TUI for benchmarking with minimal setup.
-func newBenchmarkTUI() *TUI {
+// newBenchmarkModel creates a Model for benchmarking with minimal setup.
+func newBenchmarkModel() *Model {
ta := textarea.New()
ta.SetHeight(3)
ta.ShowLineNumbers = false
- return &TUI{
+ return &Model{
state: StateInput,
input: ta,
history: make([]string, 0, maxHistory),
@@ -32,7 +32,7 @@ func newBenchmarkTUI() *TUI {
// BenchmarkTUI_View measures View rendering performance.
func BenchmarkTUI_View(b *testing.B) {
b.Run("empty", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
@@ -41,7 +41,7 @@ func BenchmarkTUI_View(b *testing.B) {
})
b.Run("10_messages", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
for i := 0; i < 10; i++ {
tui.addMessage(Message{Role: "user", Text: "Hello, this is a test message"})
tui.addMessage(Message{Role: "assistant", Text: "This is a response with some content"})
@@ -54,7 +54,7 @@ func BenchmarkTUI_View(b *testing.B) {
})
b.Run("50_messages", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
for i := 0; i < 50; i++ {
tui.addMessage(Message{Role: "user", Text: "Hello, this is a test message"})
tui.addMessage(Message{Role: "assistant", Text: "This is a response with some content"})
@@ -67,7 +67,7 @@ func BenchmarkTUI_View(b *testing.B) {
})
b.Run("max_messages", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
for i := 0; i < maxMessages; i++ {
tui.addMessage(Message{Role: "user", Text: "Hello, this is a test message with some longer content to simulate real usage"})
}
@@ -79,7 +79,7 @@ func BenchmarkTUI_View(b *testing.B) {
})
b.Run("streaming_state", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
tui.state = StateStreaming
tui.output.WriteString("This is streaming output that is being written in real-time...")
for i := 0; i < 10; i++ {
@@ -94,7 +94,7 @@ func BenchmarkTUI_View(b *testing.B) {
})
b.Run("thinking_state", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
tui.state = StateThinking
b.ResetTimer()
b.ReportAllocs()
@@ -104,7 +104,7 @@ func BenchmarkTUI_View(b *testing.B) {
})
b.Run("large_messages", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
largeText := strings.Repeat("This is a large message with lots of content. ", 100)
for i := 0; i < 20; i++ {
tui.addMessage(Message{Role: "assistant", Text: largeText})
@@ -120,7 +120,7 @@ func BenchmarkTUI_View(b *testing.B) {
// BenchmarkTUI_AddMessage measures message addition performance.
func BenchmarkTUI_AddMessage(b *testing.B) {
b.Run("single", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
msg := Message{Role: "user", Text: "Hello"}
b.ResetTimer()
b.ReportAllocs()
@@ -131,7 +131,7 @@ func BenchmarkTUI_AddMessage(b *testing.B) {
})
b.Run("with_bounds_check", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
// Pre-fill to near capacity
for i := 0; i < maxMessages-1; i++ {
tui.messages = append(tui.messages, Message{Role: "user", Text: "test"})
@@ -149,7 +149,7 @@ func BenchmarkTUI_AddMessage(b *testing.B) {
})
b.Run("at_capacity", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
// Fill to capacity
for i := 0; i < maxMessages; i++ {
tui.messages = append(tui.messages, Message{Role: "user", Text: "test"})
@@ -166,30 +166,30 @@ func BenchmarkTUI_AddMessage(b *testing.B) {
// BenchmarkTUI_Update measures Update loop performance.
func BenchmarkTUI_Update(b *testing.B) {
b.Run("key_press", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
key := tea.Key{Code: 'a'}
msg := tea.KeyPressMsg(key)
b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
model, _ := tui.Update(msg)
- tui = model.(*TUI)
+ tui = model.(*Model)
}
})
b.Run("window_resize", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
msg := tea.WindowSizeMsg{Width: 120, Height: 40}
b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
model, _ := tui.Update(msg)
- tui = model.(*TUI)
+ tui = model.(*Model)
}
})
b.Run("stream_text_msg", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
tui.state = StateStreaming
eventCh := make(chan streamEvent, 1)
tui.streamEventCh = eventCh
@@ -199,7 +199,7 @@ func BenchmarkTUI_Update(b *testing.B) {
b.ReportAllocs()
for b.Loop() {
model, _ := tui.Update(msg)
- tui = model.(*TUI)
+ tui = model.(*Model)
tui.output.Reset() // Reset to avoid unbounded growth
}
})
@@ -208,14 +208,14 @@ func BenchmarkTUI_Update(b *testing.B) {
// BenchmarkTUI_NavigateHistory measures history navigation performance.
func BenchmarkTUI_NavigateHistory(b *testing.B) {
b.Run("small_history", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
tui.history = []string{"one", "two", "three"}
tui.historyIdx = 1
b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
model, _ := tui.navigateHistory(-1)
- tui = model.(*TUI)
+ tui = model.(*Model)
if tui.historyIdx == 0 {
tui.historyIdx = len(tui.history)
}
@@ -223,7 +223,7 @@ func BenchmarkTUI_NavigateHistory(b *testing.B) {
})
b.Run("large_history", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
for i := 0; i < maxHistory; i++ {
tui.history = append(tui.history, "history entry "+string(rune('a'+i%26)))
}
@@ -232,7 +232,7 @@ func BenchmarkTUI_NavigateHistory(b *testing.B) {
b.ReportAllocs()
for b.Loop() {
model, _ := tui.navigateHistory(-1)
- tui = model.(*TUI)
+ tui = model.(*Model)
if tui.historyIdx == 0 {
tui.historyIdx = len(tui.history)
}
@@ -376,18 +376,18 @@ func BenchmarkStyles(b *testing.B) {
// BenchmarkTUI_HandleSlashCommand measures slash command handling performance.
func BenchmarkTUI_HandleSlashCommand(b *testing.B) {
b.Run("help", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
tui.messages = tui.messages[:0] // Reset messages
model, _ := tui.handleSlashCommand("/help")
- tui = model.(*TUI)
+ tui = model.(*Model)
}
})
b.Run("clear", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
for i := 0; i < 10; i++ {
tui.addMessage(Message{Role: "user", Text: "test"})
}
@@ -396,18 +396,18 @@ func BenchmarkTUI_HandleSlashCommand(b *testing.B) {
for b.Loop() {
tui.messages = []Message{{Role: "user", Text: "test"}}
model, _ := tui.handleSlashCommand("/clear")
- tui = model.(*TUI)
+ tui = model.(*Model)
}
})
b.Run("unknown", func(b *testing.B) {
- tui := newBenchmarkTUI()
+ tui := newBenchmarkModel()
b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
tui.messages = tui.messages[:0]
model, _ := tui.handleSlashCommand("/unknown")
- tui = model.(*TUI)
+ tui = model.(*Model)
}
})
}
diff --git a/internal/tui/fuzz_test.go b/internal/tui/fuzz_test.go
index 5cf1fda..d73e5ac 100644
--- a/internal/tui/fuzz_test.go
+++ b/internal/tui/fuzz_test.go
@@ -31,12 +31,12 @@ func FuzzTUI_HandleSlashCommand(f *testing.F) {
return
}
- tui := newTestTUI()
+ tui := newTestModel()
tui.messages = []Message{{Role: "user", Text: "hello"}}
// Should never panic
model, resultCmd := tui.handleSlashCommand(cmd)
- result := model.(*TUI)
+ result := model.(*Model)
// Basic invariants
if result == nil {
@@ -71,13 +71,13 @@ func FuzzTUI_NavigateHistory(f *testing.F) {
f.Add(-1000000)
f.Fuzz(func(t *testing.T, delta int) {
- tui := newTestTUI()
+ tui := newTestModel()
tui.history = []string{"first", "second", "third"}
tui.historyIdx = 1
// Should never panic
model, _ := tui.navigateHistory(delta)
- result := model.(*TUI)
+ result := model.(*Model)
// Index should be within bounds
if result.historyIdx < 0 {
@@ -104,7 +104,7 @@ func FuzzTUI_AddMessage(f *testing.F) {
f.Add("user", "\x00\x01\x02") // Binary
f.Fuzz(func(t *testing.T, role, text string) {
- tui := newTestTUI()
+ tui := newTestModel()
// Should never panic
tui.addMessage(Message{Role: role, Text: text})
@@ -138,7 +138,7 @@ func FuzzTUI_KeyPress(f *testing.F) {
f.Add(int32(tea.KeySpace), 0) //nolint:unconvert // f.Add requires exact types
f.Fuzz(func(t *testing.T, code int32, mod int) {
- tui := newTestTUI()
+ tui := newTestModel()
// Use background context to avoid nil pointer issues
tui.ctx = context.Background()
@@ -166,7 +166,7 @@ func FuzzTUI_View(f *testing.F) {
f.Add(0, 10000, 1) // Very wide
f.Fuzz(func(t *testing.T, state, width, height int) {
- tui := newTestTUI()
+ tui := newTestModel()
// Set state (bounded to valid values)
if state >= 0 && state <= 2 {
@@ -219,7 +219,7 @@ func FuzzMarkdownRenderer_Render(f *testing.F) {
f.Fuzz(func(t *testing.T, markdown string) {
mr := newMarkdownRenderer(80)
if mr == nil {
- t.Skip("Failed to create markdown renderer")
+ t.Skip("newMarkdownRenderer() returned nil")
}
// Should never panic
@@ -251,7 +251,7 @@ func FuzzMarkdownRenderer_UpdateWidth(f *testing.F) {
f.Fuzz(func(t *testing.T, width int) {
mr := newMarkdownRenderer(80)
if mr == nil {
- t.Skip("Failed to create markdown renderer")
+ t.Skip("newMarkdownRenderer() returned nil")
}
// Should never panic
diff --git a/internal/tui/i18n.go b/internal/tui/i18n.go
new file mode 100644
index 0000000..c443bc6
--- /dev/null
+++ b/internal/tui/i18n.go
@@ -0,0 +1,27 @@
+package tui
+
+// toolDisplayNames maps tool names to localized display names.
+var toolDisplayNames = map[string]string{
+ "web_search": "搜尋網路",
+ "web_fetch": "讀取網頁",
+ "read_file": "讀取檔案",
+ "write_file": "寫入檔案",
+ "list_files": "瀏覽目錄",
+ "delete_file": "刪除檔案",
+ "get_file_info": "取得檔案資訊",
+ "execute_command": "執行命令",
+ "current_time": "取得時間",
+ "get_env": "取得環境變數",
+ "search_history": "搜尋對話記錄",
+ "search_documents": "搜尋知識庫",
+ "search_system_knowledge": "搜尋系統知識",
+ "knowledge_store": "儲存知識",
+}
+
+// toolDisplayName returns a localized display name for a tool.
+func toolDisplayName(name string) string {
+ if display, ok := toolDisplayNames[name]; ok {
+ return display
+ }
+ return name
+}
diff --git a/internal/tui/integration_test.go b/internal/tui/integration_test.go
index 08f4bde..5a80ea9 100644
--- a/internal/tui/integration_test.go
+++ b/internal/tui/integration_test.go
@@ -10,7 +10,8 @@ import (
"go.uber.org/goleak"
- "github.com/koopa0/koopa/internal/agent/chat"
+ "github.com/google/uuid"
+ "github.com/koopa0/koopa/internal/chat"
)
func TestMain(m *testing.M) {
@@ -25,17 +26,17 @@ func TestMain(m *testing.M) {
}
// createTestSession creates a session in the database and returns its ID and cleanup function.
-func createTestSession(t *testing.T, setup *chatFlowSetup) (string, func()) {
+func createTestSession(t *testing.T, setup *chatFlowSetup) (uuid.UUID, func()) {
t.Helper()
- sess, err := setup.SessionStore.CreateSession(setup.Ctx, "test-session", "gemini-2.0-flash", "")
+ sess, err := setup.SessionStore.CreateSession(setup.Ctx, "test-session")
if err != nil {
- t.Fatalf("Failed to create test session: %v", err)
+ t.Fatalf("CreateSession() error: %v", err)
}
cleanup := func() {
// Use background context for cleanup since test context may be canceled
_ = setup.SessionStore.DeleteSession(context.Background(), sess.ID)
}
- return sess.ID.String(), cleanup
+ return sess.ID, cleanup
}
func TestTUI_Integration_StartStream_Success(t *testing.T) {
@@ -50,7 +51,7 @@ func TestTUI_Integration_StartStream_Success(t *testing.T) {
defer sessionCleanup()
tui, err := New(setup.Ctx, setup.Flow, sessionID)
if err != nil {
- t.Fatalf("Failed to create TUI: %v", err)
+ t.Fatalf("New() error: %v", err)
}
// Start a stream with a simple query
@@ -132,7 +133,7 @@ func TestTUI_Integration_StartStream_Cancellation(t *testing.T) {
defer sessionCleanup()
tui, err := New(setup.Ctx, setup.Flow, sessionID)
if err != nil {
- t.Fatalf("Failed to create TUI: %v", err)
+ t.Fatalf("New() error: %v", err)
}
// Start a stream with a long query
@@ -198,7 +199,7 @@ func TestTUI_Integration_HandleSubmit_StateTransition(t *testing.T) {
defer sessionCleanup()
tui, err := New(setup.Ctx, setup.Flow, sessionID)
if err != nil {
- t.Fatalf("Failed to create TUI: %v", err)
+ t.Fatalf("New() error: %v", err)
}
// Set input value
@@ -206,7 +207,7 @@ func TestTUI_Integration_HandleSubmit_StateTransition(t *testing.T) {
// Call handleSubmit
model, cmd := tui.handleSubmit()
- result := model.(*TUI)
+ result := model.(*Model)
// Verify state changed to thinking
if result.state != StateThinking {
@@ -355,13 +356,13 @@ func TestTUI_Integration_ViewDuringStreaming(t *testing.T) {
defer sessionCleanup()
tui, err := New(setup.Ctx, setup.Flow, sessionID)
if err != nil {
- t.Fatalf("Failed to create TUI: %v", err)
+ t.Fatalf("New() error: %v", err)
}
// Start streaming
tui.input.SetValue("Tell me about Go programming")
model, cmd := tui.handleSubmit()
- tui = model.(*TUI)
+ tui = model.(*Model)
// Call View during different states and verify no panic
_ = tui.View()
@@ -380,7 +381,7 @@ func TestTUI_Integration_ViewDuringStreaming(t *testing.T) {
break
}
model, cmd = tui.Update(msg)
- tui = model.(*TUI)
+ tui = model.(*Model)
// View should work in any state
_ = tui.View()
diff --git a/internal/tui/keys.go b/internal/tui/keys.go
index e281d8b..8ea5f07 100644
--- a/internal/tui/keys.go
+++ b/internal/tui/keys.go
@@ -42,193 +42,193 @@ func newKeyMap() keyMap {
}
//nolint:gocyclo // Keyboard handler requires branching for all key combinations
-func (t *TUI) handleKey(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) {
+func (m *Model) handleKey(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) {
k := msg.Key()
// Check for Ctrl modifier
if k.Mod&tea.ModCtrl != 0 {
switch k.Code {
case 'c':
- return t.handleCtrlC()
+ return m.handleCtrlC()
case 'd':
- cmd := t.cleanup()
- return t, cmd
+ cmd := m.cleanup()
+ return m, cmd
}
}
// Check special keys
switch k.Code {
case tea.KeyEnter:
- if t.state == StateInput {
+ if m.state == StateInput {
// Enter without Shift = submit
// Shift+Enter = newline (pass through to textarea)
if k.Mod&tea.ModShift == 0 {
- return t.handleSubmit()
+ return m.handleSubmit()
}
}
case tea.KeyUp:
// Up at first line navigates history, otherwise pass to textarea
- if t.state == StateInput && t.input.Line() == 0 {
- return t.navigateHistory(-1)
+ if m.state == StateInput && m.input.Line() == 0 {
+ return m.navigateHistory(-1)
}
case tea.KeyDown:
// Down at last line navigates history, otherwise pass to textarea
- if t.state == StateInput && t.input.Line() == t.input.LineCount()-1 {
- return t.navigateHistory(1)
+ if m.state == StateInput && m.input.Line() == m.input.LineCount()-1 {
+ return m.navigateHistory(1)
}
case tea.KeyEscape:
- if t.state == StateStreaming || t.state == StateThinking {
- t.cancelStream()
- t.state = StateInput
- t.output.Reset()
- return t, nil
+ if m.state == StateStreaming || m.state == StateThinking {
+ m.cancelStream()
+ m.state = StateInput
+ m.output.Reset()
+ return m, nil
}
case tea.KeyPgUp:
- t.viewport.PageUp()
- return t, nil
+ m.viewport.PageUp()
+ return m, nil
case tea.KeyPgDown:
- t.viewport.PageDown()
- return t, nil
+ m.viewport.PageDown()
+ return m, nil
}
// Pass keys to textarea for typing - ALWAYS allow typing even during streaming
// Better UX: users can prepare next message while LLM responds
var cmd tea.Cmd
- t.input, cmd = t.input.Update(msg)
- return t, cmd
+ m.input, cmd = m.input.Update(msg)
+ return m, cmd
}
-func (t *TUI) handleCtrlC() (tea.Model, tea.Cmd) {
+func (m *Model) handleCtrlC() (tea.Model, tea.Cmd) {
now := time.Now()
// Double Ctrl+C within 1 second = quit
- if now.Sub(t.lastCtrlC) < time.Second {
- cmd := t.cleanup()
- return t, cmd
+ if now.Sub(m.lastCtrlC) < time.Second {
+ cmd := m.cleanup()
+ return m, cmd
}
- t.lastCtrlC = now
+ m.lastCtrlC = now
- switch t.state {
+ switch m.state {
case StateInput:
- t.input.Reset()
- return t, nil
+ m.input.Reset()
+ return m, nil
case StateThinking, StateStreaming:
- t.cancelStream()
- t.state = StateInput
- t.output.Reset()
- t.addMessage(Message{Role: "system", Text: "(Canceled)"})
- return t, nil
+ m.cancelStream()
+ m.state = StateInput
+ m.output.Reset()
+ m.addMessage(Message{Role: roleSystem, Text: "(Canceled)"})
+ return m, nil
}
- return t, nil
+ return m, nil
}
-func (t *TUI) handleSubmit() (tea.Model, tea.Cmd) {
- query := strings.TrimSpace(t.input.Value())
+func (m *Model) handleSubmit() (tea.Model, tea.Cmd) {
+ query := strings.TrimSpace(m.input.Value())
if query == "" {
- return t, nil
+ return m, nil
}
// Handle slash commands
if strings.HasPrefix(query, "/") {
- return t.handleSlashCommand(query)
+ return m.handleSlashCommand(query)
}
// Add to history (enforce maxHistory cap)
- t.history = append(t.history, query)
- if len(t.history) > maxHistory {
+ m.history = append(m.history, query)
+ if len(m.history) > maxHistory {
// Remove oldest entries to stay within bounds
- t.history = t.history[len(t.history)-maxHistory:]
+ m.history = m.history[len(m.history)-maxHistory:]
}
- t.historyIdx = len(t.history)
+ m.historyIdx = len(m.history)
// Add user message
- t.addMessage(Message{Role: "user", Text: query})
+ m.addMessage(Message{Role: roleUser, Text: query})
// Clear input
- t.input.Reset()
+ m.input.Reset()
// Start thinking
- t.state = StateThinking
+ m.state = StateThinking
- return t, tea.Batch(
- t.spinner.Tick,
- t.startStream(query),
+ return m, tea.Batch(
+ m.spinner.Tick,
+ m.startStream(query),
)
}
-func (t *TUI) handleSlashCommand(cmd string) (tea.Model, tea.Cmd) {
+func (m *Model) handleSlashCommand(cmd string) (tea.Model, tea.Cmd) {
switch cmd {
case cmdHelp:
- t.addMessage(Message{
+ m.addMessage(Message{
Role: roleSystem,
Text: "Commands: " + cmdHelp + ", " + cmdClear + ", " + cmdExit + "\nShortcuts:\n Enter: send message\n Shift+Enter: new line\n Ctrl+C: cancel/clear\n Ctrl+D: exit\n Up/Down: history\n PgUp/PgDn: scroll",
})
case cmdClear:
- t.messages = nil
+ m.messages = nil
case cmdExit, cmdQuit:
- cleanupCmd := t.cleanup()
- return t, cleanupCmd
+ cleanupCmd := m.cleanup()
+ return m, cleanupCmd
default:
- t.addMessage(Message{
+ m.addMessage(Message{
Role: roleError,
Text: "Unknown command: " + cmd,
})
}
- t.input.Reset()
- return t, nil
+ m.input.Reset()
+ return m, nil
}
-func (t *TUI) navigateHistory(delta int) (tea.Model, tea.Cmd) {
- if len(t.history) == 0 {
- return t, nil
+func (m *Model) navigateHistory(delta int) (tea.Model, tea.Cmd) {
+ if len(m.history) == 0 {
+ return m, nil
}
- t.historyIdx += delta
+ m.historyIdx += delta
- if t.historyIdx < 0 {
- t.historyIdx = 0
+ if m.historyIdx < 0 {
+ m.historyIdx = 0
}
- if t.historyIdx > len(t.history) {
- t.historyIdx = len(t.history)
+ if m.historyIdx > len(m.history) {
+ m.historyIdx = len(m.history)
}
- if t.historyIdx == len(t.history) {
- t.input.SetValue("")
+ if m.historyIdx == len(m.history) {
+ m.input.SetValue("")
} else {
- t.input.SetValue(t.history[t.historyIdx])
+ m.input.SetValue(m.history[m.historyIdx])
// Move cursor to end of text
- t.input.CursorEnd()
+ m.input.CursorEnd()
}
- return t, nil
+ return m, nil
}
-func (t *TUI) cancelStream() {
- if t.streamCancel != nil {
- t.streamCancel()
- t.streamCancel = nil
+func (m *Model) cancelStream() {
+ if m.streamCancel != nil {
+ m.streamCancel()
+ m.streamCancel = nil
}
}
// cleanup cancels any active stream and returns the quit command.
// Waits for goroutine exit with timeout to prevent resource leaks.
-func (t *TUI) cleanup() tea.Cmd {
- // Cancel main context first - this triggers all goroutines using t.ctx
- if t.ctxCancel != nil {
- t.ctxCancel()
- t.ctxCancel = nil
+func (m *Model) cleanup() tea.Cmd {
+ // Cancel main context first - this triggers all goroutines using m.ctx
+ if m.ctxCancel != nil {
+ m.ctxCancel()
+ m.ctxCancel = nil
}
// Then cancel stream-specific context (may already be canceled via parent)
- t.cancelStream()
- t.streamEventCh = nil
+ m.cancelStream()
+ m.streamEventCh = nil
return tea.Quit
}
diff --git a/internal/tui/model.go b/internal/tui/model.go
new file mode 100644
index 0000000..467dabc
--- /dev/null
+++ b/internal/tui/model.go
@@ -0,0 +1,198 @@
+// Package tui provides Bubble Tea terminal interface for Koopa.
+package tui
+
+import (
+ "context"
+ "errors"
+ "strings"
+ "time"
+
+ "charm.land/bubbles/v2/help"
+ "charm.land/bubbles/v2/spinner"
+ "charm.land/bubbles/v2/textarea"
+ "charm.land/bubbles/v2/viewport"
+ tea "charm.land/bubbletea/v2"
+ "charm.land/lipgloss/v2"
+
+ "github.com/google/uuid"
+ "github.com/koopa0/koopa/internal/chat"
+)
+
+// State represents TUI state machine.
+type State int
+
+// TUI state machine states.
+const (
+ StateInput State = iota // Awaiting user input
+ StateThinking // Processing request
+ StateStreaming // Streaming response
+)
+
+// Memory bounds to prevent unbounded growth.
+const (
+ maxMessages = 100 // Maximum messages stored
+ maxHistory = 100 // Maximum command history entries
+)
+
+// Timeout constants for stream operations.
+const streamTimeout = 5 * time.Minute // Maximum time for a single stream
+
+// Message role constants for consistent display.
+const (
+ roleUser = "user"
+ roleAssistant = "assistant"
+ roleSystem = "system"
+ roleError = "error"
+)
+
+// Layout constants for viewport height calculation.
+const (
+ separatorLines = 2 // Two separator lines (above and below input)
+ helpLines = 1 // Help bar height
+ promptLines = 1 // Prompt prefix line
+ minViewport = 3 // Minimum viewport height
+)
+
+// Message represents a conversation message for display.
+type Message struct {
+ Role string // "user", "assistant", "system", "error"
+ Text string
+}
+
+// Model is the Bubble Tea model for Koopa terminal interface.
+type Model struct {
+ // Input (textarea for multi-line support, Shift+Enter for newline)
+ input textarea.Model
+ history []string
+ historyIdx int
+
+ // State
+ state State
+ lastCtrlC time.Time
+
+ // Output
+ spinner spinner.Model
+ output strings.Builder
+ viewBuf strings.Builder // Reusable buffer for View() to reduce allocations
+ messages []Message
+
+ // Scrollable message viewport
+ viewport viewport.Model
+
+ // Help bar for keyboard shortcuts
+ help help.Model
+ keys keyMap
+
+ // Stream management
+ // Note: No sync.WaitGroup - Bubble Tea's event loop provides synchronization.
+ // Single union channel with discriminated events simplifies select logic.
+ streamCancel context.CancelFunc
+ streamEventCh <-chan streamEvent
+ toolStatus string // Current tool status (e.g., "搜尋網路..."), empty when idle
+
+ // Dependencies (direct, no interface)
+ chatFlow *chat.Flow
+ sessionID uuid.UUID
+ ctx context.Context
+ ctxCancel context.CancelFunc // For canceling all operations on exit
+
+ // Dimensions
+ width int
+ height int
+
+ // Styles
+ styles Styles
+
+ // Markdown rendering (nil = graceful degradation to plain text)
+ markdown *markdownRenderer
+}
+
+// addMessage appends a message and enforces maxMessages bound.
+func (m *Model) addMessage(msg Message) {
+ m.messages = append(m.messages, msg)
+ if len(m.messages) > maxMessages {
+ // Remove oldest messages to stay within bounds
+ m.messages = m.messages[len(m.messages)-maxMessages:]
+ }
+}
+
+// New creates a Model for chat interaction.
+// Returns error if required dependencies are nil.
+//
+// IMPORTANT: ctx MUST be the same context passed to tea.WithContext()
+// to ensure consistent cancellation behavior.
+func New(ctx context.Context, flow *chat.Flow, sessionID uuid.UUID) (*Model, error) {
+ if flow == nil {
+ return nil, errors.New("tui.New: flow is required")
+ }
+ if ctx == nil {
+ return nil, errors.New("tui.New: ctx is required")
+ }
+ if sessionID == uuid.Nil {
+ return nil, errors.New("tui.New: session ID is required")
+ }
+
+ // Create cancellable context for cleanup on exit
+ ctx, cancel := context.WithCancel(ctx)
+
+ // Create textarea for multi-line input
+ // Enter submits, Shift+Enter adds newline (default behavior)
+ ta := textarea.New()
+ ta.Placeholder = "Ask anything..."
+ ta.SetHeight(1) // Single line by default
+ ta.SetWidth(120) // Wide enough for long text, updated on WindowSizeMsg
+ ta.MaxWidth = 0 // No max width limit
+ ta.ShowLineNumbers = false
+
+ // Clean, minimal styling like Claude Code / Gemini CLI
+ // No background colors, just simple text
+ cleanStyle := textarea.StyleState{
+ Base: lipgloss.NewStyle(),
+ Text: lipgloss.NewStyle(),
+ Placeholder: lipgloss.NewStyle().Foreground(lipgloss.Color("240")), // Gray placeholder
+ Prompt: lipgloss.NewStyle(),
+ }
+ ta.SetStyles(textarea.Styles{
+ Focused: cleanStyle,
+ Blurred: cleanStyle,
+ })
+ ta.Focus()
+
+ sp := spinner.New()
+ sp.Spinner = spinner.Dot
+
+ // Create viewport for scrollable message history.
+ // Disable built-in keyboard handling — we route keys explicitly
+ // in handleKey to avoid conflicts with textarea/history navigation.
+ vp := viewport.New(viewport.WithWidth(80), viewport.WithHeight(20))
+ vp.MouseWheelEnabled = true
+ vp.SoftWrap = true
+ vp.KeyMap = viewport.KeyMap{} // Disable default key bindings
+
+ h := help.New()
+
+ return &Model{
+ chatFlow: flow,
+ sessionID: sessionID,
+ ctx: ctx,
+ ctxCancel: cancel,
+ input: ta,
+ spinner: sp,
+ viewport: vp,
+ help: h,
+ keys: newKeyMap(),
+ styles: DefaultStyles(),
+ history: make([]string, 0, maxHistory),
+ markdown: newMarkdownRenderer(80),
+ width: 80, // Default width until WindowSizeMsg arrives
+ }, nil
+}
+
+// Init implements tea.Model.
+func (m *Model) Init() tea.Cmd {
+ return tea.Batch(
+ textarea.Blink,
+ m.spinner.Tick,
+ m.input.Focus(), // Ensure textarea is focused on startup
+ )
+}
diff --git a/internal/tui/setup_test.go b/internal/tui/setup_test.go
index 2ee1653..9dbfb63 100644
--- a/internal/tui/setup_test.go
+++ b/internal/tui/setup_test.go
@@ -12,12 +12,10 @@ package tui
import (
"context"
- "fmt"
"io"
"log/slog"
"os"
"path/filepath"
- "runtime"
"testing"
"time"
@@ -26,12 +24,13 @@ import (
"github.com/firebase/genkit/go/plugins/postgresql"
"github.com/jackc/pgx/v5/pgxpool"
- "github.com/koopa0/koopa/internal/agent/chat"
+ "github.com/koopa0/koopa/internal/chat"
"github.com/koopa0/koopa/internal/config"
"github.com/koopa0/koopa/internal/rag"
"github.com/koopa0/koopa/internal/security"
"github.com/koopa0/koopa/internal/session"
"github.com/koopa0/koopa/internal/sqlc"
+ "github.com/koopa0/koopa/internal/testutil"
"github.com/koopa0/koopa/internal/tools"
)
@@ -44,27 +43,6 @@ type chatFlowSetup struct {
Cancel context.CancelFunc
}
-// findProjectRoot finds the project root directory by looking for go.mod.
-func findProjectRoot() (string, error) {
- _, filename, _, ok := runtime.Caller(0)
- if !ok {
- return "", fmt.Errorf("runtime.Caller failed to get caller info")
- }
-
- dir := filepath.Dir(filename)
- for {
- goModPath := filepath.Join(dir, "go.mod")
- if _, err := os.Stat(goModPath); err == nil {
- return dir, nil
- }
- parent := filepath.Dir(dir)
- if parent == dir {
- return "", fmt.Errorf("go.mod not found in any parent directory of %s", filename)
- }
- dir = parent
- }
-}
-
// setupChatFlow creates a complete chat flow setup for integration testing.
//
// This function assembles all dependencies needed for TUI integration tests.
@@ -81,7 +59,9 @@ func findProjectRoot() (string, error) {
// setup, cleanup := setupChatFlow(t)
// defer cleanup()
//
-// tui := New(setup.Ctx, setup.Flow, "test-session-id")
+// sessionID, cleanup := createTestSession(t, setup)
+// defer cleanup()
+// tui := New(setup.Ctx, setup.Flow, sessionID)
// }
func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) {
t.Helper()
@@ -98,10 +78,10 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
- projectRoot, err := findProjectRoot()
+ projectRoot, err := testutil.FindProjectRoot()
if err != nil || projectRoot == "" {
cancel()
- t.Fatalf("Failed to find project root: %v", err)
+ t.Fatalf("FindProjectRoot() error: %v", err)
}
promptsDir := filepath.Join(projectRoot, "prompts")
@@ -109,7 +89,7 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) {
pool, err := pgxpool.New(ctx, dbURL)
if err != nil {
cancel()
- t.Fatalf("Failed to connect to database: %v", err)
+ t.Fatalf("pgxpool.New() error: %v", err)
}
// Create PostgreSQL engine for Genkit
@@ -120,7 +100,7 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) {
if err != nil {
pool.Close()
cancel()
- t.Fatalf("Failed to create PostgresEngine: %v", err)
+ t.Fatalf("NewPostgresEngine() error: %v", err)
}
postgres := &postgresql.Postgres{Engine: pEngine}
@@ -133,15 +113,14 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) {
if g == nil {
pool.Close()
cancel()
- t.Fatal("Failed to initialize Genkit")
+ t.Fatal("genkit.Init() returned nil")
}
logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelWarn}))
cfg := &config.Config{
ModelName: "gemini-2.0-flash",
- EmbedderModel: "text-embedding-004",
- RAGTopK: 5,
+ EmbedderModel: "gemini-embedding-001",
MaxTurns: 10,
}
@@ -150,7 +129,7 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) {
if embedder == nil {
pool.Close()
cancel()
- t.Fatalf("Failed to create embedder for model %q", cfg.EmbedderModel)
+ t.Fatalf("GoogleAIEmbedder(%q) returned nil", cfg.EmbedderModel)
}
// Create DocStore and Retriever using shared config factory
@@ -159,7 +138,7 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) {
if err != nil {
pool.Close()
cancel()
- t.Fatalf("Failed to define retriever: %v", err)
+ t.Fatalf("DefineRetriever() error: %v", err)
}
queries := sqlc.New(pool)
@@ -169,51 +148,51 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) {
if err != nil {
pool.Close()
cancel()
- t.Fatalf("Failed to create path validator: %v", err)
+ t.Fatalf("NewPath() error: %v", err)
}
// Create and register file tools
- ft, err := tools.NewFileTools(pathValidator, logger)
+ ft, err := tools.NewFile(pathValidator, logger)
if err != nil {
pool.Close()
cancel()
- t.Fatalf("Failed to create file tools: %v", err)
+ t.Fatalf("NewFile() error: %v", err)
}
- fileTools, err := tools.RegisterFileTools(g, ft)
+ fileTools, err := tools.RegisterFile(g, ft)
if err != nil {
pool.Close()
cancel()
- t.Fatalf("Failed to register file tools: %v", err)
+ t.Fatalf("RegisterFile() error: %v", err)
}
// Create and register system tools
cmdValidator := security.NewCommand()
envValidator := security.NewEnv()
- st, err := tools.NewSystemTools(cmdValidator, envValidator, logger)
+ st, err := tools.NewSystem(cmdValidator, envValidator, logger)
if err != nil {
pool.Close()
cancel()
- t.Fatalf("Failed to create system tools: %v", err)
+ t.Fatalf("NewSystem() error: %v", err)
}
- systemTools, err := tools.RegisterSystemTools(g, st)
+ systemTools, err := tools.RegisterSystem(g, st)
if err != nil {
pool.Close()
cancel()
- t.Fatalf("Failed to register system tools: %v", err)
+ t.Fatalf("RegisterSystem() error: %v", err)
}
// Create and register knowledge tools
- kt, err := tools.NewKnowledgeTools(retriever, nil, logger)
+ kt, err := tools.NewKnowledge(retriever, nil, logger)
if err != nil {
pool.Close()
cancel()
- t.Fatalf("Failed to create knowledge tools: %v", err)
+ t.Fatalf("NewKnowledge() error: %v", err)
}
- knowledgeTools, err := tools.RegisterKnowledgeTools(g, kt)
+ knowledgeTools, err := tools.RegisterKnowledge(g, kt)
if err != nil {
pool.Close()
cancel()
- t.Fatalf("Failed to register knowledge tools: %v", err)
+ t.Fatalf("RegisterKnowledge() error: %v", err)
}
// Combine all tools
@@ -221,27 +200,20 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) {
chatAgent, err := chat.New(chat.Config{
Genkit: g,
- Retriever: retriever,
SessionStore: sessionStore,
Logger: logger,
Tools: allTools,
MaxTurns: cfg.MaxTurns,
- RAGTopK: cfg.RAGTopK,
})
if err != nil {
pool.Close()
cancel()
- t.Fatalf("Failed to create chat agent: %v", err)
+ t.Fatalf("chat.New() error: %v", err)
}
// Initialize Flow singleton (reset first for test isolation)
chat.ResetFlowForTesting()
- flow, err := chat.InitFlow(g, chatAgent)
- if err != nil {
- pool.Close()
- cancel()
- t.Fatalf("Failed to init chat flow: %v", err)
- }
+ flow := chat.NewFlow(g, chatAgent)
setup := &chatFlowSetup{
Flow: flow,
diff --git a/internal/tui/commands.go b/internal/tui/stream.go
similarity index 69%
rename from internal/tui/commands.go
rename to internal/tui/stream.go
index a326079..38761c1 100644
--- a/internal/tui/commands.go
+++ b/internal/tui/stream.go
@@ -7,7 +7,8 @@ import (
tea "charm.land/bubbletea/v2"
- "github.com/koopa0/koopa/internal/agent/chat"
+ "github.com/koopa0/koopa/internal/chat"
+ "github.com/koopa0/koopa/internal/tools"
)
// streamBufferSize is sized for ~1.5s burst at 60 FPS refresh rate.
@@ -20,10 +21,11 @@ const streamBufferSize = 100
// and eliminates complex multi-channel closure handling.
type streamEvent struct {
// Exactly one of these fields is set per event
- text string // Text chunk (when non-empty)
- output chat.Output // Final output (when done is true)
- err error // Error (when non-nil)
- done bool // True when stream completed successfully
+ text string // Text chunk (when non-empty)
+ output chat.Output // Final output (when done is true)
+ err error // Error (when non-nil)
+ done bool // True when stream completed successfully
+ toolStatus string // Tool status message (when non-empty, e.g. "搜尋網路中...")
}
// Stream message types for Bubble Tea
@@ -44,6 +46,42 @@ type streamErrorMsg struct {
err error
}
+type streamToolMsg struct {
+ status string
+}
+
+// tuiToolEmitter implements tools.Emitter for the TUI.
+// Sends tool status through the stream event channel so Bubble Tea
+// can display tool execution progress (e.g., "搜尋網路中...").
+type tuiToolEmitter struct {
+ eventCh chan<- streamEvent
+}
+
+func (e *tuiToolEmitter) OnToolStart(name string) {
+ display := toolDisplayName(name)
+ select {
+ case e.eventCh <- streamEvent{toolStatus: display + "..."}:
+ default: // best-effort: don't block if channel is full
+ }
+}
+
+func (e *tuiToolEmitter) OnToolComplete(_ string) {
+ select {
+ case e.eventCh <- streamEvent{toolStatus: ""}:
+ default:
+ }
+}
+
+func (e *tuiToolEmitter) OnToolError(_ string) {
+ select {
+ case e.eventCh <- streamEvent{toolStatus: ""}:
+ default:
+ }
+}
+
+// Compile-time interface verification.
+var _ tools.Emitter = (*tuiToolEmitter)(nil)
+
// startStream creates a command that initiates streaming.
// Directly uses *chat.Flow - no adapter needed.
//
@@ -53,12 +91,15 @@ type streamErrorMsg struct {
// 3. Error occurs
//
// Channel closure signals completion - no WaitGroup needed.
-func (t *TUI) startStream(query string) tea.Cmd {
+func (m *Model) startStream(query string) tea.Cmd {
return func() tea.Msg {
eventCh := make(chan streamEvent, streamBufferSize)
// Create context with timeout to prevent indefinite hangs
- ctx, cancel := context.WithTimeout(t.ctx, streamTimeout)
+ ctx, cancel := context.WithTimeout(m.ctx, streamTimeout)
+
+ // Inject tool event emitter so tool status is shown in TUI
+ ctx = tools.ContextWithEmitter(ctx, &tuiToolEmitter{eventCh: eventCh})
go func() {
// Ensure timer resources are released on all exit paths
@@ -81,9 +122,9 @@ func (t *TUI) startStream(query string) tea.Cmd {
// Directly use chat.Flow's iterator (Go 1.23+ range-over-func)
// Genkit's StreamingFlowValue has: {Stream.Text, Output, Done}
- for streamValue, err := range t.chatFlow.Stream(ctx, chat.Input{
+ for streamValue, err := range m.chatFlow.Stream(ctx, chat.Input{
Query: query,
- SessionID: t.sessionID,
+ SessionID: m.sessionID.String(),
}) {
if err != nil {
select {
@@ -154,6 +195,8 @@ func listenForStream(eventCh <-chan streamEvent) tea.Cmd {
return streamErrorMsg{err: event.err}
case event.done:
return streamDoneMsg{output: event.output}
+ case event.toolStatus != "":
+ return streamToolMsg{status: event.toolStatus}
case event.text != "":
return streamTextMsg{text: event.text}
default:
diff --git a/internal/tui/tui.go b/internal/tui/tui.go
deleted file mode 100644
index 44f31b5..0000000
--- a/internal/tui/tui.go
+++ /dev/null
@@ -1,413 +0,0 @@
-// Package tui provides Bubble Tea terminal interface for Koopa.
-package tui
-
-import (
- "context"
- "errors"
- "strings"
- "time"
-
- "charm.land/bubbles/v2/help"
- "charm.land/bubbles/v2/key"
- "charm.land/bubbles/v2/spinner"
- "charm.land/bubbles/v2/textarea"
- "charm.land/bubbles/v2/viewport"
- tea "charm.land/bubbletea/v2"
- "charm.land/lipgloss/v2"
-
- "github.com/koopa0/koopa/internal/agent/chat"
-)
-
-// State represents TUI state machine.
-type State int
-
-// TUI state machine states.
-const (
- StateInput State = iota // Awaiting user input
- StateThinking // Processing request
- StateStreaming // Streaming response
-)
-
-// Memory bounds to prevent unbounded growth.
-const (
- maxMessages = 100 // Maximum messages stored
- maxHistory = 100 // Maximum command history entries
-)
-
-// Timeout constants for stream operations.
-const streamTimeout = 5 * time.Minute // Maximum time for a single stream
-
-// Message role constants for consistent display.
-const (
- roleUser = "user"
- roleAssistant = "assistant"
- roleSystem = "system"
- roleError = "error"
-)
-
-// Layout constants for viewport height calculation.
-const (
- separatorLines = 2 // Two separator lines (above and below input)
- helpLines = 1 // Help bar height
- promptLines = 1 // Prompt prefix line
- minViewport = 3 // Minimum viewport height
-)
-
-// Message represents a conversation message for display.
-type Message struct {
- Role string // "user", "assistant", "system", "error"
- Text string
-}
-
-// TUI is the Bubble Tea model for Koopa terminal interface.
-type TUI struct {
- // Input (textarea for multi-line support, Shift+Enter for newline)
- input textarea.Model
- history []string
- historyIdx int
-
- // State
- state State
- lastCtrlC time.Time
-
- // Output
- spinner spinner.Model
- output strings.Builder
- viewBuf strings.Builder // Reusable buffer for View() to reduce allocations
- messages []Message
-
- // Scrollable message viewport
- viewport viewport.Model
-
- // Help bar for keyboard shortcuts
- help help.Model
- keys keyMap
-
- // Stream management
- // Note: No sync.WaitGroup - Bubble Tea's event loop provides synchronization.
- // Single union channel with discriminated events simplifies select logic.
- streamCancel context.CancelFunc
- streamEventCh <-chan streamEvent
-
- // Dependencies (direct, no interface)
- chatFlow *chat.Flow
- sessionID string
- ctx context.Context
- ctxCancel context.CancelFunc // For canceling all operations on exit
-
- // Dimensions
- width int
- height int
-
- // Styles
- styles Styles
-
- // Markdown rendering (nil = graceful degradation to plain text)
- markdown *markdownRenderer
-}
-
-// addMessage appends a message and enforces maxMessages bound.
-func (t *TUI) addMessage(msg Message) {
- t.messages = append(t.messages, msg)
- if len(t.messages) > maxMessages {
- // Remove oldest messages to stay within bounds
- t.messages = t.messages[len(t.messages)-maxMessages:]
- }
-}
-
-// New creates a TUI model for chat interaction.
-// Returns error if required dependencies are nil.
-//
-// IMPORTANT: ctx MUST be the same context passed to tea.WithContext()
-// to ensure consistent cancellation behavior.
-func New(ctx context.Context, flow *chat.Flow, sessionID string) (*TUI, error) {
- if flow == nil {
- return nil, errors.New("tui.New: flow is required")
- }
- if ctx == nil {
- return nil, errors.New("tui.New: ctx is required")
- }
- if sessionID == "" {
- return nil, errors.New("tui.New: session ID is required")
- }
-
- // Create cancellable context for cleanup on exit
- ctx, cancel := context.WithCancel(ctx)
-
- // Create textarea for multi-line input
- // Enter submits, Shift+Enter adds newline (default behavior)
- ta := textarea.New()
- ta.Placeholder = "Ask anything..."
- ta.SetHeight(1) // Single line by default
- ta.SetWidth(120) // Wide enough for long text, updated on WindowSizeMsg
- ta.MaxWidth = 0 // No max width limit
- ta.ShowLineNumbers = false
-
- // Clean, minimal styling like Claude Code / Gemini CLI
- // No background colors, just simple text
- cleanStyle := textarea.StyleState{
- Base: lipgloss.NewStyle(),
- Text: lipgloss.NewStyle(),
- Placeholder: lipgloss.NewStyle().Foreground(lipgloss.Color("240")), // Gray placeholder
- Prompt: lipgloss.NewStyle(),
- }
- ta.SetStyles(textarea.Styles{
- Focused: cleanStyle,
- Blurred: cleanStyle,
- })
- ta.Focus()
-
- sp := spinner.New()
- sp.Spinner = spinner.Dot
-
- // Create viewport for scrollable message history.
- // Disable built-in keyboard handling — we route keys explicitly
- // in handleKey to avoid conflicts with textarea/history navigation.
- vp := viewport.New(viewport.WithWidth(80), viewport.WithHeight(20))
- vp.MouseWheelEnabled = true
- vp.SoftWrap = true
- vp.KeyMap = viewport.KeyMap{} // Disable default key bindings
-
- h := help.New()
-
- return &TUI{
- chatFlow: flow,
- sessionID: sessionID,
- ctx: ctx,
- ctxCancel: cancel,
- input: ta,
- spinner: sp,
- viewport: vp,
- help: h,
- keys: newKeyMap(),
- styles: DefaultStyles(),
- history: make([]string, 0, maxHistory),
- markdown: newMarkdownRenderer(80),
- width: 80, // Default width until WindowSizeMsg arrives
- }, nil
-}
-
-// Init implements tea.Model.
-func (t *TUI) Init() tea.Cmd {
- return tea.Batch(
- textarea.Blink,
- t.spinner.Tick,
- t.input.Focus(), // Ensure textarea is focused on startup
- )
-}
-
-// Update implements tea.Model.
-//
-//nolint:gocognit,gocyclo // Bubble Tea Update requires type switch on all message types
-func (t *TUI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
- switch msg := msg.(type) {
- case tea.KeyPressMsg:
- return t.handleKey(msg)
-
- case tea.WindowSizeMsg:
- t.width = msg.Width
- t.height = msg.Height
-
- // Calculate viewport height: total - input - separators - help
- inputHeight := t.input.Height() + promptLines
- fixedHeight := separatorLines + inputHeight + helpLines
- vpHeight := max(msg.Height-fixedHeight, minViewport)
-
- t.viewport.SetWidth(msg.Width)
- t.viewport.SetHeight(vpHeight)
- t.input.SetWidth(msg.Width - 4) // Room for "> " prompt
- t.help.SetWidth(msg.Width)
- t.markdown.UpdateWidth(msg.Width)
-
- // Rebuild viewport content with new dimensions
- t.rebuildViewportContent()
- return t, nil
-
- case tea.MouseWheelMsg:
- // Forward mouse wheel to viewport for scrolling
- var cmd tea.Cmd
- t.viewport, cmd = t.viewport.Update(msg)
- return t, cmd
-
- case spinner.TickMsg:
- var cmd tea.Cmd
- t.spinner, cmd = t.spinner.Update(msg)
- // Rebuild viewport to update spinner animation
- if t.state == StateThinking {
- t.rebuildViewportContent()
- }
- return t, cmd
-
- case streamStartedMsg:
- t.streamCancel = msg.cancel
- t.streamEventCh = msg.eventCh
- t.state = StateStreaming
- t.rebuildViewportContent()
- t.viewport.GotoBottom()
- return t, listenForStream(msg.eventCh)
-
- case streamTextMsg:
- t.output.WriteString(msg.text)
- t.rebuildViewportContent()
- t.viewport.GotoBottom()
- return t, listenForStream(t.streamEventCh)
-
- case streamDoneMsg:
- t.state = StateInput
-
- // Cancel context to release timer resources
- if t.streamCancel != nil {
- t.streamCancel()
- t.streamCancel = nil
- }
- t.streamEventCh = nil
-
- // Prefer msg.output.Response (complete response from Genkit) over accumulated chunks.
- // This handles models that don't stream or send final content only in Output.
- finalText := msg.output.Response
- if finalText == "" {
- // Fallback to accumulated chunks if Response is empty
- finalText = t.output.String()
- }
-
- t.addMessage(Message{
- Role: roleAssistant,
- Text: finalText,
- })
- t.output.Reset()
- t.rebuildViewportContent()
- t.viewport.GotoBottom()
- // Re-focus textarea after stream completes
- return t, t.input.Focus()
-
- case streamErrorMsg:
- t.state = StateInput
-
- // Cancel context to release timer resources
- if t.streamCancel != nil {
- t.streamCancel()
- t.streamCancel = nil
- }
- t.streamEventCh = nil
-
- switch {
- case errors.Is(msg.err, context.Canceled):
- t.addMessage(Message{Role: roleSystem, Text: "(Canceled)"})
- case errors.Is(msg.err, context.DeadlineExceeded):
- t.addMessage(Message{Role: roleError, Text: "Query timeout (>5 min). Try a simpler query or break it into steps."})
- default:
- t.addMessage(Message{Role: roleError, Text: msg.err.Error()})
- }
- t.output.Reset()
- t.rebuildViewportContent()
- t.viewport.GotoBottom()
- // Re-focus textarea after error
- return t, t.input.Focus()
- }
-
- var cmd tea.Cmd
- t.input, cmd = t.input.Update(msg)
- return t, cmd
-}
-
-// View implements tea.Model.
-// Uses AltScreen with viewport for scrollable message history.
-func (t *TUI) View() tea.View {
- t.viewBuf.Reset()
-
- // Viewport (scrollable message area)
- _, _ = t.viewBuf.WriteString(t.viewport.View())
- _, _ = t.viewBuf.WriteString("\n")
-
- // Separator line above input
- _, _ = t.viewBuf.WriteString(t.renderSeparator())
- _, _ = t.viewBuf.WriteString("\n")
-
- // Input prompt - always show and always accept input
- // Users can type while LLM is thinking/streaming (better UX)
- _, _ = t.viewBuf.WriteString(t.styles.Prompt.Render("> "))
- _, _ = t.viewBuf.WriteString(t.input.View())
- _, _ = t.viewBuf.WriteString("\n")
-
- // Separator line below input
- _, _ = t.viewBuf.WriteString(t.renderSeparator())
- _, _ = t.viewBuf.WriteString("\n")
-
- // Help bar (keyboard shortcuts)
- _, _ = t.viewBuf.WriteString(t.renderStatusBar())
-
- v := tea.NewView(t.viewBuf.String())
- v.AltScreen = true
- return v
-}
-
-// rebuildViewportContent reconstructs the viewport content from messages and state.
-// Called when messages, streaming output, or state changes.
-func (t *TUI) rebuildViewportContent() {
- var b strings.Builder
-
- // Banner (ASCII art) and tips
- _, _ = b.WriteString(t.styles.RenderBanner())
- _, _ = b.WriteString("\n")
- _, _ = b.WriteString(t.styles.RenderWelcomeTips())
- _, _ = b.WriteString("\n")
-
- // Messages (already bounded by addMessage)
- for _, msg := range t.messages {
- switch msg.Role {
- case roleUser:
- _, _ = b.WriteString(t.styles.User.Render("You> "))
- _, _ = b.WriteString(msg.Text)
- case roleAssistant:
- _, _ = b.WriteString(t.styles.Assistant.Render("Koopa> "))
- _, _ = b.WriteString(t.markdown.Render(msg.Text))
- case roleSystem:
- _, _ = b.WriteString(t.styles.System.Render(msg.Text))
- case roleError:
- _, _ = b.WriteString(t.styles.Error.Render("Error: " + msg.Text))
- }
- _, _ = b.WriteString("\n\n")
- }
-
- // Current streaming output
- if t.state == StateStreaming && t.output.Len() > 0 {
- _, _ = b.WriteString(t.styles.Assistant.Render("Koopa> "))
- _, _ = b.WriteString(t.output.String())
- _, _ = b.WriteString("\n\n")
- }
-
- // Thinking indicator
- if t.state == StateThinking {
- _, _ = b.WriteString(t.spinner.View())
- _, _ = b.WriteString(" Thinking...\n\n")
- }
-
- t.viewport.SetContent(b.String())
-}
-
-// renderSeparator returns a horizontal line separator.
-func (t *TUI) renderSeparator() string {
- width := t.width
- if width <= 0 {
- width = 80 // Default width
- }
- return t.styles.Separator.Render(strings.Repeat("─", width))
-}
-
-// renderStatusBar returns state-appropriate keyboard shortcut help.
-func (t *TUI) renderStatusBar() string {
- var bindings []key.Binding
- switch t.state {
- case StateInput:
- bindings = []key.Binding{
- t.keys.Submit, t.keys.NewLine, t.keys.History,
- t.keys.Cancel, t.keys.Quit, t.keys.ScrollUp,
- }
- case StateThinking, StateStreaming:
- bindings = []key.Binding{
- t.keys.EscCancel, t.keys.Cancel,
- t.keys.ScrollUp, t.keys.ScrollDown,
- }
- }
- return t.help.ShortHelpView(bindings)
-}
diff --git a/internal/tui/tui_test.go b/internal/tui/tui_test.go
index d362468..a1d7ebd 100644
--- a/internal/tui/tui_test.go
+++ b/internal/tui/tui_test.go
@@ -12,7 +12,8 @@ import (
tea "charm.land/bubbletea/v2"
"go.uber.org/goleak"
- "github.com/koopa0/koopa/internal/agent/chat"
+ "github.com/google/uuid"
+ "github.com/koopa0/koopa/internal/chat"
)
// goleakOptions returns standard goleak options for all TUI tests.
@@ -27,8 +28,8 @@ func goleakOptions() []goleak.Option {
}
}
-// newTestTUI creates a TUI with properly initialized components for testing.
-func newTestTUI() *TUI {
+// newTestModel creates a Model with properly initialized components for testing.
+func newTestModel() *Model {
ta := textarea.New()
ta.SetHeight(3)
ta.ShowLineNumbers = false
@@ -38,7 +39,7 @@ func newTestTUI() *TUI {
vp.SoftWrap = true
vp.KeyMap = viewport.KeyMap{}
- return &TUI{
+ return &Model{
state: StateInput,
input: ta,
viewport: vp,
@@ -53,7 +54,7 @@ func newTestTUI() *TUI {
}
func TestNew_ErrorOnNilFlow(t *testing.T) {
- _, err := New(context.Background(), nil, "test")
+ _, err := New(context.Background(), nil, uuid.New())
if err == nil {
t.Error("Expected error for nil flow")
}
@@ -64,23 +65,23 @@ func TestNew_ErrorOnNilContext(t *testing.T) {
// so we're testing that error is returned for nil context
var flow *chat.Flow
//lint:ignore SA1012 intentionally testing nil context handling
- _, err := New(nil, flow, "test") //nolint:staticcheck
+ _, err := New(nil, flow, uuid.New()) //nolint:staticcheck
if err == nil {
t.Error("Expected error for nil context")
}
}
-func TestNew_ErrorOnEmptySessionID(t *testing.T) {
- _, err := New(context.Background(), nil, "")
+func TestNew_ErrorOnNilSessionID(t *testing.T) {
+ _, err := New(context.Background(), nil, uuid.Nil)
if err == nil {
- t.Error("Expected error for empty session ID")
+ t.Error("Expected error for nil session ID")
}
}
func TestTUI_Init(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
cmd := tui.Init()
if cmd == nil {
t.Error("Init should return a command (blink + spinner tick)")
@@ -96,22 +97,22 @@ func TestTUI_HandleSlashCommands(t *testing.T) {
wantExit bool
wantMsgs int // number of messages added
}{
- {"help", "/help", false, 1},
- {"clear", "/clear", false, 0}, // clears messages
- {"exit", "/exit", true, 0},
- {"quit", "/quit", true, 0},
- {"unknown", "/unknown", false, 1}, // error message
+ {name: "help", cmd: "/help", wantExit: false, wantMsgs: 1},
+ {name: "clear", cmd: "/clear", wantExit: false, wantMsgs: 0}, // clears messages
+ {name: "exit", cmd: "/exit", wantExit: true, wantMsgs: 0},
+ {name: "quit", cmd: "/quit", wantExit: true, wantMsgs: 0},
+ {name: "unknown", cmd: "/unknown", wantExit: false, wantMsgs: 1}, // error message
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- tui := newTestTUI()
+ tui := newTestModel()
// Pre-populate with a message for /clear test
tui.messages = []Message{{Role: "user", Text: "hello"}}
model, cmd := tui.handleSlashCommand(tt.cmd)
- result := model.(*TUI)
+ result := model.(*Model)
if tt.wantExit {
if cmd == nil {
@@ -135,29 +136,29 @@ func TestTUI_HandleSlashCommands(t *testing.T) {
func TestTUI_HistoryNavigation(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
tui.history = []string{"first", "second", "third"}
tui.historyIdx = 3
tests := []struct {
- delta int
- expected string
+ delta int
+ want string
}{
- {-1, "third"},
- {-1, "second"},
- {-1, "first"},
- {-1, "first"}, // Should stay at first
- {1, "second"},
- {1, "third"},
- {1, ""}, // Past end = empty
- {1, ""}, // Should stay empty
+ {delta: -1, want: "third"},
+ {delta: -1, want: "second"},
+ {delta: -1, want: "first"},
+ {delta: -1, want: "first"}, // Should stay at first
+ {delta: 1, want: "second"},
+ {delta: 1, want: "third"},
+ {delta: 1, want: ""}, // Past end = empty
+ {delta: 1, want: ""}, // Should stay empty
}
for i, tt := range tests {
model, _ := tui.navigateHistory(tt.delta)
- tui = model.(*TUI)
- if tui.input.Value() != tt.expected {
- t.Errorf("Step %d: got %q, want %q", i, tui.input.Value(), tt.expected)
+ tui = model.(*Model)
+ if tui.input.Value() != tt.want {
+ t.Errorf("navigateHistory(%d) step %d: got %q, want %q", tt.delta, i, tui.input.Value(), tt.want)
}
}
}
@@ -165,11 +166,11 @@ func TestTUI_HistoryNavigation(t *testing.T) {
func TestTUI_CtrlC_ClearsInput(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
tui.input.SetValue("some input")
model, _ := tui.handleCtrlC()
- result := model.(*TUI)
+ result := model.(*Model)
if result.input.Value() != "" {
t.Error("First Ctrl+C should clear input")
@@ -179,7 +180,7 @@ func TestTUI_CtrlC_ClearsInput(t *testing.T) {
func TestTUI_DoubleCtrlC_Exits(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
tui.lastCtrlC = time.Now()
_, cmd := tui.handleCtrlC()
@@ -192,7 +193,7 @@ func TestTUI_DoubleCtrlC_Exits(t *testing.T) {
func TestTUI_Update_KeyPress(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
tui.input.SetValue("test")
// Simulate Ctrl+C (should clear input)
@@ -200,7 +201,7 @@ func TestTUI_Update_KeyPress(t *testing.T) {
msg := tea.KeyPressMsg(key)
model, _ := tui.Update(msg)
- result := model.(*TUI)
+ result := model.(*Model)
if result.input.Value() != "" {
t.Error("Ctrl+C should clear input")
@@ -210,7 +211,7 @@ func TestTUI_Update_KeyPress(t *testing.T) {
func TestTUI_View_ReturnsAltScreen(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
view := tui.View()
if !view.AltScreen {
@@ -224,7 +225,7 @@ func TestTUI_View_ReturnsAltScreen(t *testing.T) {
func TestTUI_RebuildViewportContent(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
tui.addMessage(Message{Role: roleUser, Text: "hello"})
tui.addMessage(Message{Role: roleAssistant, Text: "world"})
@@ -239,7 +240,7 @@ func TestTUI_RebuildViewportContent(t *testing.T) {
func TestTUI_ViewportScrolling(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
// Add enough messages to exceed viewport height
for i := 0; i < 50; i++ {
@@ -262,12 +263,12 @@ func TestTUI_StreamMessageTypes(t *testing.T) {
t.Run("streamTextMsg", func(t *testing.T) {
eventCh := make(chan streamEvent, 1)
- tui := newTestTUI()
+ tui := newTestModel()
tui.state = StateStreaming
tui.streamEventCh = eventCh
model, _ := tui.Update(streamTextMsg{text: "Hello"})
- result := model.(*TUI)
+ result := model.(*Model)
if result.output.String() != "Hello" {
t.Errorf("Expected 'Hello', got %q", result.output.String())
@@ -275,12 +276,12 @@ func TestTUI_StreamMessageTypes(t *testing.T) {
})
t.Run("streamDoneMsg", func(t *testing.T) {
- tui := newTestTUI()
+ tui := newTestModel()
tui.state = StateStreaming
_, _ = tui.output.WriteString("Hello World")
model, _ := tui.Update(streamDoneMsg{output: chat.Output{Response: "Hello World"}})
- result := model.(*TUI)
+ result := model.(*Model)
if result.state != StateInput {
t.Error("Should return to StateInput after stream done")
@@ -294,11 +295,11 @@ func TestTUI_StreamMessageTypes(t *testing.T) {
})
t.Run("streamErrorMsg", func(t *testing.T) {
- tui := newTestTUI()
+ tui := newTestModel()
tui.state = StateStreaming
model, _ := tui.Update(streamErrorMsg{err: context.Canceled})
- result := model.(*TUI)
+ result := model.(*Model)
if result.state != StateInput {
t.Error("Should return to StateInput after error")
@@ -399,7 +400,7 @@ func TestListenForStream_UnionChannel(t *testing.T) {
func TestTUI_AddMessage_BoundsEnforcement(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
// Add more than maxMessages
for i := 0; i < maxMessages+50; i++ {
@@ -418,7 +419,7 @@ func TestTUI_AddMessage_BoundsEnforcement(t *testing.T) {
func TestTUI_HandleSubmit_AddsToHistory(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
tui.ctx = context.Background()
tui.input.SetValue("test query")
@@ -445,7 +446,7 @@ func TestTUI_HandleSubmit_AddsToHistory(t *testing.T) {
func TestTUI_HandleSubmit_HistoryBounds(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
// Pre-fill history to max
for i := 0; i < maxHistory; i++ {
@@ -474,17 +475,17 @@ func TestMarkdownRenderer_UpdateWidth(t *testing.T) {
t.Run("creates renderer with correct width", func(t *testing.T) {
mr := newMarkdownRenderer(100)
if mr == nil {
- t.Fatal("Failed to create markdown renderer")
+ t.Fatal("newMarkdownRenderer() returned nil")
}
if mr.width != 100 {
- t.Errorf("Expected width 100, got %d", mr.width)
+ t.Errorf("newMarkdownRenderer(100).width = %d, want 100", mr.width)
}
})
t.Run("UpdateWidth changes width", func(t *testing.T) {
mr := newMarkdownRenderer(80)
if mr == nil {
- t.Fatal("Failed to create markdown renderer")
+ t.Fatal("newMarkdownRenderer() returned nil")
}
updated := mr.UpdateWidth(120)
@@ -492,14 +493,14 @@ func TestMarkdownRenderer_UpdateWidth(t *testing.T) {
t.Error("UpdateWidth should return true when width changes")
}
if mr.width != 120 {
- t.Errorf("Expected width 120, got %d", mr.width)
+ t.Errorf("UpdateWidth(120) width = %d, want 120", mr.width)
}
})
t.Run("UpdateWidth no-op for same width", func(t *testing.T) {
mr := newMarkdownRenderer(80)
if mr == nil {
- t.Fatal("Failed to create markdown renderer")
+ t.Fatal("newMarkdownRenderer() returned nil")
}
updated := mr.UpdateWidth(80)
@@ -519,7 +520,7 @@ func TestMarkdownRenderer_UpdateWidth(t *testing.T) {
t.Run("UpdateWidth handles invalid width", func(t *testing.T) {
mr := newMarkdownRenderer(80)
if mr == nil {
- t.Fatal("Failed to create markdown renderer")
+ t.Fatal("newMarkdownRenderer() returned nil")
}
updated := mr.UpdateWidth(0)
@@ -540,7 +541,7 @@ func TestMarkdownRenderer_Render(t *testing.T) {
t.Run("renders markdown", func(t *testing.T) {
mr := newMarkdownRenderer(80)
if mr == nil {
- t.Fatal("Failed to create markdown renderer")
+ t.Fatal("newMarkdownRenderer() returned nil")
}
result := mr.Render("**bold**")
@@ -562,7 +563,7 @@ func TestMarkdownRenderer_Render(t *testing.T) {
func TestTUI_Cleanup(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
// Setup stream state
eventCh := make(chan streamEvent, 1)
@@ -582,7 +583,7 @@ func TestTUI_Cleanup(t *testing.T) {
func TestTUI_CancelStream(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
canceled := false
tui.streamCancel = func() { canceled = true }
@@ -600,14 +601,14 @@ func TestTUI_CancelStream(t *testing.T) {
func TestTUI_CtrlC_CancelsStream(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
tui.state = StateStreaming
canceled := false
tui.streamCancel = func() { canceled = true }
model, _ := tui.handleCtrlC()
- result := model.(*TUI)
+ result := model.(*Model)
if !canceled {
t.Error("Ctrl+C during streaming should cancel")
@@ -623,7 +624,7 @@ func TestTUI_CtrlC_CancelsStream(t *testing.T) {
func TestTUI_RenderStatusBar_StateInput(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
tui.state = StateInput
bar := tui.renderStatusBar()
@@ -635,7 +636,7 @@ func TestTUI_RenderStatusBar_StateInput(t *testing.T) {
func TestTUI_RenderStatusBar_StateStreaming(t *testing.T) {
defer goleak.VerifyNone(t, goleakOptions()...)
- tui := newTestTUI()
+ tui := newTestModel()
tui.state = StateStreaming
bar := tui.renderStatusBar()
diff --git a/internal/tui/update.go b/internal/tui/update.go
new file mode 100644
index 0000000..1f65575
--- /dev/null
+++ b/internal/tui/update.go
@@ -0,0 +1,132 @@
+package tui
+
+import (
+ "context"
+ "errors"
+
+ "charm.land/bubbles/v2/spinner"
+ tea "charm.land/bubbletea/v2"
+)
+
+// Update implements tea.Model.
+//
+//nolint:gocognit,gocyclo // Bubble Tea Update requires type switch on all message types
+func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ switch msg := msg.(type) {
+ case tea.KeyPressMsg:
+ return m.handleKey(msg)
+
+ case tea.WindowSizeMsg:
+ m.width = msg.Width
+ m.height = msg.Height
+
+ // Calculate viewport height: total - input - separators - help
+ inputHeight := m.input.Height() + promptLines
+ fixedHeight := separatorLines + inputHeight + helpLines
+ vpHeight := max(msg.Height-fixedHeight, minViewport)
+
+ m.viewport.SetWidth(msg.Width)
+ m.viewport.SetHeight(vpHeight)
+ m.input.SetWidth(msg.Width - 4) // Room for "> " prompt
+ m.help.SetWidth(msg.Width)
+ m.markdown.UpdateWidth(msg.Width)
+
+ // Rebuild viewport content with new dimensions
+ m.rebuildViewportContent()
+ return m, nil
+
+ case tea.MouseWheelMsg:
+ // Forward mouse wheel to viewport for scrolling
+ var cmd tea.Cmd
+ m.viewport, cmd = m.viewport.Update(msg)
+ return m, cmd
+
+ case spinner.TickMsg:
+ var cmd tea.Cmd
+ m.spinner, cmd = m.spinner.Update(msg)
+ // Rebuild viewport to update spinner animation during thinking or tool execution
+ if m.state == StateThinking || (m.state == StateStreaming && m.toolStatus != "") {
+ m.rebuildViewportContent()
+ }
+ return m, cmd
+
+ case streamStartedMsg:
+ m.streamCancel = msg.cancel
+ m.streamEventCh = msg.eventCh
+ m.state = StateStreaming
+ m.rebuildViewportContent()
+ m.viewport.GotoBottom()
+ return m, listenForStream(msg.eventCh)
+
+ case streamToolMsg:
+ m.toolStatus = msg.status
+ m.rebuildViewportContent()
+ m.viewport.GotoBottom()
+ return m, listenForStream(m.streamEventCh)
+
+ case streamTextMsg:
+ m.toolStatus = "" // Clear tool status when text arrives
+ m.output.WriteString(msg.text)
+ m.rebuildViewportContent()
+ m.viewport.GotoBottom()
+ return m, listenForStream(m.streamEventCh)
+
+ case streamDoneMsg:
+ m.state = StateInput
+ m.toolStatus = ""
+
+ // Cancel context to release timer resources
+ if m.streamCancel != nil {
+ m.streamCancel()
+ m.streamCancel = nil
+ }
+ m.streamEventCh = nil
+
+ // Prefer msg.output.Response (complete response from Genkit) over accumulated chunks.
+ // This handles models that don't stream or send final content only in Output.
+ finalText := msg.output.Response
+ if finalText == "" {
+ // Fallback to accumulated chunks if Response is empty
+ finalText = m.output.String()
+ }
+
+ m.addMessage(Message{
+ Role: roleAssistant,
+ Text: finalText,
+ })
+ m.output.Reset()
+ m.rebuildViewportContent()
+ m.viewport.GotoBottom()
+ // Re-focus textarea after stream completes
+ return m, m.input.Focus()
+
+ case streamErrorMsg:
+ m.state = StateInput
+ m.toolStatus = ""
+
+ // Cancel context to release timer resources
+ if m.streamCancel != nil {
+ m.streamCancel()
+ m.streamCancel = nil
+ }
+ m.streamEventCh = nil
+
+ switch {
+ case errors.Is(msg.err, context.Canceled):
+ m.addMessage(Message{Role: roleSystem, Text: "(Canceled)"})
+ case errors.Is(msg.err, context.DeadlineExceeded):
+ m.addMessage(Message{Role: roleError, Text: "Query timeout (>5 min). Try a simpler query or break it into steps."})
+ default:
+ m.addMessage(Message{Role: roleError, Text: msg.err.Error()})
+ }
+ m.output.Reset()
+ m.rebuildViewportContent()
+ m.viewport.GotoBottom()
+ // Re-focus textarea after error
+ return m, m.input.Focus()
+ }
+
+ var cmd tea.Cmd
+ m.input, cmd = m.input.Update(msg)
+ return m, cmd
+}
diff --git a/internal/tui/view.go b/internal/tui/view.go
new file mode 100644
index 0000000..0567309
--- /dev/null
+++ b/internal/tui/view.go
@@ -0,0 +1,118 @@
+package tui
+
+import (
+ "strings"
+
+ "charm.land/bubbles/v2/key"
+ tea "charm.land/bubbletea/v2"
+)
+
+// View implements tea.Model.
+// Uses AltScreen with viewport for scrollable message history.
+func (m *Model) View() tea.View {
+ m.viewBuf.Reset()
+
+ // Viewport (scrollable message area)
+ _, _ = m.viewBuf.WriteString(m.viewport.View())
+ _, _ = m.viewBuf.WriteString("\n")
+
+ // Separator line above input
+ _, _ = m.viewBuf.WriteString(m.renderSeparator())
+ _, _ = m.viewBuf.WriteString("\n")
+
+ // Input prompt - always show and always accept input
+ // Users can type while LLM is thinking/streaming (better UX)
+ _, _ = m.viewBuf.WriteString(m.styles.Prompt.Render("> "))
+ _, _ = m.viewBuf.WriteString(m.input.View())
+ _, _ = m.viewBuf.WriteString("\n")
+
+ // Separator line below input
+ _, _ = m.viewBuf.WriteString(m.renderSeparator())
+ _, _ = m.viewBuf.WriteString("\n")
+
+ // Help bar (keyboard shortcuts)
+ _, _ = m.viewBuf.WriteString(m.renderStatusBar())
+
+ v := tea.NewView(m.viewBuf.String())
+ v.AltScreen = true
+ return v
+}
+
+// rebuildViewportContent reconstructs the viewport content from messages and state.
+// Called when messages, streaming output, or state changes.
+func (m *Model) rebuildViewportContent() {
+ var b strings.Builder
+
+ // Banner (ASCII art) and tips
+ _, _ = b.WriteString(m.styles.RenderBanner())
+ _, _ = b.WriteString("\n")
+ _, _ = b.WriteString(m.styles.RenderWelcomeTips())
+ _, _ = b.WriteString("\n")
+
+ // Messages (already bounded by addMessage)
+ for _, msg := range m.messages {
+ switch msg.Role {
+ case roleUser:
+ _, _ = b.WriteString(m.styles.User.Render("You> "))
+ _, _ = b.WriteString(msg.Text)
+ case roleAssistant:
+ _, _ = b.WriteString(m.styles.Assistant.Render("Koopa> "))
+ _, _ = b.WriteString(m.markdown.Render(msg.Text))
+ case roleSystem:
+ _, _ = b.WriteString(m.styles.System.Render(msg.Text))
+ case roleError:
+ _, _ = b.WriteString(m.styles.Error.Render("Error: " + msg.Text))
+ }
+ _, _ = b.WriteString("\n\n")
+ }
+
+ // Current streaming output
+ if m.state == StateStreaming && m.output.Len() > 0 {
+ _, _ = b.WriteString(m.styles.Assistant.Render("Koopa> "))
+ _, _ = b.WriteString(m.output.String())
+ _, _ = b.WriteString("\n\n")
+ }
+
+ // Tool status indicator (shown during streaming when a tool is executing)
+ if m.state == StateStreaming && m.toolStatus != "" {
+ _, _ = b.WriteString(m.spinner.View())
+ _, _ = b.WriteString(" ")
+ _, _ = b.WriteString(m.styles.System.Render(m.toolStatus))
+ _, _ = b.WriteString("\n\n")
+ }
+
+ // Thinking indicator
+ if m.state == StateThinking {
+ _, _ = b.WriteString(m.spinner.View())
+ _, _ = b.WriteString(" Thinking...\n\n")
+ }
+
+ m.viewport.SetContent(b.String())
+}
+
+// renderSeparator returns a horizontal line separator.
+func (m *Model) renderSeparator() string {
+ width := m.width
+ if width <= 0 {
+ width = 80 // Default width
+ }
+ return m.styles.Separator.Render(strings.Repeat("─", width))
+}
+
+// renderStatusBar returns state-appropriate keyboard shortcut help.
+func (m *Model) renderStatusBar() string {
+ var bindings []key.Binding
+ switch m.state {
+ case StateInput:
+ bindings = []key.Binding{
+ m.keys.Submit, m.keys.NewLine, m.keys.History,
+ m.keys.Cancel, m.keys.Quit, m.keys.ScrollUp,
+ }
+ case StateThinking, StateStreaming:
+ bindings = []key.Binding{
+ m.keys.EscCancel, m.keys.Cancel,
+ m.keys.ScrollUp, m.keys.ScrollDown,
+ }
+ }
+ return m.help.ShortHelpView(bindings)
+}
diff --git a/prompts/koopa.prompt b/prompts/koopa.prompt
index 8dfcdfc..08e565e 100644
--- a/prompts/koopa.prompt
+++ b/prompts/koopa.prompt
@@ -1,4 +1,5 @@
---
+# Model below is a dotprompt default; overridden at runtime by chat.Config.ModelName.
model: googleai/gemini-2.5-flash
config:
temperature: 0.7