diff --git a/db/migrations/000004_trigram_search.up.sql b/db/migrations/000004_trigram_search.up.sql index 1f98026..794881f 100644 --- a/db/migrations/000004_trigram_search.up.sql +++ b/db/migrations/000004_trigram_search.up.sql @@ -1,17 +1,17 @@ -- Enable pg_trgm extension for trigram-based text search (CJK support). CREATE EXTENSION IF NOT EXISTS pg_trgm; --- CONCURRENTLY avoids locking the table during index creation. --- NOTE: CONCURRENTLY cannot run inside a transaction block. --- golang-migrate runs each file in a transaction by default; --- the operator must run this migration manually with: --- psql -f 000004_trigram_search.up.sql --- or disable transactions in the migration tool. - -- GIN trigram index on messages.text_content for ILIKE fallback search. -CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_messages_text_content_trgm +-- NOTE: For large production tables with existing data, create these indexes +-- manually with CONCURRENTLY before running this migration (they will be +-- no-ops due to IF NOT EXISTS): +-- CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_messages_text_content_trgm +-- ON messages USING gin (text_content gin_trgm_ops); +-- CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_memories_content_trgm +-- ON memories USING gin (content gin_trgm_ops); +CREATE INDEX IF NOT EXISTS idx_messages_text_content_trgm ON messages USING gin (text_content gin_trgm_ops); -- GIN trigram index on memories.content for similarity() scoring. -CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_memories_content_trgm +CREATE INDEX IF NOT EXISTS idx_memories_content_trgm ON memories USING gin (content gin_trgm_ops); diff --git a/db/migrations/000005_owner_id_check.down.sql b/db/migrations/000005_owner_id_check.down.sql new file mode 100644 index 0000000..ef86117 --- /dev/null +++ b/db/migrations/000005_owner_id_check.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE sessions ALTER COLUMN owner_id SET DEFAULT ''; +ALTER TABLE sessions DROP CONSTRAINT IF EXISTS sessions_owner_id_not_empty; diff --git a/db/migrations/000005_owner_id_check.up.sql b/db/migrations/000005_owner_id_check.up.sql new file mode 100644 index 0000000..7dbad08 --- /dev/null +++ b/db/migrations/000005_owner_id_check.up.sql @@ -0,0 +1,7 @@ +-- Backfill empty owner_id before adding constraint. +-- Sessions created before ownership was enforced may have owner_id = ''. +UPDATE sessions SET owner_id = 'anonymous' WHERE owner_id = ''; + +-- Prevent empty owner_id in sessions (enforces NOT NULL + non-empty). +ALTER TABLE sessions ADD CONSTRAINT sessions_owner_id_not_empty CHECK (owner_id != ''); +ALTER TABLE sessions ALTER COLUMN owner_id DROP DEFAULT; diff --git a/internal/api/chat.go b/internal/api/chat.go index c3aafb5..d707ae4 100644 --- a/internal/api/chat.go +++ b/internal/api/chat.go @@ -14,6 +14,7 @@ import ( "time" "github.com/google/uuid" + "github.com/koopa0/koopa/internal/chat" "github.com/koopa0/koopa/internal/session" "github.com/koopa0/koopa/internal/tools" diff --git a/internal/api/doc.go b/internal/api/doc.go index b6ae844..6f672f6 100644 --- a/internal/api/doc.go +++ b/internal/api/doc.go @@ -4,7 +4,7 @@ // // The API server uses Go 1.22+ routing with a layered middleware stack: // -// Recovery → Logging → RateLimit → CORS → Session → CSRF → Routes +// Recovery → RequestID → Logging → CORS → RateLimit → User → Session → CSRF → Routes // // Health probes (/health, /ready) bypass the middleware stack via a // top-level mux, ensuring they remain fast and unauthenticated. @@ -24,11 +24,24 @@ // - GET /api/v1/sessions/{id} — get session by ID // - GET /api/v1/sessions/{id}/messages — get session messages // - DELETE /api/v1/sessions/{id} — delete session +// - GET /api/v1/sessions/{id}/export — export session // // Chat (ownership-enforced): // - POST /api/v1/chat — initiate chat, returns stream URL // - GET /api/v1/chat/stream — SSE endpoint for streaming responses // +// Search: +// - GET /api/v1/search — full-text search across messages +// +// Stats: +// - GET /api/v1/stats — usage statistics (sessions, messages, memories) +// +// Memory (ownership-enforced): +// - GET /api/v1/memories — list memories +// - POST /api/v1/memories — create memory +// - DELETE /api/v1/memories/{id} — delete memory +// - GET /api/v1/memories/search — search memories +// // # CSRF Token Model // // Two token types prevent cross-site request forgery: @@ -39,7 +52,7 @@ // - Session-bound tokens ("timestamp:signature"): bound to a specific // session via HMAC-SHA256, verified with constant-time comparison. // -// Both expire after 24 hours with 5 minutes of clock skew tolerance. +// Both expire after 1 hour with 5 minutes of clock skew tolerance. // // # Session Ownership // diff --git a/internal/api/search.go b/internal/api/search.go index 7bca919..ec3f532 100644 --- a/internal/api/search.go +++ b/internal/api/search.go @@ -45,7 +45,7 @@ func (h *searchHandler) searchMessages(w http.ResponseWriter, r *http.Request) { results, total, err := h.store.SearchMessages(r.Context(), userID, query, limit, offset) if err != nil { - h.logger.Error("searching messages", "error", err, "user_id", userID, "query", query) + h.logger.Error("searching messages", "error", err, "user_id", userID, "query_len", len(query)) WriteError(w, http.StatusInternalServerError, "search_failed", "failed to search messages", h.logger) return } diff --git a/internal/api/session.go b/internal/api/session.go index 723fe44..8962cac 100644 --- a/internal/api/session.go +++ b/internal/api/session.go @@ -16,23 +16,17 @@ import ( "time" "github.com/google/uuid" + "github.com/koopa0/koopa/internal/session" ) -// Sentinel errors for session/CSRF operations. var ( - // ErrSessionCookieNotFound is returned when the session cookie is absent from the request. - ErrSessionCookieNotFound = errors.New("session cookie not found") - // ErrSessionInvalid is returned when the session cookie value is not a valid UUID. - ErrSessionInvalid = errors.New("session ID invalid") - // ErrCSRFRequired is returned when a state-changing request has no CSRF token. - ErrCSRFRequired = errors.New("csrf token required") - // ErrCSRFInvalid is returned when the CSRF token signature does not match. - ErrCSRFInvalid = errors.New("csrf token invalid") - // ErrCSRFExpired is returned when the CSRF token timestamp exceeds csrfTokenTTL. - ErrCSRFExpired = errors.New("csrf token expired") - // ErrCSRFMalformed is returned when the CSRF token format cannot be parsed. - ErrCSRFMalformed = errors.New("csrf token malformed") + errSessionCookieNotFound = errors.New("session cookie not found") + errSessionInvalid = errors.New("session ID invalid") + errCSRFRequired = errors.New("csrf token required") + errCSRFInvalid = errors.New("csrf token invalid") + errCSRFExpired = errors.New("csrf token expired") + errCSRFMalformed = errors.New("csrf token malformed") ) // Pre-session CSRF token prefix to distinguish from user-bound tokens. @@ -61,12 +55,12 @@ type sessionManager struct { func (*sessionManager) SessionID(r *http.Request) (uuid.UUID, error) { cookie, err := r.Cookie(sessionCookieName) if err != nil { - return uuid.Nil, ErrSessionCookieNotFound + return uuid.Nil, errSessionCookieNotFound } sessionID, err := uuid.Parse(cookie.Value) if err != nil { - return uuid.Nil, ErrSessionInvalid + return uuid.Nil, errSessionInvalid } return sessionID, nil @@ -109,17 +103,17 @@ func (sm *sessionManager) NewCSRFToken(userID string) string { // CheckCSRF verifies a user-bound CSRF token. func (sm *sessionManager) CheckCSRF(userID, token string) error { if token == "" { - return ErrCSRFRequired + return errCSRFRequired } parts := strings.SplitN(token, ":", 2) if len(parts) != 2 { - return ErrCSRFMalformed + return errCSRFMalformed } timestamp, err := strconv.ParseInt(parts[0], 10, 64) if err != nil { - return ErrCSRFMalformed + return errCSRFMalformed } // SECURITY: Compute and verify HMAC BEFORE timestamp checks to prevent @@ -133,19 +127,19 @@ func (sm *sessionManager) CheckCSRF(userID, token string) error { actualSig, err := base64.URLEncoding.DecodeString(parts[1]) if err != nil { - return ErrCSRFMalformed + return errCSRFMalformed } if subtle.ConstantTimeCompare(actualSig, expectedSig) != 1 { - return ErrCSRFInvalid + return errCSRFInvalid } age := time.Since(time.Unix(timestamp, 0)) if age > csrfTokenTTL { - return ErrCSRFExpired + return errCSRFExpired } if age < -csrfClockSkew { - return ErrCSRFInvalid + return errCSRFInvalid } return nil @@ -168,23 +162,23 @@ func (sm *sessionManager) NewPreSessionCSRFToken() string { // CheckPreSessionCSRF verifies a pre-session CSRF token. func (sm *sessionManager) CheckPreSessionCSRF(token string) error { if token == "" { - return ErrCSRFRequired + return errCSRFRequired } if !strings.HasPrefix(token, preSessionPrefix) { - return ErrCSRFMalformed + return errCSRFMalformed } tokenBody := strings.TrimPrefix(token, preSessionPrefix) parts := strings.SplitN(tokenBody, ":", 3) if len(parts) != 3 { - return ErrCSRFMalformed + return errCSRFMalformed } nonce := parts[0] timestamp, err := strconv.ParseInt(parts[1], 10, 64) if err != nil { - return ErrCSRFMalformed + return errCSRFMalformed } // SECURITY: Compute and verify HMAC BEFORE timestamp checks to prevent @@ -196,19 +190,19 @@ func (sm *sessionManager) CheckPreSessionCSRF(token string) error { actualSig, err := base64.URLEncoding.DecodeString(parts[2]) if err != nil { - return ErrCSRFMalformed + return errCSRFMalformed } if subtle.ConstantTimeCompare(actualSig, expectedSig) != 1 { - return ErrCSRFInvalid + return errCSRFInvalid } age := time.Since(time.Unix(timestamp, 0)) if age > csrfTokenTTL { - return ErrCSRFExpired + return errCSRFExpired } if age < -csrfClockSkew { - return ErrCSRFInvalid + return errCSRFInvalid } return nil diff --git a/internal/chat/chat.go b/internal/chat/chat.go index a7fb86b..8b7e25c 100644 --- a/internal/chat/chat.go +++ b/internal/chat/chat.go @@ -103,7 +103,7 @@ func (cfg Config) validate() error { return errors.New("at least one tool is required") } if cfg.MemoryStore != nil && cfg.WG == nil { - return errors.New("WG is required when MemoryStore is set") + return errors.New("wg is required when memory store is set") } return nil } diff --git a/internal/config/config.go b/internal/config/config.go index 0a7db4a..c713a35 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -24,6 +24,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "log/slog" "os" "path/filepath" @@ -51,9 +52,6 @@ var ( // ErrInvalidEmbedderModel indicates the embedder model is invalid. ErrInvalidEmbedderModel = errors.New("invalid embedder model") - // ErrInvalidEmbedderDimension indicates the embedder produces incompatible vector dimensions. - ErrInvalidEmbedderDimension = errors.New("incompatible embedder dimension") - // ErrInvalidPostgresHost indicates the PostgreSQL host is invalid. ErrInvalidPostgresHost = errors.New("invalid PostgreSQL host") @@ -158,6 +156,11 @@ type Config struct { CORSOrigins []string `mapstructure:"cors_origins" json:"cors_origins"` TrustProxy bool `mapstructure:"trust_proxy" json:"trust_proxy"` // Trust X-Real-IP/X-Forwarded-For headers (set true behind reverse proxy) DevMode bool `mapstructure:"dev_mode" json:"dev_mode"` // Dev mode: disables Secure flag on cookies (decoupled from DB SSL) + + // API keys read from environment at Load() time. + // Unexported to avoid accidental serialization; validated in validateProviderAPIKey(). + geminiAPIKey string + openaiAPIKey string } // Load loads configuration. @@ -213,6 +216,12 @@ func Load() (*Config, error) { return nil, fmt.Errorf("parsing DATABASE_URL: %w", err) } + // Capture API keys from environment at load time. + // These are read by Genkit plugins directly, but we also need them + // for fail-fast validation (no os.Getenv in business logic). + cfg.geminiAPIKey = os.Getenv("GEMINI_API_KEY") + cfg.openaiAPIKey = os.Getenv("OPENAI_API_KEY") + // CRITICAL: Validate immediately (fail-fast) if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("validating configuration: %w", err) @@ -280,11 +289,11 @@ func setDefaults(v *viper.Viper) { // 2. DD_API_KEY - Datadog API key (optional, for observability) // 3. HMAC_SECRET - HMAC secret for CSRF protection (serve mode only) func bindEnvVariables(v *viper.Viper) { - // Helper to panic on unexpected bind errors (hardcoded strings can't fail) - // If this panics, it's a BUG in our code, not a runtime error + // Helper to fatal on unexpected bind errors (hardcoded strings can't fail) + // If this fails, it's a BUG in our code, not a runtime error mustBind := func(key, envVar string) { if err := v.BindEnv(key, envVar); err != nil { - panic(fmt.Sprintf("BUG: failed to bind %q to %q: %v", key, envVar, err)) + log.Fatalf("BUG: failed to bind %q to %q: %v", key, envVar, err) } } diff --git a/internal/config/validation.go b/internal/config/validation.go index 667a19c..4b6376f 100644 --- a/internal/config/validation.go +++ b/internal/config/validation.go @@ -3,8 +3,9 @@ package config import ( "fmt" "log/slog" - "os" "slices" + + "github.com/koopa0/koopa/internal/rag" ) // supportedProviders lists all valid AI provider values. @@ -127,7 +128,8 @@ var knownEmbedderDimensions = map[string]map[string]int{ } // requiredVectorDimension must match the pgvector schema: embedding vector(768). -const requiredVectorDimension = 768 +// Canonical source: rag.VectorDimension. +var requiredVectorDimension = int(rag.VectorDimension) // validateEmbedder checks that the configured embedder model produces vectors // compatible with the database schema. For known models whose native dimension @@ -200,16 +202,17 @@ func (c *Config) validateRetention() error { } // validateProviderAPIKey checks that the required API key is set for the configured provider. +// API keys are captured from environment in Load() and stored as unexported fields. func (c *Config) validateProviderAPIKey() error { switch c.resolvedProvider() { case ProviderGemini: - if os.Getenv("GEMINI_API_KEY") == "" { + if c.geminiAPIKey == "" { return fmt.Errorf("%w: GEMINI_API_KEY environment variable is required for provider %q\n"+ "Get your API key at: https://ai.google.dev/gemini-api/docs/api-key", ErrMissingAPIKey, c.resolvedProvider()) } case ProviderOpenAI: - if os.Getenv("OPENAI_API_KEY") == "" { + if c.openaiAPIKey == "" { return fmt.Errorf("%w: OPENAI_API_KEY environment variable is required for provider %q\n"+ "Get your API key at: https://platform.openai.com/api-keys", ErrMissingAPIKey, c.resolvedProvider()) diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go index f45f7e9..75e6f87 100644 --- a/internal/config/validation_test.go +++ b/internal/config/validation_test.go @@ -2,7 +2,6 @@ package config import ( "errors" - "os" "strings" "testing" ) @@ -27,30 +26,12 @@ func validBaseConfig(provider string) *Config { cfg.OllamaHost = "http://localhost:11434" case "openai": cfg.ModelName = "gpt-4o" - } - return cfg -} - -// setEnvForProvider sets the required API key for the given provider. -// Returns a cleanup function. -func setEnvForProvider(t *testing.T, provider string) func() { - t.Helper() - switch provider { - case "gemini", "": - if err := os.Setenv("GEMINI_API_KEY", "test-api-key"); err != nil { - t.Fatalf("setting GEMINI_API_KEY: %v", err) - } - return func() { os.Unsetenv("GEMINI_API_KEY") } - case "openai": - if err := os.Setenv("OPENAI_API_KEY", "test-openai-key"); err != nil { - t.Fatalf("setting OPENAI_API_KEY: %v", err) - } - return func() { os.Unsetenv("OPENAI_API_KEY") } - case "ollama": - return func() {} // no key needed + cfg.openaiAPIKey = "test-openai-key" default: - return func() {} + // gemini is the default provider (including when provider is "") + cfg.geminiAPIKey = "test-gemini-key" } + return cfg } // TestValidateSuccess tests successful validation for each provider. @@ -63,9 +44,6 @@ func TestValidateSuccess(t *testing.T) { name = "default" } t.Run(name, func(t *testing.T) { - cleanup := setEnvForProvider(t, provider) - defer cleanup() - cfg := validBaseConfig(provider) if err := cfg.Validate(); err != nil { t.Errorf("Validate() unexpected error with valid config (provider %q): %v", provider, err) @@ -93,21 +71,20 @@ func TestValidateProviderAPIKey(t *testing.T) { tests := []struct { name string provider string - envKey string wantErr bool }{ - {name: "gemini missing key", provider: "gemini", envKey: "GEMINI_API_KEY", wantErr: true}, - {name: "openai missing key", provider: "openai", envKey: "OPENAI_API_KEY", wantErr: true}, + {name: "gemini missing key", provider: "gemini", wantErr: true}, + {name: "openai missing key", provider: "openai", wantErr: true}, {name: "ollama no key needed", provider: "ollama", wantErr: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Clear all API keys - os.Unsetenv("GEMINI_API_KEY") - os.Unsetenv("OPENAI_API_KEY") - cfg := validBaseConfig(tt.provider) + // Clear API keys to test missing key scenario + cfg.geminiAPIKey = "" + cfg.openaiAPIKey = "" + err := cfg.Validate() if tt.wantErr && err == nil { @@ -125,9 +102,6 @@ func TestValidateProviderAPIKey(t *testing.T) { // TestValidateModelName tests model name validation. func TestValidateModelName(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - cfg := validBaseConfig("gemini") cfg.ModelName = "" @@ -142,9 +116,6 @@ func TestValidateModelName(t *testing.T) { // TestValidateTemperature tests temperature range validation. func TestValidateTemperature(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - tests := []struct { name string temperature float32 @@ -180,9 +151,6 @@ func TestValidateTemperature(t *testing.T) { // TestValidateMaxTokens tests max tokens range validation. func TestValidateMaxTokens(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - tests := []struct { name string maxTokens int @@ -231,9 +199,6 @@ func TestValidateOllamaHost(t *testing.T) { // TestValidateEmbedderModel tests embedder model validation. func TestValidateEmbedderModel(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - cfg := validBaseConfig("gemini") cfg.EmbedderModel = "" @@ -248,9 +213,6 @@ func TestValidateEmbedderModel(t *testing.T) { // TestValidatePostgresHost tests PostgreSQL host validation. func TestValidatePostgresHost(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - cfg := validBaseConfig("gemini") cfg.PostgresHost = "" @@ -265,9 +227,6 @@ func TestValidatePostgresHost(t *testing.T) { // TestValidatePostgresPort tests PostgreSQL port validation. func TestValidatePostgresPort(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - tests := []struct { name string port int @@ -302,9 +261,6 @@ func TestValidatePostgresPort(t *testing.T) { // TestValidatePostgresDBName tests PostgreSQL database name validation. func TestValidatePostgresDBName(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - cfg := validBaseConfig("gemini") cfg.PostgresDBName = "" @@ -319,9 +275,6 @@ func TestValidatePostgresDBName(t *testing.T) { // TestValidatePostgresPassword tests PostgreSQL password validation. func TestValidatePostgresPassword(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - tests := []struct { name string password string @@ -363,9 +316,6 @@ func TestValidatePostgresPassword(t *testing.T) { // TestValidatePostgresSSLMode tests PostgreSQL SSL mode validation. func TestValidatePostgresSSLMode(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - tests := []struct { name string sslMode string @@ -405,9 +355,6 @@ func TestValidatePostgresSSLMode(t *testing.T) { // default development password (koopa_dev_password) as a hard error. // The same password passes Validate() with only a warning for CLI/MCP modes. func TestValidateServe_DefaultPassword(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - cfg := validBaseConfig("gemini") cfg.PostgresPassword = "koopa_dev_password" cfg.HMACSecret = "test-hmac-secret-that-is-at-least-32-characters-long" @@ -429,9 +376,6 @@ func TestValidateServe_DefaultPassword(t *testing.T) { // TestValidateServe_NonDefaultPassword verifies that ValidateServe succeeds // when the password is not the default value. func TestValidateServe_NonDefaultPassword(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - cfg := validBaseConfig("gemini") cfg.PostgresPassword = "production_secure_password" cfg.HMACSecret = "test-hmac-secret-that-is-at-least-32-characters-long" @@ -443,9 +387,6 @@ func TestValidateServe_NonDefaultPassword(t *testing.T) { // TestValidateRetentionDays tests retention days range validation. func TestValidateRetentionDays(t *testing.T) { - cleanup := setEnvForProvider(t, "gemini") - defer cleanup() - tests := []struct { name string retentionDays int @@ -482,11 +423,6 @@ func TestValidateRetentionDays(t *testing.T) { // BenchmarkValidate benchmarks configuration validation. func BenchmarkValidate(b *testing.B) { - if err := os.Setenv("GEMINI_API_KEY", "test-key"); err != nil { - b.Fatalf("setting GEMINI_API_KEY: %v", err) - } - defer os.Unsetenv("GEMINI_API_KEY") - cfg := validBaseConfig("gemini") if err := cfg.Validate(); err != nil { diff --git a/internal/mcp/file.go b/internal/mcp/file.go index f86a96a..de7755b 100644 --- a/internal/mcp/file.go +++ b/internal/mcp/file.go @@ -6,8 +6,9 @@ import ( "github.com/firebase/genkit/go/ai" "github.com/google/jsonschema-go/jsonschema" - "github.com/koopa0/koopa/internal/tools" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/koopa0/koopa/internal/tools" ) // registerFile registers all file operation tools to the MCP server. @@ -68,7 +69,7 @@ func (s *Server) registerFile() error { }, s.DeleteFile) // get_file_info - getFileInfoSchema, err := jsonschema.For[tools.GetFileInfoInput](nil) + getFileInfoSchema, err := jsonschema.For[tools.FileInfoInput](nil) if err != nil { return fmt.Errorf("schema for %s: %w", tools.FileInfoName, err) } @@ -128,7 +129,7 @@ func (s *Server) DeleteFile(ctx context.Context, _ *mcp.CallToolRequest, input t } // FileInfo handles the getFileInfo MCP tool call. -func (s *Server) FileInfo(ctx context.Context, _ *mcp.CallToolRequest, input tools.GetFileInfoInput) (*mcp.CallToolResult, any, error) { +func (s *Server) FileInfo(ctx context.Context, _ *mcp.CallToolRequest, input tools.FileInfoInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} result, err := s.file.FileInfo(toolCtx, input) if err != nil { diff --git a/internal/mcp/file_test.go b/internal/mcp/file_test.go index ab0045e..5126c5a 100644 --- a/internal/mcp/file_test.go +++ b/internal/mcp/file_test.go @@ -285,7 +285,7 @@ func TestGetFileInfo_Success(t *testing.T) { t.Fatalf("creating test file: %v", err) } - result, _, err := server.FileInfo(context.Background(), &mcp.CallToolRequest{}, tools.GetFileInfoInput{ + result, _, err := server.FileInfo(context.Background(), &mcp.CallToolRequest{}, tools.FileInfoInput{ Path: testFile, }) @@ -307,7 +307,7 @@ func TestGetFileInfo_FileNotFound(t *testing.T) { t.Fatalf("NewServer(): %v", err) } - result, _, err := server.FileInfo(context.Background(), &mcp.CallToolRequest{}, tools.GetFileInfoInput{ + result, _, err := server.FileInfo(context.Background(), &mcp.CallToolRequest{}, tools.FileInfoInput{ Path: filepath.Join(h.tempDir, "nonexistent.txt"), }) diff --git a/internal/mcp/knowledge.go b/internal/mcp/knowledge.go index f1e1428..ec3a968 100644 --- a/internal/mcp/knowledge.go +++ b/internal/mcp/knowledge.go @@ -6,8 +6,9 @@ import ( "github.com/firebase/genkit/go/ai" "github.com/google/jsonschema-go/jsonschema" - "github.com/koopa0/koopa/internal/tools" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/koopa0/koopa/internal/tools" ) // registerKnowledge registers all knowledge tools to the MCP server. diff --git a/internal/mcp/network.go b/internal/mcp/network.go index dcd8276..db752bc 100644 --- a/internal/mcp/network.go +++ b/internal/mcp/network.go @@ -6,8 +6,9 @@ import ( "github.com/firebase/genkit/go/ai" "github.com/google/jsonschema-go/jsonschema" - "github.com/koopa0/koopa/internal/tools" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/koopa0/koopa/internal/tools" ) // registerNetwork registers all network operation tools to the MCP server. diff --git a/internal/mcp/system.go b/internal/mcp/system.go index 7f75f19..dfc71e0 100644 --- a/internal/mcp/system.go +++ b/internal/mcp/system.go @@ -6,8 +6,9 @@ import ( "github.com/firebase/genkit/go/ai" "github.com/google/jsonschema-go/jsonschema" - "github.com/koopa0/koopa/internal/tools" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/koopa0/koopa/internal/tools" ) // registerSystem registers all system operation tools to the MCP server. @@ -43,7 +44,7 @@ func (s *Server) registerSystem() error { }, s.ExecuteCommand) // get_env - getEnvSchema, err := jsonschema.For[tools.GetEnvInput](nil) + getEnvSchema, err := jsonschema.For[tools.EnvInput](nil) if err != nil { return fmt.Errorf("schema for %s: %w", tools.GetEnvName, err) } @@ -81,7 +82,7 @@ func (s *Server) ExecuteCommand(ctx context.Context, _ *mcp.CallToolRequest, inp } // Env handles the getEnv MCP tool call. -func (s *Server) Env(ctx context.Context, _ *mcp.CallToolRequest, input tools.GetEnvInput) (*mcp.CallToolResult, any, error) { +func (s *Server) Env(ctx context.Context, _ *mcp.CallToolRequest, input tools.EnvInput) (*mcp.CallToolResult, any, error) { toolCtx := &ai.ToolContext{Context: ctx} result, err := s.system.Env(toolCtx, input) if err != nil { diff --git a/internal/mcp/system_test.go b/internal/mcp/system_test.go index 8166631..f8103bc 100644 --- a/internal/mcp/system_test.go +++ b/internal/mcp/system_test.go @@ -154,7 +154,7 @@ func TestGetEnv_Success(t *testing.T) { testValue := "test_value_123" t.Setenv(testKey, testValue) - result, _, err := server.Env(context.Background(), &mcp.CallToolRequest{}, tools.GetEnvInput{ + result, _, err := server.Env(context.Background(), &mcp.CallToolRequest{}, tools.EnvInput{ Key: testKey, }) @@ -190,7 +190,7 @@ func TestGetEnv_NotSet(t *testing.T) { t.Fatalf("NewServer(): %v", err) } - result, _, err := server.Env(context.Background(), &mcp.CallToolRequest{}, tools.GetEnvInput{ + result, _, err := server.Env(context.Background(), &mcp.CallToolRequest{}, tools.EnvInput{ Key: "NONEXISTENT_VAR_12345", }) @@ -232,7 +232,7 @@ func TestGetEnv_SensitiveVariableBlocked(t *testing.T) { for _, key := range sensitiveKeys { t.Run(key, func(t *testing.T) { - result, _, err := server.Env(context.Background(), &mcp.CallToolRequest{}, tools.GetEnvInput{ + result, _, err := server.Env(context.Background(), &mcp.CallToolRequest{}, tools.EnvInput{ Key: key, }) diff --git a/internal/mcp/util.go b/internal/mcp/util.go index 2356ecd..a92c57b 100644 --- a/internal/mcp/util.go +++ b/internal/mcp/util.go @@ -5,8 +5,9 @@ import ( "fmt" "log/slog" - "github.com/koopa0/koopa/internal/tools" "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/koopa0/koopa/internal/tools" ) // MCP Error Detail Whitelist Policy: diff --git a/internal/memory/memory.go b/internal/memory/memory.go index edea26f..861c9c2 100644 --- a/internal/memory/memory.go +++ b/internal/memory/memory.go @@ -15,6 +15,8 @@ import ( "time" "github.com/google/uuid" + + "github.com/koopa0/koopa/internal/rag" ) // Sentinel errors for memory operations. @@ -95,8 +97,10 @@ func (c Category) DecayLambda() float64 { return math.Log(2) / halfLife } -// VectorDimension matches the embedding column size (768). -const VectorDimension int32 = 768 +// VectorDimension matches the embedding column size. +// Canonical source: rag.VectorDimension. Aliased here to avoid changing +// 20+ references in memory package tests. +const VectorDimension = rag.VectorDimension // Two-threshold dedup constants. const ( diff --git a/internal/memory/store.go b/internal/memory/store.go index 015d429..9e64493 100644 --- a/internal/memory/store.go +++ b/internal/memory/store.go @@ -604,45 +604,48 @@ func (s *Store) UpdateAccess(ctx context.Context, ids []uuid.UUID) error { return nil } -// UpdateDecayScores recalculates decay_score for all active memories. -// Processes per-category with batched UPDATEs to avoid large locks. +// UpdateDecayScores recalculates decay_score for all active memories in a single query. +// Uses a VALUES join to apply per-category lambda rates atomically. // Does NOT update updated_at to preserve the decay index. // Returns total number of rows updated. // // The Go-side formula must stay in sync with the SQL expression: // // Go: math.Exp(-lambda * hours) -// SQL: exp(-$1 * extract(epoch from (now() - updated_at)) / 3600.0) +// SQL: exp(-v.lambda * extract(epoch from (now() - m.updated_at)) / 3600.0) // -// NOTE: The explicit $1::float8 cast is required because pgx v5 sends +// NOTE: The explicit ::float8 cast is required because pgx v5 sends // Go float64 as an untyped parameter. When PostgreSQL sees `$1 = 0`, // it infers the parameter as integer, silently truncating 0.001925 → 0. // The cast forces float8 inference. See: github.com/jackc/pgx/issues/2125 func (s *Store) UpdateDecayScores(ctx context.Context) (int, error) { categories := AllCategories() - var total int - for _, cat := range categories { - lambda := cat.DecayLambda() - - tag, err := s.pool.Exec(ctx, - `UPDATE memories - SET decay_score = CASE - WHEN $1::float8 = 0.0 THEN 1.0 - ELSE LEAST(1.0, exp(-$1::float8 * extract(epoch from (now() - updated_at)) / 3600.0)) - END - WHERE active = true - AND superseded_by IS NULL - AND category = $2`, - lambda, string(cat), - ) - if err != nil { - return total, fmt.Errorf("updating decay scores for %s: %w", cat, err) - } - total += int(tag.RowsAffected()) + // Build VALUES clause and params for a single batched UPDATE. + // Each category contributes ($N::text, $N+1::float8) to the VALUES list. + params := make([]any, 0, len(categories)*2) + valueParts := make([]string, 0, len(categories)) + for i, cat := range categories { + paramIdx := i*2 + 1 // $1, $3, $5, $7 + valueParts = append(valueParts, fmt.Sprintf("($%d::text, $%d::float8)", paramIdx, paramIdx+1)) + params = append(params, string(cat), cat.DecayLambda()) + } + + query := `UPDATE memories m + SET decay_score = CASE + WHEN v.lambda = 0.0 THEN 1.0 + ELSE LEAST(1.0, exp(-v.lambda * extract(epoch from (now() - m.updated_at)) / 3600.0)) + END + FROM (VALUES ` + strings.Join(valueParts, ", ") + `) AS v(cat, lambda) + WHERE m.active = true + AND m.superseded_by IS NULL + AND m.category = v.cat` + + tag, err := s.pool.Exec(ctx, query, params...) + if err != nil { + return 0, fmt.Errorf("updating decay scores: %w", err) } - - return total, nil + return int(tag.RowsAffected()), nil } // DeleteStale soft-deletes memories past their expires_at timestamp. @@ -1049,7 +1052,7 @@ func (s *Store) Memories(ctx context.Context, ownerID string, limit, offset int) } defer rows.Close() - var memories []*Memory + memories := make([]*Memory, 0) var total int for rows.Next() { m := &Memory{} diff --git a/internal/security/command.go b/internal/security/command.go index df28d7d..6ad353f 100644 --- a/internal/security/command.go +++ b/internal/security/command.go @@ -64,16 +64,12 @@ func NewCommand() *Command { // Prevents RCE via: go test (F2), npm install, git grep/archive (F3). allowedSubcommands: map[string][]string{ "git": { - // Read-only metadata - "status", "branch", "tag", "remote", "rev-parse", "describe", - // History viewing (content scoped to repo) + // Read-only operations only — no write/mutate commands + // NOTE: "branch" and "tag" excluded — destructive flags (-D, -d, -m, -f) + // cannot be safely blocked without colliding with legitimate read-only + // flags (e.g., git diff -M). Use "git log --oneline --decorate" instead. + "status", "remote", "rev-parse", "describe", "log", "diff", "show", "blame", - // Working tree operations - "add", "commit", "push", "pull", "fetch", - "checkout", "switch", "merge", "rebase", - "stash", "restore", "reset", - // Maintenance - "clean", "gc", "prune", }, "go": { // Metadata (no code execution) diff --git a/internal/security/command_test.go b/internal/security/command_test.go index e43dd60..b6f1bc8 100644 --- a/internal/security/command_test.go +++ b/internal/security/command_test.go @@ -170,25 +170,27 @@ func TestAllowedSubcommands(t *testing.T) { args []string shouldErr bool }{ - // git: allowed subcommands + // git: allowed subcommands (read-only only) {name: "git status", command: "git", args: []string{"status"}, shouldErr: false}, {name: "git log", command: "git", args: []string{"log", "--oneline"}, shouldErr: false}, {name: "git diff", command: "git", args: []string{"diff"}, shouldErr: false}, {name: "git show", command: "git", args: []string{"show", "HEAD"}, shouldErr: false}, {name: "git blame", command: "git", args: []string{"blame", "file.go"}, shouldErr: false}, - {name: "git branch", command: "git", args: []string{"branch", "-a"}, shouldErr: false}, - {name: "git add", command: "git", args: []string{"add", "."}, shouldErr: false}, - {name: "git commit", command: "git", args: []string{"commit", "-m", "msg"}, shouldErr: false}, - {name: "git push", command: "git", args: []string{"push"}, shouldErr: false}, - {name: "git pull", command: "git", args: []string{"pull"}, shouldErr: false}, - {name: "git fetch", command: "git", args: []string{"fetch"}, shouldErr: false}, - {name: "git checkout", command: "git", args: []string{"checkout", "main"}, shouldErr: false}, - {name: "git merge", command: "git", args: []string{"merge", "feature"}, shouldErr: false}, - {name: "git rebase", command: "git", args: []string{"rebase", "main"}, shouldErr: false}, - {name: "git stash", command: "git", args: []string{"stash"}, shouldErr: false}, - {name: "git tag", command: "git", args: []string{"tag", "-l"}, shouldErr: false}, {name: "git remote", command: "git", args: []string{"remote", "-v"}, shouldErr: false}, - // git: blocked subcommands (not in allowlist) + {name: "git rev-parse", command: "git", args: []string{"rev-parse", "HEAD"}, shouldErr: false}, + {name: "git describe", command: "git", args: []string{"describe", "--tags"}, shouldErr: false}, + // git: blocked subcommands (write/mutate or have destructive flags) + {name: "git branch blocked", command: "git", args: []string{"branch", "-a"}, shouldErr: true}, + {name: "git tag blocked", command: "git", args: []string{"tag", "-l"}, shouldErr: true}, + {name: "git add blocked", command: "git", args: []string{"add", "."}, shouldErr: true}, + {name: "git commit blocked", command: "git", args: []string{"commit", "-m", "msg"}, shouldErr: true}, + {name: "git push blocked", command: "git", args: []string{"push"}, shouldErr: true}, + {name: "git pull blocked", command: "git", args: []string{"pull"}, shouldErr: true}, + {name: "git fetch blocked", command: "git", args: []string{"fetch"}, shouldErr: true}, + {name: "git checkout blocked", command: "git", args: []string{"checkout", "main"}, shouldErr: true}, + {name: "git merge blocked", command: "git", args: []string{"merge", "feature"}, shouldErr: true}, + {name: "git rebase blocked", command: "git", args: []string{"rebase", "main"}, shouldErr: true}, + {name: "git stash blocked", command: "git", args: []string{"stash"}, shouldErr: true}, {name: "git grep blocked (F3)", command: "git", args: []string{"grep", "pattern"}, shouldErr: true}, {name: "git archive blocked (F3)", command: "git", args: []string{"archive", "HEAD"}, shouldErr: true}, {name: "git filter-branch blocked", command: "git", args: []string{"filter-branch", "--tree-filter", "cmd"}, shouldErr: true}, diff --git a/internal/security/url.go b/internal/security/url.go index aef0122..d49b55a 100644 --- a/internal/security/url.go +++ b/internal/security/url.go @@ -146,9 +146,10 @@ func (v *URL) SafeTransport() *http.Transport { return &http.Transport{ DialContext: v.safeDialContext, // Reasonable defaults - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, } } diff --git a/internal/session/state.go b/internal/session/state.go index db30f7f..69b7e0b 100644 --- a/internal/session/state.go +++ b/internal/session/state.go @@ -13,19 +13,18 @@ import ( ) const ( - stateDir = ".koopa" - stateFile = "current_session" - lockTimeout = 5 * time.Second // Maximum time to wait for lock - lockFileName = "current_session.lock" + defaultStateDir = ".koopa" + stateFile = "current_session" + lockTimeout = 5 * time.Second // Maximum time to wait for lock + lockFileName = "current_session.lock" ) -// getStateDirPath returns the state directory path. -// Checks KOOPA_STATE_DIR environment variable first (for testing), -// then falls back to ~/.koopa (for production). -func getStateDirPath() (string, error) { - // Check for test override - if testDir := os.Getenv("KOOPA_STATE_DIR"); testDir != "" { - return testDir, nil +// resolveStateDir returns the state directory path. +// If overrideDir is non-empty, it is used directly (for testing or custom config). +// Otherwise falls back to ~/.koopa (production default). +func resolveStateDir(overrideDir string) (string, error) { + if overrideDir != "" { + return overrideDir, nil } // Production: use ~/.koopa @@ -34,22 +33,13 @@ func getStateDirPath() (string, error) { return "", fmt.Errorf("getting home directory: %w", err) } - return filepath.Join(homeDir, stateDir), nil + return filepath.Join(homeDir, defaultStateDir), nil } -// getStateFilePath returns the full path to the current session state file. -// Creates the state directory (~/.koopa) if it doesn't exist. -// -// For testing, you can override the state directory by setting KOOPA_STATE_DIR -// environment variable to a temporary directory (e.g., t.TempDir()). -// -// Returns: -// - string: Path to ~/.koopa/current_session (or $KOOPA_STATE_DIR/current_session if set) -// - error: If unable to determine home directory or create state directory -// -// Note: This is a private function as it's only used within the session package. -func getStateFilePath() (string, error) { - stateDirPath, err := getStateDirPath() +// stateFilePath returns the full path to the current session state file. +// Creates the state directory if it doesn't exist. +func stateFilePath(overrideDir string) (string, error) { + stateDirPath, err := resolveStateDir(overrideDir) if err != nil { return "", err } @@ -64,27 +54,24 @@ func getStateFilePath() (string, error) { // LoadCurrentSessionID loads the currently active session ID from local state file. // -// Acquires exclusive file lock to prevent concurrent access during read. -// -// Returns: -// - *uuid.UUID: Current session ID (nil if no current session) -// - error: If state file exists but is malformed or unreadable +// The stateDir parameter overrides the default ~/.koopa directory. +// Pass empty string for production default, or a temp directory for testing. // -// Note: Returns (nil, nil) if state file doesn't exist - this is not an error. -func LoadCurrentSessionID() (*uuid.UUID, error) { - filePath, err := getStateFilePath() +// Returns (nil, nil) if state file doesn't exist - this is not an error. +func LoadCurrentSessionID(stateDir string) (*uuid.UUID, error) { + filePath, err := stateFilePath(stateDir) if err != nil { return nil, err } // Acquire file lock to prevent concurrent writes - lock, err := acquireStateLock() + lock, err := acquireStateLock(stateDir) if err != nil { return nil, fmt.Errorf("acquiring lock: %w", err) } defer func() { _ = lock.Unlock() }() - // #nosec G304 -- filePath is constructed internally via getStateFilePath() to ~/.koopa/current_session, not from user input + // #nosec G304 -- filePath is constructed internally via stateFilePath() to ~/.koopa/current_session, not from user input data, err := os.ReadFile(filePath) if err != nil { if os.IsNotExist(err) { @@ -107,22 +94,18 @@ func LoadCurrentSessionID() (*uuid.UUID, error) { } // cleanupOrphanedTempFiles removes any stale .tmp files from previous crashed sessions. -// This prevents accumulation of temporary files when the process crashes after writing -// the temp file but before renaming it. -func cleanupOrphanedTempFiles() error { - stateDirPath, err := getStateDirPath() +func cleanupOrphanedTempFiles(stateDir string) error { + stateDirPath, err := resolveStateDir(stateDir) if err != nil { return fmt.Errorf("getting state directory: %w", err) } - // Find all .tmp files in state directory pattern := filepath.Join(stateDirPath, "*.tmp") tmpFiles, err := filepath.Glob(pattern) if err != nil { return fmt.Errorf("finding temp files: %w", err) } - // Remove each temp file (ignore errors as they may not exist) for _, tmpFile := range tmpFiles { _ = os.Remove(tmpFile) } @@ -133,38 +116,27 @@ func cleanupOrphanedTempFiles() error { // SaveCurrentSessionID saves the current session ID to local state file. // // Uses atomic write (temp file + rename) to ensure file is never partially written. -// Acquires exclusive file lock to prevent concurrent access. -// -// Parameters: -// - sessionID: UUID of the session to mark as current -// -// Returns: -// - error: If unable to write state file -func SaveCurrentSessionID(sessionID uuid.UUID) error { - filePath, err := getStateFilePath() +// The stateDir parameter overrides the default ~/.koopa directory. +func SaveCurrentSessionID(stateDir string, sessionID uuid.UUID) error { + filePath, err := stateFilePath(stateDir) if err != nil { return fmt.Errorf("saving session: %w", err) } - // Acquire file lock to prevent concurrent access - lock, err := acquireStateLock() + lock, err := acquireStateLock(stateDir) if err != nil { return fmt.Errorf("acquiring lock: %w", err) } defer func() { _ = lock.Unlock() }() - // Clean up any orphaned temp files from previous crashed sessions (under lock) - _ = cleanupOrphanedTempFiles() + _ = cleanupOrphanedTempFiles(stateDir) - // Write to temporary file first (atomic write pattern) tmpFile := filePath + ".tmp" if err := os.WriteFile(tmpFile, []byte(sessionID.String()), 0o600); err != nil { return fmt.Errorf("writing temp state file: %w", err) } - // Atomically rename temp file to final file if err := os.Rename(tmpFile, filePath); err != nil { - // Clean up temp file on error _ = os.Remove(tmpFile) return fmt.Errorf("updating state file: %w", err) } @@ -173,19 +145,15 @@ func SaveCurrentSessionID(sessionID uuid.UUID) error { } // ClearCurrentSessionID removes the current session state file. -// -// Returns: -// - error: If unable to remove state file (ignores "file not found" errors) -// -// Note: This is idempotent - calling it when no current session exists is not an error. -func ClearCurrentSessionID() error { - filePath, err := getStateFilePath() +// Idempotent - calling when no current session exists is not an error. +// The stateDir parameter overrides the default ~/.koopa directory. +func ClearCurrentSessionID(stateDir string) error { + filePath, err := stateFilePath(stateDir) if err != nil { return fmt.Errorf("clearing session: %w", err) } - // Acquire file lock to prevent concurrent access - lock, err := acquireStateLock() + lock, err := acquireStateLock(stateDir) if err != nil { return fmt.Errorf("acquiring lock: %w", err) } @@ -200,9 +168,8 @@ func ClearCurrentSessionID() error { } // acquireStateLock acquires an exclusive lock on the state file. -// Returns a locked flock.Flock instance that should be unlocked by the caller. -func acquireStateLock() (*flock.Flock, error) { - stateDirPath, err := getStateDirPath() +func acquireStateLock(stateDir string) (*flock.Flock, error) { + stateDirPath, err := resolveStateDir(stateDir) if err != nil { return nil, err } diff --git a/internal/session/state_test.go b/internal/session/state_test.go index 8007259..cc0458b 100644 --- a/internal/session/state_test.go +++ b/internal/session/state_test.go @@ -9,55 +9,50 @@ import ( "github.com/google/uuid" ) -func TestGetStateFilePath(t *testing.T) { - // Note: Testing private function getStateFilePath (accessible within same package) - // Use isolated temp directory for this test +func TestStateFilePath(t *testing.T) { tempDir := t.TempDir() - t.Setenv("KOOPA_STATE_DIR", tempDir) - path, err := getStateFilePath() + path, err := stateFilePath(tempDir) if err != nil { - t.Fatalf("getStateFilePath() error = %v", err) + t.Fatalf("stateFilePath(%q) error = %v", tempDir, err) } if path == "" { - t.Error("getStateFilePath() returned empty path") + t.Error("stateFilePath() returned empty path") } // Verify path is absolute if !filepath.IsAbs(path) { - t.Errorf("getStateFilePath() returned relative path: %q", path) + t.Errorf("stateFilePath() returned relative path: %q", path) } // Verify path uses temp directory rel, err := filepath.Rel(tempDir, path) if err != nil || strings.HasPrefix(rel, "..") { - t.Errorf("getStateFilePath() = %q, want within %q", path, tempDir) + t.Errorf("stateFilePath() = %q, want within %q", path, tempDir) } // Verify directory was created dir := filepath.Dir(path) if _, err := os.Stat(dir); os.IsNotExist(err) { - t.Errorf("getStateFilePath() did not create directory: %q", dir) + t.Errorf("stateFilePath() did not create directory: %q", dir) } } func TestSaveAndLoadCurrentSessionID(t *testing.T) { - // Use isolated temp directory for all sub-tests tempDir := t.TempDir() - t.Setenv("KOOPA_STATE_DIR", tempDir) t.Run("save and load session ID", func(t *testing.T) { testID := uuid.New() // Save session ID - err := SaveCurrentSessionID(testID) + err := SaveCurrentSessionID(tempDir, testID) if err != nil { t.Fatalf("SaveCurrentSessionID() error = %v", err) } // Load session ID - loadedID, err := LoadCurrentSessionID() + loadedID, err := LoadCurrentSessionID(tempDir) if err != nil { t.Fatalf("LoadCurrentSessionID() error = %v", err) } @@ -72,10 +67,9 @@ func TestSaveAndLoadCurrentSessionID(t *testing.T) { }) t.Run("load returns nil when file doesn't exist", func(t *testing.T) { - // Ensure file doesn't exist - _ = ClearCurrentSessionID() + emptyDir := t.TempDir() - loadedID, err := LoadCurrentSessionID() + loadedID, err := LoadCurrentSessionID(emptyDir) if err != nil { t.Errorf("LoadCurrentSessionID() error = %v, want nil", err) } @@ -90,19 +84,19 @@ func TestSaveAndLoadCurrentSessionID(t *testing.T) { secondID := uuid.New() // Save first ID - err := SaveCurrentSessionID(firstID) + err := SaveCurrentSessionID(tempDir, firstID) if err != nil { t.Fatalf("SaveCurrentSessionID() first save error = %v", err) } // Overwrite with second ID - err = SaveCurrentSessionID(secondID) + err = SaveCurrentSessionID(tempDir, secondID) if err != nil { t.Fatalf("SaveCurrentSessionID() second save error = %v", err) } // Load and verify second ID - loadedID, err := LoadCurrentSessionID() + loadedID, err := LoadCurrentSessionID(tempDir) if err != nil { t.Fatalf("LoadCurrentSessionID() error = %v", err) } @@ -118,26 +112,24 @@ func TestSaveAndLoadCurrentSessionID(t *testing.T) { } func TestClearCurrentSessionID(t *testing.T) { - // Use isolated temp directory for all sub-tests - tempDir := t.TempDir() - t.Setenv("KOOPA_STATE_DIR", tempDir) - t.Run("clear existing session ID", func(t *testing.T) { + tempDir := t.TempDir() + // Set up - save a session ID first testID := uuid.New() - err := SaveCurrentSessionID(testID) + err := SaveCurrentSessionID(tempDir, testID) if err != nil { t.Fatalf("SaveCurrentSessionID() setup error = %v", err) } // Clear session ID - err = ClearCurrentSessionID() + err = ClearCurrentSessionID(tempDir) if err != nil { t.Errorf("ClearCurrentSessionID() error = %v", err) } // Verify file was deleted - loadedID, err := LoadCurrentSessionID() + loadedID, err := LoadCurrentSessionID(tempDir) if err != nil { t.Errorf("LoadCurrentSessionID() error = %v", err) } @@ -148,11 +140,10 @@ func TestClearCurrentSessionID(t *testing.T) { }) t.Run("clear when file doesn't exist is not an error", func(t *testing.T) { - // Ensure file doesn't exist - _ = ClearCurrentSessionID() + tempDir := t.TempDir() - // Clear again should not error - err := ClearCurrentSessionID() + // Clear on empty dir should not error + err := ClearCurrentSessionID(tempDir) if err != nil { t.Errorf("ClearCurrentSessionID() on non-existent file error = %v, want nil", err) } @@ -160,10 +151,6 @@ func TestClearCurrentSessionID(t *testing.T) { } func TestLoadCurrentSessionID_InvalidContent(t *testing.T) { - // Use isolated temp directory for all sub-tests - tempDir := t.TempDir() - t.Setenv("KOOPA_STATE_DIR", tempDir) - tests := []struct { name string content string @@ -202,10 +189,12 @@ func TestLoadCurrentSessionID_InvalidContent(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + // Write test content directly to state file - filePath, err := getStateFilePath() + filePath, err := stateFilePath(tempDir) if err != nil { - t.Fatalf("getStateFilePath() error = %v", err) + t.Fatalf("stateFilePath(%q) error = %v", tempDir, err) } err = os.WriteFile(filePath, []byte(tt.content), 0o600) @@ -214,7 +203,7 @@ func TestLoadCurrentSessionID_InvalidContent(t *testing.T) { } // Try to load - loadedID, err := LoadCurrentSessionID() + loadedID, err := LoadCurrentSessionID(tempDir) if (err != nil) != tt.wantErr { t.Errorf("LoadCurrentSessionID() error = %v, wantErr %v", err, tt.wantErr) diff --git a/internal/session/store.go b/internal/session/store.go index b3e6028..9e0a5eb 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -117,9 +117,13 @@ func (s *Store) Sessions(ctx context.Context, ownerID string, limit, offset int) const listSQL = ` SELECT s.id, s.title, s.owner_id, s.created_at, s.updated_at, - (SELECT COUNT(*) FROM messages m WHERE m.session_id = s.id) AS message_count, + COALESCE(mc.cnt, 0) AS message_count, COUNT(*) OVER() AS total FROM sessions s + LEFT JOIN ( + SELECT session_id, COUNT(*) AS cnt + FROM messages GROUP BY session_id + ) mc ON mc.session_id = s.id WHERE s.owner_id = $1 ORDER BY s.updated_at DESC LIMIT $2 OFFSET $3 @@ -131,7 +135,7 @@ func (s *Store) Sessions(ctx context.Context, ownerID string, limit, offset int) } defer rows.Close() - var sessions []*Session + sessions := make([]*Session, 0) var total int for rows.Next() { var ss Session @@ -495,7 +499,7 @@ func (s *Store) History(ctx context.Context, sessionID uuid.UUID) ([]*ai.Message // Returns the session ID. func (s *Store) ResolveCurrentSession(ctx context.Context) (uuid.UUID, error) { //nolint:contextcheck // LoadCurrentSessionID manages its own lock timeout context - savedID, err := LoadCurrentSessionID() + savedID, err := LoadCurrentSessionID("") if err != nil { return uuid.Nil, fmt.Errorf("loading current session: %w", err) } @@ -516,7 +520,7 @@ func (s *Store) ResolveCurrentSession(ctx context.Context) (uuid.UUID, error) { // best-effort: state file is non-critical, session already created in DB //nolint:contextcheck // SaveCurrentSessionID manages its own lock timeout context - if saveErr := SaveCurrentSessionID(newSess.ID); saveErr != nil { + if saveErr := SaveCurrentSessionID("", newSess.ID); saveErr != nil { s.logger.Warn("saving session state", "error", saveErr) } diff --git a/internal/tools/file.go b/internal/tools/file.go index 6ccc178..9c1d62f 100644 --- a/internal/tools/file.go +++ b/internal/tools/file.go @@ -72,8 +72,8 @@ type DeleteFileInput struct { Path string `json:"path" jsonschema_description:"The file path to delete"` } -// GetFileInfoInput defines input for get_file_info tool. -type GetFileInfoInput struct { +// FileInfoInput defines input for get_file_info tool. +type FileInfoInput struct { Path string `json:"path" jsonschema_description:"The file path to get info for"` } @@ -427,7 +427,7 @@ func (f *File) DeleteFile(_ *ai.ToolContext, input DeleteFileInput) (Result, err } // FileInfo gets file metadata. -func (f *File) FileInfo(_ *ai.ToolContext, input GetFileInfoInput) (Result, error) { +func (f *File) FileInfo(_ *ai.ToolContext, input FileInfoInput) (Result, error) { f.logger.Debug("FileInfo called", "path", input.Path) if len(input.Path) > MaxPathLength { diff --git a/internal/tools/file_integration_test.go b/internal/tools/file_integration_test.go index 474e0f7..6ae0e9b 100644 --- a/internal/tools/file_integration_test.go +++ b/internal/tools/file_integration_test.go @@ -594,7 +594,7 @@ func TestFile_GetFileInfo_PathSecurity(t *testing.T) { h := newfileTools(t) ft := h.createFile() - result, err := ft.FileInfo(nil, GetFileInfoInput{Path: "/etc/passwd"}) + result, err := ft.FileInfo(nil, FileInfoInput{Path: "/etc/passwd"}) if err != nil { t.Fatalf("GetFileInfo(%q) unexpected Go error: %v (should not return Go error)", "/etc/passwd", err) @@ -619,7 +619,7 @@ func TestFile_GetFileInfo_Success(t *testing.T) { // Create a test file testPath := h.createTestFile("info.txt", "test content") - result, err := ft.FileInfo(nil, GetFileInfoInput{Path: testPath}) + result, err := ft.FileInfo(nil, FileInfoInput{Path: testPath}) if err != nil { t.Fatalf("GetFileInfo(%q) unexpected error: %v", testPath, err) @@ -649,7 +649,7 @@ func TestFile_GetFileInfo_NotFound(t *testing.T) { nonExistentPath := filepath.Join(h.tempDir, "does-not-exist.txt") - result, err := ft.FileInfo(nil, GetFileInfoInput{Path: nonExistentPath}) + result, err := ft.FileInfo(nil, FileInfoInput{Path: nonExistentPath}) if err != nil { t.Fatalf("GetFileInfo(%q) unexpected Go error: %v (should not return Go error)", nonExistentPath, err) diff --git a/internal/tools/knowledge.go b/internal/tools/knowledge.go index b66b82a..677bcb1 100644 --- a/internal/tools/knowledge.go +++ b/internal/tools/knowledge.go @@ -257,6 +257,7 @@ func stripInjectionMarkers(content string) string { } // countUserDocs returns the number of user-created documents owned by the given owner. +// Raw SQL: sqlc does not manage this query. Keep in sync with documents table schema. func (k *Knowledge) countUserDocs(ctx context.Context, ownerID string) (int64, error) { var count int64 err := k.pool.QueryRow(ctx, diff --git a/internal/tools/network.go b/internal/tools/network.go index 03bec24..d88b725 100644 --- a/internal/tools/network.go +++ b/internal/tools/network.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "log/slog" "net/http" "net/url" @@ -235,9 +236,9 @@ func (n *Network) Search(ctx *ai.ToolContext, input SearchInput) (SearchOutput, return SearchOutput{Query: input.Query, Error: fmt.Sprintf("Search request error: HTTP %d", resp.StatusCode)}, nil } - // Parse response + // Parse response (limit body to 10 MB to prevent resource exhaustion) var searxResp searxngResponse - if err := json.NewDecoder(resp.Body).Decode(&searxResp); err != nil { + if err := json.NewDecoder(io.LimitReader(resp.Body, 10<<20)).Decode(&searxResp); err != nil { return SearchOutput{}, fmt.Errorf("parse response: %w", err) } diff --git a/internal/tools/system.go b/internal/tools/system.go index 4509741..d3445c0 100644 --- a/internal/tools/system.go +++ b/internal/tools/system.go @@ -38,8 +38,8 @@ const MaxEnvKeyLength = 256 // Prevents abuse via extremely long command strings. const MaxCommandArgLength = 10000 -// GetEnvInput defines input for get_env tool. -type GetEnvInput struct { +// EnvInput defines input for get_env tool. +type EnvInput struct { Key string `json:"key" jsonschema_description:"The environment variable name"` } @@ -201,7 +201,7 @@ func (s *System) ExecuteCommand(ctx *ai.ToolContext, input ExecuteCommandInput) // Env reads an environment variable value with security protection. // Sensitive variables containing KEY, SECRET, or TOKEN in the name are blocked. // Business errors (sensitive variable blocked) are returned in Result.Error. -func (s *System) Env(_ *ai.ToolContext, input GetEnvInput) (Result, error) { +func (s *System) Env(_ *ai.ToolContext, input EnvInput) (Result, error) { s.logger.Debug("Env called", "key", input.Key) if len(input.Key) > MaxEnvKeyLength { diff --git a/internal/tools/system_integration_test.go b/internal/tools/system_integration_test.go index f501bc9..20d5e43 100644 --- a/internal/tools/system_integration_test.go +++ b/internal/tools/system_integration_test.go @@ -518,7 +518,7 @@ func TestSystem_GetEnv_SensitiveVariableBlocked(t *testing.T) { h := newsystemTools(t) st := h.createSystem() - result, err := st.Env(nil, GetEnvInput{Key: tt.envKey}) + result, err := st.Env(nil, EnvInput{Key: tt.envKey}) // Go error only for infrastructure errors if err != nil { @@ -559,7 +559,7 @@ func TestSystem_GetEnv_SafeVariableAllowed(t *testing.T) { h := newsystemTools(t) st := h.createSystem() - result, err := st.Env(nil, GetEnvInput{Key: tt.envKey}) + result, err := st.Env(nil, EnvInput{Key: tt.envKey}) if err != nil { t.Fatalf("GetEnv(%q) unexpected error: %v (safe variable should not be blocked)", tt.envKey, err) @@ -605,7 +605,7 @@ func TestSystem_GetEnv_CaseInsensitiveBlocking(t *testing.T) { h := newsystemTools(t) st := h.createSystem() - result, err := st.Env(nil, GetEnvInput{Key: tt.envKey}) + result, err := st.Env(nil, EnvInput{Key: tt.envKey}) // Go error only for infrastructure errors if err != nil {