From b193da60515e6770b1fc55e38fd3f61901e2afad Mon Sep 17 00:00:00 2001 From: Koopa Date: Sun, 15 Feb 2026 16:10:53 +0800 Subject: [PATCH] feat: add runtime hardening - Add http.MaxBytesReader (1MB) and content length cap (32KB) to API chat handlers - Add ErrPathDenied sentinel and deniedPrefixes to security.Path for blocking prompt directory access - Add MaxKnowledgeContentSize (50KB) validation to knowledge_store tool - Update NewPath signature across all callers to support denied prefixes - Add unit tests for all new validation paths --- internal/api/chat.go | 25 ++++++ internal/api/chat_test.go | 67 ++++++++++++++ internal/api/server.go | 6 +- internal/api/session.go | 18 +++- internal/app/app_test.go | 2 +- internal/app/setup.go | 3 +- internal/chat/setup_test.go | 2 +- internal/config/observability.go | 2 +- internal/mcp/benchmark_test.go | 4 +- internal/mcp/integration_test.go | 2 +- internal/mcp/server_test.go | 2 +- internal/security/doc.go | 4 +- internal/security/fuzz_test.go | 4 +- internal/security/path.go | 50 +++++++++-- internal/security/path_test.go | 113 ++++++++++++++++++++++-- internal/tools/doc.go | 2 +- internal/tools/file_integration_test.go | 2 +- internal/tools/file_test.go | 4 +- internal/tools/fuzz_test.go | 2 +- internal/tools/knowledge.go | 13 +++ internal/tools/knowledge_test.go | 79 +++++++++++++++++ internal/tools/register_test.go | 8 +- internal/tui/setup_test.go | 2 +- 23 files changed, 375 insertions(+), 41 deletions(-) diff --git a/internal/api/chat.go b/internal/api/chat.go index 0e5a805..7b59646 100644 --- a/internal/api/chat.go +++ b/internal/api/chat.go @@ -22,6 +22,12 @@ const sseTimeout = 5 * time.Minute // titleMaxLength is the maximum rune length for a fallback session title. const titleMaxLength = 50 +// maxRequestBodySize is the maximum allowed HTTP request body size (1 MB). +const maxRequestBodySize = 1 << 20 + +// maxChatContentLength is the maximum allowed chat message length in bytes (~8K tokens). +const maxChatContentLength = 32_000 + // Tool display info for JSON SSE events. type toolDisplayInfo struct { StartMsg string @@ -69,11 +75,18 @@ type chatHandler struct { // send handles POST /api/v1/chat — accepts JSON, sends message to chat flow. func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodySize) + var req struct { Content string `json:"content"` SessionID string `json:"sessionId"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + WriteError(w, http.StatusRequestEntityTooLarge, "body_too_large", "request body too large", h.logger) + return + } WriteError(w, http.StatusBadRequest, "invalid_json", "invalid request body", h.logger) return } @@ -84,6 +97,11 @@ func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) { return } + if len(content) > maxChatContentLength { + WriteError(w, http.StatusRequestEntityTooLarge, "content_too_long", "message content exceeds maximum length", h.logger) + return + } + if req.SessionID == "" { WriteError(w, http.StatusBadRequest, "session_required", "sessionId is required", h.logger) return @@ -96,6 +114,7 @@ func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) { } if !h.sessionAccessAllowed(r, sessionID) { + h.logger.Warn("session access denied", "sessionId", req.SessionID, "path", r.URL.Path) WriteError(w, http.StatusForbidden, "forbidden", "session access denied", h.logger) return } @@ -149,6 +168,11 @@ func (h *chatHandler) stream(w http.ResponseWriter, r *http.Request) { return } + if len(query) > maxChatContentLength { + WriteError(w, http.StatusRequestEntityTooLarge, "content_too_long", "query exceeds maximum length", h.logger) + return + } + parsedID, err := uuid.Parse(sessionID) if err != nil { WriteError(w, http.StatusBadRequest, "invalid_session", "invalid session ID", h.logger) @@ -156,6 +180,7 @@ func (h *chatHandler) stream(w http.ResponseWriter, r *http.Request) { } if !h.sessionAccessAllowed(r, parsedID) { + h.logger.Warn("session access denied", "sessionId", sessionID, "msgId", msgID, "path", r.URL.Path) WriteError(w, http.StatusForbidden, "forbidden", "session access denied", h.logger) return } diff --git a/internal/api/chat_test.go b/internal/api/chat_test.go index 667810c..e6f053f 100644 --- a/internal/api/chat_test.go +++ b/internal/api/chat_test.go @@ -196,6 +196,73 @@ func TestChatSend_InvalidJSON(t *testing.T) { } } +func TestChatSend_BodyTooLarge(t *testing.T) { + // Create a valid JSON body larger than maxRequestBodySize (1 MB). + // The content field must be large enough so the whole JSON exceeds the limit. + largeContent := strings.Repeat("x", maxRequestBodySize) + body, _ := json.Marshal(map[string]string{ + "content": largeContent, + "sessionId": uuid.New().String(), + }) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) + + newTestChatHandler().send(w, r) + + if w.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("send(>1MB body) status = %d, want %d\nbody: %s", w.Code, http.StatusRequestEntityTooLarge, w.Body.String()) + } + + errResp := decodeErrorEnvelope(t, w) + if errResp.Code != "body_too_large" { + t.Errorf("send(>1MB body) code = %q, want %q", errResp.Code, "body_too_large") + } +} + +func TestChatSend_ContentTooLong(t *testing.T) { + // Create content that exceeds maxChatContentLength (32K) + longContent := strings.Repeat("x", maxChatContentLength+1) + body, _ := json.Marshal(map[string]string{ + "content": longContent, + "sessionId": uuid.New().String(), + }) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) + + newTestChatHandler().send(w, r) + + if w.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("send(>32K content) status = %d, want %d\nbody: %s", w.Code, http.StatusRequestEntityTooLarge, w.Body.String()) + } + + errResp := decodeErrorEnvelope(t, w) + if errResp.Code != "content_too_long" { + t.Errorf("send(>32K content) code = %q, want %q", errResp.Code, "content_too_long") + } +} + +func TestStream_QueryTooLong(t *testing.T) { + ch := newTestChatHandler() + sessionID := uuid.New() + longQuery := strings.Repeat("x", maxChatContentLength+1) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query="+url.QueryEscape(longQuery), nil) + + ch.stream(w, r) + + if w.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("stream(>32K query) status = %d, want %d\nbody: %s", w.Code, http.StatusRequestEntityTooLarge, w.Body.String()) + } + + errResp := decodeErrorEnvelope(t, w) + if errResp.Code != "content_too_long" { + t.Errorf("stream(>32K query) code = %q, want %q", errResp.Code, "content_too_long") + } +} + func TestChatSend_OwnershipDenied(t *testing.T) { body, _ := json.Marshal(map[string]string{ "content": "hello", diff --git a/internal/api/server.go b/internal/api/server.go index b246e58..7c9e805 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -78,13 +78,15 @@ func NewServer(cfg ServerConfig) (*Server, error) { } rl := newRateLimiter(1.0, burst) - // Build middleware stack: Recovery → Logging → RateLimit → CORS → User → Session → CSRF → Routes + // Build middleware stack (outermost first): + // Recovery → Logging → CORS → RateLimit → User → Session → CSRF → Routes + // CORS must be before RateLimit so preflight OPTIONS gets proper CORS headers. var handler http.Handler = mux handler = csrfMiddleware(sm, logger)(handler) handler = sessionMiddleware(sm)(handler) handler = userMiddleware(sm)(handler) - handler = corsMiddleware(cfg.CORSOrigins)(handler) handler = rateLimitMiddleware(rl, cfg.TrustProxy, logger)(handler) + handler = corsMiddleware(cfg.CORSOrigins)(handler) handler = loggingMiddleware(logger)(handler) handler = recoveryMiddleware(logger)(handler) diff --git a/internal/api/session.go b/internal/api/session.go index 62c7ad8..ea38130 100644 --- a/internal/api/session.go +++ b/internal/api/session.go @@ -120,9 +120,14 @@ func (sm *sessionManager) CheckCSRF(userID, token string) error { message := fmt.Sprintf("%s:%d", userID, timestamp) h := hmac.New(sha256.New, sm.hmacSecret) h.Write([]byte(message)) - expectedSig := base64.URLEncoding.EncodeToString(h.Sum(nil)) + expectedSig := h.Sum(nil) - if subtle.ConstantTimeCompare([]byte(parts[1]), []byte(expectedSig)) != 1 { + actualSig, err := base64.URLEncoding.DecodeString(parts[1]) + if err != nil { + return ErrCSRFMalformed + } + + if subtle.ConstantTimeCompare(actualSig, expectedSig) != 1 { return ErrCSRFInvalid } @@ -176,9 +181,14 @@ func (sm *sessionManager) CheckPreSessionCSRF(token string) error { message := fmt.Sprintf("%s:%d", nonce, timestamp) h := hmac.New(sha256.New, sm.hmacSecret) h.Write([]byte(message)) - expectedSig := base64.URLEncoding.EncodeToString(h.Sum(nil)) + expectedSig := h.Sum(nil) + + actualSig, err := base64.URLEncoding.DecodeString(parts[2]) + if err != nil { + return ErrCSRFMalformed + } - if subtle.ConstantTimeCompare([]byte(parts[2]), []byte(expectedSig)) != 1 { + if subtle.ConstantTimeCompare(actualSig, expectedSig) != 1 { return ErrCSRFInvalid } diff --git a/internal/app/app_test.go b/internal/app/app_test.go index dd8049b..d21c09b 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -71,7 +71,7 @@ func TestApp_Fields(t *testing.T) { defer cancel() g := genkit.Init(ctx) - pathValidator, err := security.NewPath([]string{"."}) + pathValidator, err := security.NewPath([]string{"."}, nil) if err != nil { t.Fatalf("security.NewPath() error: %v", err) } diff --git a/internal/app/setup.go b/internal/app/setup.go index fc1e071..038dd5d 100644 --- a/internal/app/setup.go +++ b/internal/app/setup.go @@ -300,8 +300,9 @@ func provideSessionStore(pool *pgxpool.Pool) *session.Store { } // providePathValidator creates a path validator instance. +// Denies access to prompts/ to protect system prompt files from tool-based access. func providePathValidator() (*security.Path, error) { - return security.NewPath([]string{"."}) + return security.NewPath([]string{"."}, []string{"prompts"}) } // provideTools creates toolsets, registers them with Genkit, and stores both diff --git a/internal/chat/setup_test.go b/internal/chat/setup_test.go index da5997e..9fd039a 100644 --- a/internal/chat/setup_test.go +++ b/internal/chat/setup_test.go @@ -112,7 +112,7 @@ func SetupTest(t *testing.T) *TestFramework { } // Create toolsets - pathValidator, err := security.NewPath([]string{os.TempDir()}) + pathValidator, err := security.NewPath([]string{os.TempDir()}, nil) if err != nil { t.Fatalf("creating path validator: %v", err) } diff --git a/internal/config/observability.go b/internal/config/observability.go index c542795..dc6676e 100644 --- a/internal/config/observability.go +++ b/internal/config/observability.go @@ -11,7 +11,7 @@ import ( // Setup is inlined in internal/app/setup.go (provideOtelShutdown). type DatadogConfig struct { // APIKey is the Datadog API key (optional, for observability) - APIKey string `mapstructure:"api_key" json:"api_key"` + APIKey string `mapstructure:"api_key" json:"api_key"` // #nosec G117 -- masked in MarshalJSON, never serialized in plain text // AgentHost is the Datadog Agent OTLP endpoint (default: localhost:4318) AgentHost string `mapstructure:"agent_host" json:"agent_host"` // Environment is the deployment environment tag (default: dev) diff --git a/internal/mcp/benchmark_test.go b/internal/mcp/benchmark_test.go index cf0791d..44b3024 100644 --- a/internal/mcp/benchmark_test.go +++ b/internal/mcp/benchmark_test.go @@ -18,7 +18,7 @@ import ( func BenchmarkServer_Creation(b *testing.B) { // Setup toolsets once — benchmark only NewServer + tool registration. tmpDir := b.TempDir() - pathVal, err := security.NewPath([]string{tmpDir}) + pathVal, err := security.NewPath([]string{tmpDir}, nil) if err != nil { b.Fatalf("creating path validator: %v", err) } @@ -167,7 +167,7 @@ func BenchmarkReadFileInput_Parse(b *testing.B) { // BenchmarkConfig_Validation benchmarks Config validation. func BenchmarkConfig_Validation(b *testing.B) { tmpDir := b.TempDir() - pathVal, err := security.NewPath([]string{tmpDir}) + pathVal, err := security.NewPath([]string{tmpDir}, nil) if err != nil { b.Fatalf("creating path validator: %v", err) } diff --git a/internal/mcp/integration_test.go b/internal/mcp/integration_test.go index 3d64418..1c71551 100644 --- a/internal/mcp/integration_test.go +++ b/internal/mcp/integration_test.go @@ -26,7 +26,7 @@ func createIntegrationTestConfig(t *testing.T, name string) Config { t.Fatalf("EvalSymlinks(%q) unexpected error: %v", tmpDir, err) } - pathVal, err := security.NewPath([]string{realTmpDir}) + pathVal, err := security.NewPath([]string{realTmpDir}, nil) if err != nil { t.Fatalf("security.NewPath(%q) unexpected error: %v", realTmpDir, err) } diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index d066f45..aa7008d 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -53,7 +53,7 @@ func newTestHelper(t *testing.T) *testHelper { func (h *testHelper) createFile() *tools.File { h.t.Helper() - pathVal, err := security.NewPath([]string{h.tempDir}) + pathVal, err := security.NewPath([]string{h.tempDir}, nil) if err != nil { h.t.Fatalf("creating path validator: %v", err) } diff --git a/internal/security/doc.go b/internal/security/doc.go index 8e17ca3..188ddf2 100644 --- a/internal/security/doc.go +++ b/internal/security/doc.go @@ -16,7 +16,7 @@ // Path Validator: Prevents directory traversal and ensures file operations // stay within allowed boundaries. // -// pathValidator, err := security.NewPath([]string{"/safe/dir"}) +// pathValidator, err := security.NewPath([]string{"/safe/dir"}, nil) // if _, err := pathValidator.Validate(userInput); err != nil { // return fmt.Errorf("invalid path: %w", err) // } @@ -67,7 +67,7 @@ // # Integration Example // // // Create validators -// pathVal, _ := security.NewPath([]string{workDir}) +// pathVal, _ := security.NewPath([]string{workDir}, nil) // cmdVal := security.NewCommand() // urlVal := security.NewURL() // envVal := security.NewEnv() diff --git a/internal/security/fuzz_test.go b/internal/security/fuzz_test.go index 226487f..9e410b3 100644 --- a/internal/security/fuzz_test.go +++ b/internal/security/fuzz_test.go @@ -70,7 +70,7 @@ func FuzzPathValidation(f *testing.F) { // Create a validator with a safe temporary directory tmpDir := f.TempDir() - validator, err := NewPath([]string{tmpDir}) + validator, err := NewPath([]string{tmpDir}, nil) if err != nil { f.Fatalf("creating validator: %v", err) } @@ -130,7 +130,7 @@ func FuzzPathValidationWithSymlinks(f *testing.F) { } tmpDir := t.TempDir() - validator, err := NewPath([]string{tmpDir}) + validator, err := NewPath([]string{tmpDir}, nil) if err != nil { t.Skipf("creating validator: %v", err) } diff --git a/internal/security/path.go b/internal/security/path.go index 8909056..3437465 100644 --- a/internal/security/path.go +++ b/internal/security/path.go @@ -20,18 +20,25 @@ var ( // ErrPathNullByte indicates the path contains a null byte (CWE-626). ErrPathNullByte = errors.New("path contains null byte") + + // ErrPathDenied indicates the path matches a denied prefix (e.g. prompts/). + ErrPathDenied = errors.New("path is denied") ) // Path validates and sanitizes file paths to prevent traversal attacks. // Used to prevent path traversal attacks (CWE-22). type Path struct { - allowedDirs []string - workDir string + allowedDirs []string + deniedPrefixes []string // absolute paths that are always denied (case-insensitive on macOS HFS+) + workDir string } // NewPath creates a new Path validator. -// allowedDirs: list of allowed directories (empty list means only working directory is allowed) -func NewPath(allowedDirs []string) (*Path, error) { +// allowedDirs: list of allowed directories (empty list means only working directory is allowed). +// deniedPrefixes: list of directory prefixes that are always denied even if inside allowed dirs +// (e.g. "prompts/" to protect system prompt files). Compared case-insensitively on +// case-insensitive filesystems (macOS HFS+). +func NewPath(allowedDirs, deniedPrefixes []string) (*Path, error) { workDir, err := os.Getwd() if err != nil { return nil, fmt.Errorf("unable to get working directory: %w", err) @@ -47,12 +54,35 @@ func NewPath(allowedDirs []string) (*Path, error) { absAllowedDirs = append(absAllowedDirs, absDir) } + // Convert denied prefixes to absolute paths + absDenied := make([]string, 0, len(deniedPrefixes)) + for _, prefix := range deniedPrefixes { + absPrefix, err := filepath.Abs(prefix) + if err != nil { + return nil, fmt.Errorf("unable to resolve denied prefix %s: %w", prefix, err) + } + absDenied = append(absDenied, absPrefix) + } + return &Path{ - allowedDirs: absAllowedDirs, - workDir: workDir, + allowedDirs: absAllowedDirs, + deniedPrefixes: absDenied, + workDir: workDir, }, nil } +// isPathDenied checks if a path matches any denied prefix. +// Uses case-insensitive comparison to handle case-insensitive filesystems (macOS HFS+). +func (v *Path) isPathDenied(absPath string) bool { + for _, denied := range v.deniedPrefixes { + deniedWithSep := filepath.Clean(denied) + string(filepath.Separator) + if strings.EqualFold(absPath, denied) || strings.HasPrefix(strings.ToLower(absPath+string(filepath.Separator)), strings.ToLower(deniedWithSep)) { + return true + } + } + return false +} + // isPathInAllowedDirs checks if a path is within allowed directories // Returns true if path is in working directory or any allowed directory func (v *Path) isPathInAllowedDirs(absPath string) bool { @@ -126,6 +156,14 @@ func (v *Path) Validate(path string) (string, error) { return "", fmt.Errorf("%w: access denied", ErrPathOutsideAllowed) } + // 3b. Check if path matches a denied prefix (e.g. prompts/) + if v.isPathDenied(absPath) { + slog.Warn("path denied by prefix rule", + "path", absPath, + "security_event", "denied_prefix_access_attempt") + return "", fmt.Errorf("%w: access denied", ErrPathDenied) + } + // 4. Resolve symbolic links (prevent bypassing restrictions through symlinks) realPath, err := filepath.EvalSymlinks(absPath) if err != nil { diff --git a/internal/security/path_test.go b/internal/security/path_test.go index be8e75f..52b4b71 100644 --- a/internal/security/path_test.go +++ b/internal/security/path_test.go @@ -23,7 +23,7 @@ func TestPathValidation(t *testing.T) { } defer func() { _ = os.Chdir(workDir) }() // Restore original directory - validator, err := NewPath([]string{tmpDir}) + validator, err := NewPath([]string{tmpDir}, nil) if err != nil { t.Fatalf("creating path validator: %v", err) } @@ -75,7 +75,7 @@ func TestPathValidation(t *testing.T) { // TestPathErrorSanitization tests that error messages don't leak sensitive paths func TestPathErrorSanitization(t *testing.T) { - validator, err := NewPath([]string{}) + validator, err := NewPath([]string{}, nil) if err != nil { t.Fatalf("creating path validator: %v", err) } @@ -112,7 +112,7 @@ func TestSymlinkValidation(t *testing.T) { } defer func() { _ = os.Chdir(workDir) }() // Restore original directory - validator, err := NewPath([]string{tmpDir}) + validator, err := NewPath([]string{tmpDir}, nil) if err != nil { t.Fatalf("creating path validator: %v", err) } @@ -158,7 +158,7 @@ func TestPathValidationWithNonExistentFile(t *testing.T) { } defer func() { _ = os.Chdir(workDir) }() - validator, err := NewPath([]string{tmpDir}) + validator, err := NewPath([]string{tmpDir}, nil) if err != nil { t.Fatalf("creating path validator: %v", err) } @@ -194,7 +194,7 @@ func TestSymlinkBypassAttempt(t *testing.T) { } defer func() { _ = os.Chdir(workDir) }() - validator, err := NewPath([]string{tmpDir}) + validator, err := NewPath([]string{tmpDir}, nil) if err != nil { t.Fatalf("creating path validator: %v", err) } @@ -229,7 +229,7 @@ func TestPathValidationErrors(t *testing.T) { } defer func() { _ = os.Chdir(workDir) }() - validator, err := NewPath([]string{tmpDir}) + validator, err := NewPath([]string{tmpDir}, nil) if err != nil { t.Fatalf("creating path validator: %v", err) } @@ -242,9 +242,108 @@ func TestPathValidationErrors(t *testing.T) { } } +// TestDeniedPrefixes tests that paths matching denied prefixes are blocked. +func TestDeniedPrefixes(t *testing.T) { + tmpDir := t.TempDir() + workDir, err := os.Getwd() + if err != nil { + t.Fatalf("getting working directory: %v", err) + } + + if err := os.Chdir(tmpDir); err != nil { + t.Fatalf("changing to temp directory: %v", err) + } + defer func() { _ = os.Chdir(workDir) }() + + // Create a "prompts" directory inside tmpDir + promptsDir := filepath.Join(tmpDir, "prompts") + if err := os.MkdirAll(promptsDir, 0o750); err != nil { + t.Fatalf("creating prompts directory: %v", err) + } + if err := os.WriteFile(filepath.Join(promptsDir, "system.prompt"), []byte("secret"), 0o600); err != nil { + t.Fatalf("creating prompt file: %v", err) + } + + validator, err := NewPath([]string{tmpDir}, []string{filepath.Join(tmpDir, "prompts")}) + if err != nil { + t.Fatalf("creating path validator: %v", err) + } + + tests := []struct { + name string + path string + wantErr error + }{ + { + name: "allowed file outside denied prefix", + path: filepath.Join(tmpDir, "allowed.txt"), + wantErr: nil, + }, + { + name: "file inside denied prefix", + path: filepath.Join(promptsDir, "system.prompt"), + wantErr: ErrPathDenied, + }, + { + name: "denied directory itself", + path: promptsDir, + wantErr: ErrPathDenied, + }, + { + name: "case-insensitive denied prefix", + path: filepath.Join(tmpDir, "Prompts", "system.prompt"), + wantErr: ErrPathDenied, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := validator.Validate(tt.path) + if tt.wantErr == nil { + if err != nil { + t.Errorf("Validate(%q) unexpected error: %v", tt.path, err) + } + return + } + if err == nil { + t.Fatalf("Validate(%q) = nil, want %v", tt.path, tt.wantErr) + } + if !errors.Is(err, tt.wantErr) { + t.Errorf("Validate(%q) error = %v, want %v", tt.path, err, tt.wantErr) + } + }) + } +} + +// TestDeniedPrefixes_NilSlice verifies that nil deniedPrefixes works correctly. +func TestDeniedPrefixes_NilSlice(t *testing.T) { + tmpDir := t.TempDir() + workDir, err := os.Getwd() + if err != nil { + t.Fatalf("getting working directory: %v", err) + } + + if err := os.Chdir(tmpDir); err != nil { + t.Fatalf("changing to temp directory: %v", err) + } + defer func() { _ = os.Chdir(workDir) }() + + validator, err := NewPath([]string{tmpDir}, nil) + if err != nil { + t.Fatalf("creating path validator: %v", err) + } + + // With nil denied prefixes, all paths within allowed dirs should work + path := filepath.Join(tmpDir, "prompts", "system.prompt") + _, err = validator.Validate(path) + if err != nil { + t.Errorf("Validate(%q) with nil deniedPrefixes error = %v, want nil", path, err) + } +} + // BenchmarkPathValidation benchmarks path validation performance func BenchmarkPathValidation(b *testing.B) { - validator, err := NewPath([]string{}) + validator, err := NewPath([]string{}, nil) if err != nil { b.Fatalf("creating path validator: %v", err) } diff --git a/internal/tools/doc.go b/internal/tools/doc.go index b29b97b..6f5213f 100644 --- a/internal/tools/doc.go +++ b/internal/tools/doc.go @@ -99,7 +99,7 @@ // # Usage Example // // // Create tools with security validators -// pathVal, _ := security.NewPath([]string{"/allowed/path"}) +// pathVal, _ := security.NewPath([]string{"/allowed/path"}, nil) // fileTools, err := tools.NewFile(pathVal, logger) // if err != nil { // return err diff --git a/internal/tools/file_integration_test.go b/internal/tools/file_integration_test.go index ee0b204..b102402 100644 --- a/internal/tools/file_integration_test.go +++ b/internal/tools/file_integration_test.go @@ -28,7 +28,7 @@ func newfileTools(t *testing.T) *fileTools { func (h *fileTools) createFile() *File { h.t.Helper() - pathVal, err := security.NewPath([]string{h.tempDir}) + pathVal, err := security.NewPath([]string{h.tempDir}, nil) if err != nil { h.t.Fatalf("creating path validator: %v", err) } diff --git a/internal/tools/file_test.go b/internal/tools/file_test.go index 33ed506..7f95739 100644 --- a/internal/tools/file_test.go +++ b/internal/tools/file_test.go @@ -8,7 +8,7 @@ import ( func TestFile_Constructor(t *testing.T) { t.Run("valid inputs", func(t *testing.T) { - pathVal, err := security.NewPath([]string{"/tmp"}) + pathVal, err := security.NewPath([]string{"/tmp"}, nil) if err != nil { t.Fatalf("creating path validator: %v", err) } @@ -33,7 +33,7 @@ func TestFile_Constructor(t *testing.T) { }) t.Run("nil logger", func(t *testing.T) { - pathVal, err := security.NewPath([]string{"/tmp"}) + pathVal, err := security.NewPath([]string{"/tmp"}, nil) if err != nil { t.Fatalf("creating path validator: %v", err) } diff --git a/internal/tools/fuzz_test.go b/internal/tools/fuzz_test.go index 3fb4026..9f3a0cf 100644 --- a/internal/tools/fuzz_test.go +++ b/internal/tools/fuzz_test.go @@ -93,7 +93,7 @@ func FuzzPathTraversal(f *testing.F) { f.Fuzz(func(t *testing.T, path string) { // Create validator with /tmp as allowed base - validator, err := security.NewPath([]string{"/tmp"}) + validator, err := security.NewPath([]string{"/tmp"}, nil) if err != nil { t.Skip("could not create validator") } diff --git a/internal/tools/knowledge.go b/internal/tools/knowledge.go index f1bce84..361435e 100644 --- a/internal/tools/knowledge.go +++ b/internal/tools/knowledge.go @@ -39,6 +39,10 @@ const ( MaxTopK = 10 ) +// MaxKnowledgeContentSize is the maximum allowed content size for knowledge_store (50KB). +// Prevents DoS via large document ingestion and embedding computation. +const MaxKnowledgeContentSize = 50 * 1024 + // KnowledgeSearchInput defines input for all knowledge search tools. // The default TopK varies by tool: history=3, documents=5, system=3. type KnowledgeSearchInput struct { @@ -265,6 +269,15 @@ func (k *Knowledge) StoreKnowledge(ctx *ai.ToolContext, input KnowledgeStoreInpu }, }, nil } + if len(input.Content) > MaxKnowledgeContentSize { + return Result{ + Status: StatusError, + Error: &Error{ + Code: ErrCodeValidation, + Message: fmt.Sprintf("content size %d exceeds maximum %d bytes", len(input.Content), MaxKnowledgeContentSize), + }, + }, nil + } // Generate a deterministic document ID from the title using SHA-256. // Changing the title creates a new document; the old entry remains. diff --git a/internal/tools/knowledge_test.go b/internal/tools/knowledge_test.go index 4500edf..b3779ad 100644 --- a/internal/tools/knowledge_test.go +++ b/internal/tools/knowledge_test.go @@ -3,10 +3,12 @@ package tools import ( "context" "log/slog" + "strings" "testing" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/plugins/postgresql" ) // mockRetriever is a minimal ai.Retriever implementation for testing. @@ -110,3 +112,80 @@ func TestValidSourceTypes(t *testing.T) { } } } + +func TestStoreKnowledge_Validation(t *testing.T) { + // knowledgeWithDocStore creates a Knowledge instance with a non-nil docStore + // for testing validation paths. The zero-value DocStore is safe because + // all test cases trigger validation errors before docStore.Index is called. + knowledgeWithDocStore := &Knowledge{ + retriever: &mockRetriever{}, + docStore: &postgresql.DocStore{}, + logger: slog.New(slog.DiscardHandler), + } + + knowledgeNilDocStore, err := NewKnowledge(&mockRetriever{}, nil, slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatalf("NewKnowledge() unexpected error: %v", err) + } + + tests := []struct { + name string + kt *Knowledge + input KnowledgeStoreInput + wantCode ErrorCode + wantInMsg string + }{ + { + name: "nil docStore returns not available", + kt: knowledgeNilDocStore, + input: KnowledgeStoreInput{Title: "t", Content: "c"}, + wantCode: ErrCodeExecution, + wantInMsg: "not available", + }, + { + name: "empty title", + kt: knowledgeWithDocStore, + input: KnowledgeStoreInput{Title: "", Content: "c"}, + wantCode: ErrCodeValidation, + wantInMsg: "title is required", + }, + { + name: "empty content", + kt: knowledgeWithDocStore, + input: KnowledgeStoreInput{Title: "t", Content: ""}, + wantCode: ErrCodeValidation, + wantInMsg: "content is required", + }, + { + name: "content exceeds maximum size", + kt: knowledgeWithDocStore, + input: KnowledgeStoreInput{ + Title: "large doc", + Content: strings.Repeat("x", MaxKnowledgeContentSize+1), + }, + wantCode: ErrCodeValidation, + wantInMsg: "exceeds maximum", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.kt.StoreKnowledge(nil, tt.input) + if err != nil { + t.Fatalf("StoreKnowledge() unexpected error: %v", err) + } + if result.Status != StatusError { + t.Fatalf("StoreKnowledge() status = %q, want %q", result.Status, StatusError) + } + if result.Error == nil { + t.Fatal("StoreKnowledge() error field is nil, want non-nil") + } + if result.Error.Code != tt.wantCode { + t.Errorf("StoreKnowledge() error code = %q, want %q", result.Error.Code, tt.wantCode) + } + if !strings.Contains(result.Error.Message, tt.wantInMsg) { + t.Errorf("StoreKnowledge() error message = %q, want to contain %q", result.Error.Message, tt.wantInMsg) + } + }) + } +} diff --git a/internal/tools/register_test.go b/internal/tools/register_test.go index abb2bb7..fddd5ac 100644 --- a/internal/tools/register_test.go +++ b/internal/tools/register_test.go @@ -22,7 +22,7 @@ func TestNewFile(t *testing.T) { t.Run("successful creation", func(t *testing.T) { t.Parallel() - pathVal, err := security.NewPath([]string{}) + pathVal, err := security.NewPath([]string{}, nil) if err != nil { t.Fatalf("NewPath() unexpected error: %v", err) } @@ -53,7 +53,7 @@ func TestNewFile(t *testing.T) { t.Run("nil logger", func(t *testing.T) { t.Parallel() - pathVal, err := security.NewPath([]string{}) + pathVal, err := security.NewPath([]string{}, nil) if err != nil { t.Fatalf("NewPath() unexpected error: %v", err) } @@ -77,7 +77,7 @@ func TestRegisterFile(t *testing.T) { t.Run("successful registration", func(t *testing.T) { t.Parallel() g := setupTestGenkit(t) - pathVal, err := security.NewPath([]string{}) + pathVal, err := security.NewPath([]string{}, nil) if err != nil { t.Fatalf("NewPath() unexpected error: %v", err) } @@ -105,7 +105,7 @@ func TestRegisterFile(t *testing.T) { t.Run("nil genkit", func(t *testing.T) { t.Parallel() - pathVal, err := security.NewPath([]string{}) + pathVal, err := security.NewPath([]string{}, nil) if err != nil { t.Fatalf("NewPath() unexpected error: %v", err) } diff --git a/internal/tui/setup_test.go b/internal/tui/setup_test.go index 9dbfb63..d054f57 100644 --- a/internal/tui/setup_test.go +++ b/internal/tui/setup_test.go @@ -144,7 +144,7 @@ func setupChatFlow(t *testing.T) (*chatFlowSetup, func()) { queries := sqlc.New(pool) sessionStore := session.New(queries, pool, logger) - pathValidator, err := security.NewPath([]string{"."}) + pathValidator, err := security.NewPath([]string{"."}, nil) if err != nil { pool.Close() cancel()