Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ linters:
- contextcheck

# contextcheck: flow.go uses InvocationContext which wraps context
- path: internal/agent/chat/flow\.go
- path: internal/chat/flow\.go
linters:
- contextcheck

Expand Down
74 changes: 74 additions & 0 deletions cmd/addr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package cmd

import (
"flag"
"fmt"
"net"
"os"
"strconv"
"strings"
)

// parseServeAddr parses and validates the server address from command line arguments.
// Uses flag.FlagSet for standard Go flag parsing, supporting:
// - koopa serve :8080 (positional)
// - koopa serve --addr :8080 (flag)
// - koopa serve -addr :8080 (single dash)
func parseServeAddr() (string, error) {
const defaultAddr = "127.0.0.1:3400"

serveFlags := flag.NewFlagSet("serve", flag.ContinueOnError)
serveFlags.SetOutput(os.Stderr)

addr := serveFlags.String("addr", defaultAddr, "Server address (host:port)")

args := []string{}
if len(os.Args) > 2 {
args = os.Args[2:]
}

// Check for positional argument first (koopa serve :8080)
if len(args) > 0 && !strings.HasPrefix(args[0], "-") {
*addr = args[0]
args = args[1:]
}
Comment on lines +31 to +34

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential for unexpected flag precedence

Assigning the positional argument directly to *addr before parsing flags can lead to confusion if both a positional argument and an --addr flag are provided. The flag parsing will override the positional argument, which may not be the user's intent. Consider explicitly detecting and rejecting cases where both are provided, or clearly documenting the precedence.

Recommended solution:

if len(args) > 0 && !strings.HasPrefix(args[0], "-") {
    if containsAddrFlag(args[1:]) {
        return "", errors.New("cannot specify address both positionally and with --addr flag")
    }
    *addr = args[0]
    args = args[1:]
}

Implement a helper to check for the presence of the addr flag in the remaining args.


if err := serveFlags.Parse(args); err != nil {
return "", fmt.Errorf("parsing serve flags: %w", err)
}

if err := validateAddr(*addr); err != nil {
return "", fmt.Errorf("invalid address %q: %w", *addr, err)
}

return *addr, nil
}

// validateAddr validates the server address format.
func validateAddr(addr string) error {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("must be in host:port format: %w", err)
}

if host != "" && host != "localhost" {
if ip := net.ParseIP(host); ip == nil {
if strings.ContainsAny(host, " \t\n") {
return fmt.Errorf("invalid host: %s", host)
}
}
Comment on lines +54 to +59

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Insufficient host validation

The current host validation only checks for whitespace and whether the host is 'localhost' or a valid IP. It does not validate the host against RFC-compliant hostname rules, nor does it prevent reserved or malformed hostnames. This could allow invalid or potentially unsafe hostnames to be accepted.

Recommended solution:
Consider using a stricter hostname validation, such as using a regular expression or leveraging net package utilities to ensure the host is a valid domain name or IP address.

}

if port == "" {
return fmt.Errorf("port is required")
}
portNum, err := strconv.Atoi(port)
if err != nil {
return fmt.Errorf("port must be numeric: %w", err)
}
if portNum < 0 || portNum > 65535 {
return fmt.Errorf("port must be 0-65535 (0 = auto-assign), got %d", portNum)
}

return nil
}
68 changes: 68 additions & 0 deletions cmd/addr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package cmd

import "testing"

func TestValidateAddr(t *testing.T) {
t.Parallel()

tests := []struct {
name string
addr string
wantErr bool
}{
// Valid addresses
{name: "port only", addr: ":8080", wantErr: false},
{name: "localhost", addr: "localhost:3400", wantErr: false},
{name: "loopback", addr: "127.0.0.1:3400", wantErr: false},
{name: "all interfaces", addr: "0.0.0.0:80", wantErr: false},
{name: "ipv6 loopback", addr: "[::1]:8080", wantErr: false},
{name: "port zero", addr: ":0", wantErr: false},
{name: "port max", addr: ":65535", wantErr: false},
{name: "hostname", addr: "myhost:9090", wantErr: false},

// Invalid: bad format
{name: "no port", addr: "localhost", wantErr: true},
{name: "port alone", addr: "8080", wantErr: true},
{name: "empty string", addr: "", wantErr: true},

// Invalid: bad port
{name: "port non-numeric", addr: ":abc", wantErr: true},
{name: "port negative", addr: ":-1", wantErr: true},
{name: "port too high", addr: ":65536", wantErr: true},
{name: "port empty after colon", addr: "localhost:", wantErr: true},

// Invalid: bad host
{name: "host with space", addr: "my host:8080", wantErr: true},
{name: "host with tab", addr: "my\thost:8080", wantErr: true},
{name: "host with newline", addr: "my\nhost:8080", wantErr: true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := validateAddr(tt.addr)
if tt.wantErr && err == nil {
t.Errorf("validateAddr(%q) = nil, want error", tt.addr)
}
if !tt.wantErr && err != nil {
t.Errorf("validateAddr(%q) = %v, want nil", tt.addr, err)
}
})
Comment on lines +41 to +50

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The subtest closure in the table-driven test captures the loop variable tt by reference, which can cause data races and unpredictable test results when running subtests in parallel. To avoid this, assign tt to a new variable within the loop before the closure:

for _, tt := range tests {
    tt := tt // capture by value
    t.Run(tt.name, func(t *testing.T) {
        t.Parallel()
        // ...
    })
}

This ensures each subtest gets its own copy of tt, preventing concurrency issues.

}
}

func FuzzValidateAddr(f *testing.F) {
f.Add(":8080")
f.Add("localhost:3400")
f.Add("127.0.0.1:80")
f.Add("")
f.Add("abc")
f.Add(":0")
f.Add(":99999")
f.Add("[::1]:8080")
f.Add("host with space:80")

f.Fuzz(func(t *testing.T, addr string) {
_ = validateAddr(addr) // must not panic
})
}
56 changes: 17 additions & 39 deletions cmd/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cmd

import (
"context"
"errors"
"fmt"
"log/slog"
"os/signal"
Expand All @@ -11,72 +10,51 @@ import (
tea "charm.land/bubbletea/v2"

"github.com/koopa0/koopa/internal/app"
"github.com/koopa0/koopa/internal/chat"
"github.com/koopa0/koopa/internal/config"
"github.com/koopa0/koopa/internal/session"
"github.com/koopa0/koopa/internal/tui"
)

// runCLI initializes and starts the interactive CLI with Bubble Tea TUI.
func runCLI() error {
cfg, err := config.Load()
if err != nil {
return err
return fmt.Errorf("loading config: %w", err)
}

ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of signal.NotifyContext to create a context that is canceled on SIGINT or SIGTERM is appropriate for graceful shutdown. However, passing this context directly to the Bubble Tea program (tea.NewProgram(model, tea.WithContext(ctx))) may result in the TUI being abruptly terminated when the signal is received, potentially preventing proper cleanup or user feedback. Consider handling the signal separately and coordinating a graceful shutdown sequence for the TUI.

defer cancel()

runtime, err := app.NewChatRuntime(ctx, cfg)
a, err := app.Setup(ctx, cfg)
if err != nil {
return fmt.Errorf("failed to initialize runtime: %w", err)
return fmt.Errorf("initializing application: %w", err)
}
defer func() {
if closeErr := runtime.Close(); closeErr != nil {
slog.Warn("runtime close error", "error", closeErr)
if closeErr := a.Close(); closeErr != nil {
slog.Warn("shutdown error", "error", closeErr)
}
}()

sessionID, err := getOrCreateSessionID(ctx, runtime.App.SessionStore, cfg)
agent, err := a.CreateAgent()
if err != nil {
return fmt.Errorf("failed to get session: %w", err)
return fmt.Errorf("creating agent: %w", err)
}

model, err := tui.New(ctx, runtime.Flow, sessionID)
if err != nil {
return fmt.Errorf("failed to create TUI: %w", err)
}
program := tea.NewProgram(model, tea.WithContext(ctx))

if _, err = program.Run(); err != nil {
return fmt.Errorf("TUI exited: %w", err)
}
return nil
}
flow := chat.NewFlow(a.Genkit, agent)

// getOrCreateSessionID returns a valid session ID, creating a new session if needed.
func getOrCreateSessionID(ctx context.Context, store *session.Store, cfg *config.Config) (string, error) {
currentID, err := session.LoadCurrentSessionID()
sessionID, err := a.SessionStore.ResolveCurrentSession(ctx)
if err != nil {
return "", fmt.Errorf("failed to load session: %w", err)
return fmt.Errorf("resolving session: %w", err)
}

if currentID != nil {
if _, err = store.Session(ctx, *currentID); err == nil {
return currentID.String(), nil
}
if !errors.Is(err, session.ErrSessionNotFound) {
return "", fmt.Errorf("failed to validate session: %w", err)
}
}

newSess, err := store.CreateSession(ctx, "New Session", cfg.ModelName, "You are a helpful assistant.")
model, err := tui.New(ctx, flow, sessionID)
if err != nil {
return "", fmt.Errorf("failed to create session: %w", err)
return fmt.Errorf("creating TUI: %w", err)
}
program := tea.NewProgram(model, tea.WithContext(ctx))

if err := session.SaveCurrentSessionID(newSess.ID); err != nil {
slog.Warn("failed to save session state", "error", err)
if _, err = program.Run(); err != nil {
return fmt.Errorf("running TUI: %w", err)
}

return newSess.ID.String(), nil
return nil
}
76 changes: 76 additions & 0 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Package cmd provides CLI commands for Koopa.
//
// Commands:
// - cli: Interactive terminal chat with Bubble Tea TUI
// - serve: HTTP API server with SSE streaming
// - mcp: Model Context Protocol server for IDE integration
//
// Signal handling and graceful shutdown are implemented
// for all commands via context cancellation.
package cmd

import (
"fmt"
"log/slog"
"os"
)

// Execute is the main entry point for the Koopa CLI application.
func Execute() error {
// Initialize logger once at entry point
level := slog.LevelInfo
if os.Getenv("DEBUG") != "" {
level = slog.LevelDebug
}
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})))

if len(os.Args) < 2 {
runHelp()
return nil
}

switch os.Args[1] {
case "cli":
return runCLI()
case "serve":
return runServe()
case "mcp":
return runMCP()
case "version", "--version", "-v":
runVersion()
return nil
case "help", "--help", "-h":
runHelp()
return nil
default:
return fmt.Errorf("unknown command: %s", os.Args[1])
}
Comment on lines +32 to +47

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maintainability Concern: Monolithic Command Dispatch

The current implementation uses a large switch statement on os.Args[1] to dispatch commands. As the number of commands increases, this approach will become harder to maintain and extend. Consider refactoring to use a map of command names to handler functions, which will improve modularity and make it easier to add, remove, or modify commands.

Recommended refactor:

var commands = map[string]func() error{
    "cli":    runCLI,
    "serve":  runServe,
    "mcp":    runMCP,
    // ...
}

if cmd, ok := commands[os.Args[1]]; ok {
    return cmd()
}
return fmt.Errorf("unknown command: %s", os.Args[1])

This approach centralizes command registration and reduces the risk of errors when modifying the command set.

}

// runHelp displays the help message.
func runHelp() {
fmt.Println("Koopa - Your terminal AI personal assistant")
fmt.Println()
fmt.Println("Usage:")
fmt.Println(" koopa cli Start interactive chat mode")
fmt.Println(" koopa serve [addr] Start HTTP API server (default: 127.0.0.1:3400)")
fmt.Println(" koopa mcp Start MCP server (for Claude Desktop/Cursor)")
fmt.Println(" koopa --version Show version information")
fmt.Println(" koopa --help Show this help")
fmt.Println()
fmt.Println("CLI Commands (in interactive mode):")
fmt.Println(" /help Show available commands")
fmt.Println(" /version Show version")
fmt.Println(" /clear Clear conversation history")
fmt.Println(" /exit, /quit Exit Koopa")
fmt.Println()
fmt.Println("Shortcuts:")
fmt.Println(" Ctrl+D Exit Koopa")
fmt.Println(" Ctrl+C Cancel current input")
fmt.Println()
fmt.Println("Environment Variables:")
fmt.Println(" GEMINI_API_KEY Required: Gemini API key")
fmt.Println(" DEBUG Optional: Enable debug logging")
fmt.Println()
fmt.Println("Learn more: https://github.com/koopa0/koopa")
}
Loading
Loading