Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions internal/api/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ const sseTimeout = 5 * time.Minute
// titleMaxLength is the maximum rune length for a fallback session title.
const titleMaxLength = 50

// maxRequestBodySize is the maximum allowed HTTP request body size (1 MB).
const maxRequestBodySize = 1 << 20

// maxChatContentLength is the maximum allowed chat message length in bytes (~8K tokens).
const maxChatContentLength = 32_000

// Tool display info for JSON SSE events.
type toolDisplayInfo struct {
StartMsg string
Expand Down Expand Up @@ -69,11 +75,18 @@ type chatHandler struct {

// send handles POST /api/v1/chat — accepts JSON, sends message to chat flow.
func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodySize)

var req struct {
Content string `json:"content"`
SessionID string `json:"sessionId"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
var maxBytesErr *http.MaxBytesError
if errors.As(err, &maxBytesErr) {
WriteError(w, http.StatusRequestEntityTooLarge, "body_too_large", "request body too large", h.logger)
return
}
WriteError(w, http.StatusBadRequest, "invalid_json", "invalid request body", h.logger)
return
}
Expand All @@ -84,6 +97,11 @@ func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) {
return
}

if len(content) > maxChatContentLength {
WriteError(w, http.StatusRequestEntityTooLarge, "content_too_long", "message content exceeds maximum length", h.logger)
return
}

if req.SessionID == "" {
WriteError(w, http.StatusBadRequest, "session_required", "sessionId is required", h.logger)
return
Expand All @@ -96,6 +114,7 @@ func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) {
}

if !h.sessionAccessAllowed(r, sessionID) {
h.logger.Warn("session access denied", "sessionId", req.SessionID, "path", r.URL.Path)
WriteError(w, http.StatusForbidden, "forbidden", "session access denied", h.logger)
return
}
Expand Down Expand Up @@ -149,13 +168,19 @@ func (h *chatHandler) stream(w http.ResponseWriter, r *http.Request) {
return
}

if len(query) > maxChatContentLength {
WriteError(w, http.StatusRequestEntityTooLarge, "content_too_long", "query exceeds maximum length", h.logger)
return
}

parsedID, err := uuid.Parse(sessionID)
if err != nil {
WriteError(w, http.StatusBadRequest, "invalid_session", "invalid session ID", h.logger)
return
}

if !h.sessionAccessAllowed(r, parsedID) {
h.logger.Warn("session access denied", "sessionId", sessionID, "msgId", msgID, "path", r.URL.Path)
WriteError(w, http.StatusForbidden, "forbidden", "session access denied", h.logger)
return
}
Expand Down
67 changes: 67 additions & 0 deletions internal/api/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,73 @@ func TestChatSend_InvalidJSON(t *testing.T) {
}
}

func TestChatSend_BodyTooLarge(t *testing.T) {
// Create a valid JSON body larger than maxRequestBodySize (1 MB).
// The content field must be large enough so the whole JSON exceeds the limit.
largeContent := strings.Repeat("x", maxRequestBodySize)
body, _ := json.Marshal(map[string]string{
"content": largeContent,
"sessionId": uuid.New().String(),
})

w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body))

newTestChatHandler().send(w, r)

if w.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("send(>1MB body) status = %d, want %d\nbody: %s", w.Code, http.StatusRequestEntityTooLarge, w.Body.String())
}

errResp := decodeErrorEnvelope(t, w)
if errResp.Code != "body_too_large" {
t.Errorf("send(>1MB body) code = %q, want %q", errResp.Code, "body_too_large")
}
}
Comment on lines +199 to +221

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Security/Robustness:
The test TestChatSend_BodyTooLarge verifies the status code and error code, but does not assert that the response body is free from sensitive or unexpected information. To strengthen the test, add assertions to ensure the response body does not contain any partial or leaked data.

Recommended solution:

if strings.Contains(w.Body.String(), "hello") {
    t.Error("send(>1MB body) response should not contain partial content")
}


func TestChatSend_ContentTooLong(t *testing.T) {
// Create content that exceeds maxChatContentLength (32K)
longContent := strings.Repeat("x", maxChatContentLength+1)
body, _ := json.Marshal(map[string]string{
"content": longContent,
"sessionId": uuid.New().String(),
})

w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body))

newTestChatHandler().send(w, r)

if w.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("send(>32K content) status = %d, want %d\nbody: %s", w.Code, http.StatusRequestEntityTooLarge, w.Body.String())
}

errResp := decodeErrorEnvelope(t, w)
if errResp.Code != "content_too_long" {
t.Errorf("send(>32K content) code = %q, want %q", errResp.Code, "content_too_long")
}
}

func TestStream_QueryTooLong(t *testing.T) {
ch := newTestChatHandler()
sessionID := uuid.New()
longQuery := strings.Repeat("x", maxChatContentLength+1)

w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query="+url.QueryEscape(longQuery), nil)

ch.stream(w, r)

if w.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("stream(>32K query) status = %d, want %d\nbody: %s", w.Code, http.StatusRequestEntityTooLarge, w.Body.String())
}

errResp := decodeErrorEnvelope(t, w)
if errResp.Code != "content_too_long" {
t.Errorf("stream(>32K query) code = %q, want %q", errResp.Code, "content_too_long")
}
}

func TestChatSend_OwnershipDenied(t *testing.T) {
body, _ := json.Marshal(map[string]string{
"content": "hello",
Expand Down
6 changes: 4 additions & 2 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,15 @@ func NewServer(cfg ServerConfig) (*Server, error) {
}
rl := newRateLimiter(1.0, burst)

// Build middleware stack: Recovery → Logging → RateLimit → CORS → User → Session → CSRF → Routes
// Build middleware stack (outermost first):
// Recovery → Logging → CORS → RateLimit → User → Session → CSRF → Routes
// CORS must be before RateLimit so preflight OPTIONS gets proper CORS headers.
var handler http.Handler = mux
handler = csrfMiddleware(sm, logger)(handler)
handler = sessionMiddleware(sm)(handler)
handler = userMiddleware(sm)(handler)
handler = corsMiddleware(cfg.CORSOrigins)(handler)
handler = rateLimitMiddleware(rl, cfg.TrustProxy, logger)(handler)
handler = corsMiddleware(cfg.CORSOrigins)(handler)
handler = loggingMiddleware(logger)(handler)
Comment on lines +89 to 90

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Middleware Order: CORS and Logging

The CORS middleware is applied before the logging middleware. This means that CORS preflight requests (OPTIONS) may not be logged, potentially omitting important request traces for debugging or auditing. Consider placing the logging middleware as the outermost layer (before CORS) to ensure all requests, including preflight, are logged:

handler = loggingMiddleware(logger)(handler)
handler = corsMiddleware(cfg.CORSOrigins)(handler)

This change will improve observability without affecting CORS behavior.

handler = recoveryMiddleware(logger)(handler)

Expand Down
18 changes: 14 additions & 4 deletions internal/api/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,14 @@ func (sm *sessionManager) CheckCSRF(userID, token string) error {
message := fmt.Sprintf("%s:%d", userID, timestamp)
h := hmac.New(sha256.New, sm.hmacSecret)
h.Write([]byte(message))
expectedSig := base64.URLEncoding.EncodeToString(h.Sum(nil))
expectedSig := h.Sum(nil)

if subtle.ConstantTimeCompare([]byte(parts[1]), []byte(expectedSig)) != 1 {
actualSig, err := base64.URLEncoding.DecodeString(parts[1])
if err != nil {
return ErrCSRFMalformed
}

if subtle.ConstantTimeCompare(actualSig, expectedSig) != 1 {
return ErrCSRFInvalid
Comment on lines +130 to 131

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential ambiguity in signature length handling:

The comparison subtle.ConstantTimeCompare(actualSig, expectedSig) will return 0 if the lengths differ, but this is implicit. For clarity and maintainability, consider explicitly checking that len(actualSig) == len(expectedSig) before performing the comparison. This makes the logic clearer and avoids confusion for future maintainers.

Recommended solution:

if len(actualSig) != len(expectedSig) || subtle.ConstantTimeCompare(actualSig, expectedSig) != 1 {
    return ErrCSRFInvalid
}

}

Expand Down Expand Up @@ -176,9 +181,14 @@ func (sm *sessionManager) CheckPreSessionCSRF(token string) error {
message := fmt.Sprintf("%s:%d", nonce, timestamp)
h := hmac.New(sha256.New, sm.hmacSecret)
h.Write([]byte(message))
expectedSig := base64.URLEncoding.EncodeToString(h.Sum(nil))
expectedSig := h.Sum(nil)

actualSig, err := base64.URLEncoding.DecodeString(parts[2])
if err != nil {
return ErrCSRFMalformed
}

if subtle.ConstantTimeCompare([]byte(parts[2]), []byte(expectedSig)) != 1 {
if subtle.ConstantTimeCompare(actualSig, expectedSig) != 1 {
return ErrCSRFInvalid
Comment on lines +191 to 192

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential ambiguity in signature length handling:

As in the user-bound CSRF check, subtle.ConstantTimeCompare(actualSig, expectedSig) will return 0 if the lengths differ, but this is implicit. For clarity and maintainability, explicitly check len(actualSig) == len(expectedSig) before comparison.

Recommended solution:

if len(actualSig) != len(expectedSig) || subtle.ConstantTimeCompare(actualSig, expectedSig) != 1 {
    return ErrCSRFInvalid
}

}

Expand Down
2 changes: 1 addition & 1 deletion internal/app/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestApp_Fields(t *testing.T) {
defer cancel()

g := genkit.Init(ctx)
pathValidator, err := security.NewPath([]string{"."})
pathValidator, err := security.NewPath([]string{"."}, nil)
if err != nil {
t.Fatalf("security.NewPath() error: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion internal/app/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ func provideSessionStore(pool *pgxpool.Pool) *session.Store {
}

// providePathValidator creates a path validator instance.
// Denies access to prompts/ to protect system prompt files from tool-based access.
func providePathValidator() (*security.Path, error) {
return security.NewPath([]string{"."})
return security.NewPath([]string{"."}, []string{"prompts"})
}

// provideTools creates toolsets, registers them with Genkit, and stores both
Expand Down
2 changes: 1 addition & 1 deletion internal/chat/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func SetupTest(t *testing.T) *TestFramework {
}

// Create toolsets
pathValidator, err := security.NewPath([]string{os.TempDir()})
pathValidator, err := security.NewPath([]string{os.TempDir()}, nil)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test Flexibility Limitation:
Using only os.TempDir() as the allowed path for the path validator may restrict test scenarios and lead to brittle tests if the environment changes. If tests require access to other directories, this configuration will cause failures. Consider parameterizing the allowed paths or making them configurable for broader test coverage and reliability.

Recommended Solution:
Allow the test to specify additional allowed paths, or use a configuration variable:

allowedPaths := []string{os.TempDir(), "/tmp", "/var/tmp"} // Example
pathValidator, err := security.NewPath(allowedPaths, nil)

if err != nil {
t.Fatalf("creating path validator: %v", err)
}
Comment on lines +115 to 118

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error Handling and Test Robustness:
The error handling for security.NewPath is abrupt and does not provide actionable context if the path validator fails. If the failure is due to environmental issues (e.g., os.TempDir() is not writable or accessible), the test will simply fail without guidance. Consider enhancing the error message to include the value of os.TempDir() and possible remediation steps, such as checking directory permissions or environment configuration.

Recommended Solution:

if err != nil {
    t.Fatalf("creating path validator for temp dir '%s': %v. Ensure the directory is writable and accessible.", os.TempDir(), err)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/config/observability.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
// Setup is inlined in internal/app/setup.go (provideOtelShutdown).
type DatadogConfig struct {
// APIKey is the Datadog API key (optional, for observability)
APIKey string `mapstructure:"api_key" json:"api_key"`
APIKey string `mapstructure:"api_key" json:"api_key"` // #nosec G117 -- masked in MarshalJSON, never serialized in plain text
// AgentHost is the Datadog Agent OTLP endpoint (default: localhost:4318)
AgentHost string `mapstructure:"agent_host" json:"agent_host"`
// Environment is the deployment environment tag (default: dev)
Expand Down
4 changes: 2 additions & 2 deletions internal/mcp/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
func BenchmarkServer_Creation(b *testing.B) {
// Setup toolsets once — benchmark only NewServer + tool registration.
tmpDir := b.TempDir()
pathVal, err := security.NewPath([]string{tmpDir})
pathVal, err := security.NewPath([]string{tmpDir}, nil)
if err != nil {
b.Fatalf("creating path validator: %v", err)
}
Expand Down Expand Up @@ -167,7 +167,7 @@ func BenchmarkReadFileInput_Parse(b *testing.B) {
// BenchmarkConfig_Validation benchmarks Config validation.
func BenchmarkConfig_Validation(b *testing.B) {
tmpDir := b.TempDir()
pathVal, err := security.NewPath([]string{tmpDir})
pathVal, err := security.NewPath([]string{tmpDir}, nil)
if err != nil {
b.Fatalf("creating path validator: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/mcp/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func createIntegrationTestConfig(t *testing.T, name string) Config {
t.Fatalf("EvalSymlinks(%q) unexpected error: %v", tmpDir, err)
}

pathVal, err := security.NewPath([]string{realTmpDir})
pathVal, err := security.NewPath([]string{realTmpDir}, nil)
if err != nil {
t.Fatalf("security.NewPath(%q) unexpected error: %v", realTmpDir, err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func newTestHelper(t *testing.T) *testHelper {

func (h *testHelper) createFile() *tools.File {
h.t.Helper()
pathVal, err := security.NewPath([]string{h.tempDir})
pathVal, err := security.NewPath([]string{h.tempDir}, nil)
if err != nil {
h.t.Fatalf("creating path validator: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/security/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// Path Validator: Prevents directory traversal and ensures file operations
// stay within allowed boundaries.
//
// pathValidator, err := security.NewPath([]string{"/safe/dir"})
// pathValidator, err := security.NewPath([]string{"/safe/dir"}, nil)
// if _, err := pathValidator.Validate(userInput); err != nil {
// return fmt.Errorf("invalid path: %w", err)
// }
Expand Down Expand Up @@ -67,7 +67,7 @@
// # Integration Example
//
// // Create validators
// pathVal, _ := security.NewPath([]string{workDir})
// pathVal, _ := security.NewPath([]string{workDir}, nil)
// cmdVal := security.NewCommand()
// urlVal := security.NewURL()
// envVal := security.NewEnv()
Comment on lines +70 to 73

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error Handling Risk:

The integration example discards errors returned by validator constructors (e.g., pathVal, _ := security.NewPath(...)). This can result in silent failures if a validator fails to initialize, potentially leaving the system unprotected. Always check and handle errors explicitly to ensure that security validators are correctly instantiated:

pathVal, err := security.NewPath([]string{workDir}, nil)
if err != nil {
    // handle error (e.g., log and abort initialization)
}

This approach ensures that initialization failures are detected and handled, maintaining the integrity of the security layer.

Expand Down
4 changes: 2 additions & 2 deletions internal/security/fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func FuzzPathValidation(f *testing.F) {

// Create a validator with a safe temporary directory
tmpDir := f.TempDir()
validator, err := NewPath([]string{tmpDir})
validator, err := NewPath([]string{tmpDir}, nil)
if err != nil {
f.Fatalf("creating validator: %v", err)
}
Expand Down Expand Up @@ -130,7 +130,7 @@ func FuzzPathValidationWithSymlinks(f *testing.F) {
}

tmpDir := t.TempDir()
validator, err := NewPath([]string{tmpDir})
validator, err := NewPath([]string{tmpDir}, nil)
if err != nil {
t.Skipf("creating validator: %v", err)
}
Expand Down
50 changes: 44 additions & 6 deletions internal/security/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,25 @@ var (

// ErrPathNullByte indicates the path contains a null byte (CWE-626).
ErrPathNullByte = errors.New("path contains null byte")

// ErrPathDenied indicates the path matches a denied prefix (e.g. prompts/).
ErrPathDenied = errors.New("path is denied")
)

// Path validates and sanitizes file paths to prevent traversal attacks.
// Used to prevent path traversal attacks (CWE-22).
type Path struct {
allowedDirs []string
workDir string
allowedDirs []string
deniedPrefixes []string // absolute paths that are always denied (case-insensitive on macOS HFS+)
workDir string
}

// NewPath creates a new Path validator.
// allowedDirs: list of allowed directories (empty list means only working directory is allowed)
func NewPath(allowedDirs []string) (*Path, error) {
// allowedDirs: list of allowed directories (empty list means only working directory is allowed).
// deniedPrefixes: list of directory prefixes that are always denied even if inside allowed dirs
// (e.g. "prompts/" to protect system prompt files). Compared case-insensitively on
// case-insensitive filesystems (macOS HFS+).
func NewPath(allowedDirs, deniedPrefixes []string) (*Path, error) {
workDir, err := os.Getwd()
if err != nil {
return nil, fmt.Errorf("unable to get working directory: %w", err)
Expand All @@ -47,12 +54,35 @@ func NewPath(allowedDirs []string) (*Path, error) {
absAllowedDirs = append(absAllowedDirs, absDir)
}

// Convert denied prefixes to absolute paths
absDenied := make([]string, 0, len(deniedPrefixes))
for _, prefix := range deniedPrefixes {
absPrefix, err := filepath.Abs(prefix)
if err != nil {
return nil, fmt.Errorf("unable to resolve denied prefix %s: %w", prefix, err)
}
absDenied = append(absDenied, absPrefix)
}

return &Path{
allowedDirs: absAllowedDirs,
workDir: workDir,
allowedDirs: absAllowedDirs,
deniedPrefixes: absDenied,
workDir: workDir,
}, nil
}

// isPathDenied checks if a path matches any denied prefix.
// Uses case-insensitive comparison to handle case-insensitive filesystems (macOS HFS+).
func (v *Path) isPathDenied(absPath string) bool {
for _, denied := range v.deniedPrefixes {
deniedWithSep := filepath.Clean(denied) + string(filepath.Separator)
if strings.EqualFold(absPath, denied) || strings.HasPrefix(strings.ToLower(absPath+string(filepath.Separator)), strings.ToLower(deniedWithSep)) {
return true
}
}
return false
}

// isPathInAllowedDirs checks if a path is within allowed directories
// Returns true if path is in working directory or any allowed directory
func (v *Path) isPathInAllowedDirs(absPath string) bool {
Comment on lines 54 to 88

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential directory containment flaw in isPathInAllowedDirs

The use of strings.HasPrefix to check if a path is within an allowed directory can be unsafe if directory names are prefixes of others (e.g., /tmp/foo and /tmp/foobar). This could allow unintended access:

if strings.HasPrefix(absPathWithSep, dirNorm) || absPath == dir {
    return true
}

Recommendation: Use filepath.Rel(dir, absPath) and check that the result does not start with .. or is not equal to ... This provides a more robust containment check and prevents directory traversal attacks.

Expand Down Expand Up @@ -126,6 +156,14 @@ func (v *Path) Validate(path string) (string, error) {
return "", fmt.Errorf("%w: access denied", ErrPathOutsideAllowed)
}

// 3b. Check if path matches a denied prefix (e.g. prompts/)
if v.isPathDenied(absPath) {
slog.Warn("path denied by prefix rule",
"path", absPath,
"security_event", "denied_prefix_access_attempt")
return "", fmt.Errorf("%w: access denied", ErrPathDenied)
}

// 4. Resolve symbolic links (prevent bypassing restrictions through symlinks)
realPath, err := filepath.EvalSymlinks(absPath)
if err != nil {
Comment on lines 156 to 169

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Denied prefix check missing after symlink resolution

After resolving symlinks with filepath.EvalSymlinks, the code checks if the resolved path is within allowed directories, but does not re-check denied prefixes. This could allow a symlink to bypass denied prefix restrictions:

if realPath != absPath {
    if !v.isPathInAllowedDirs(realPath) {
        // ...
        return "", fmt.Errorf("%w: access denied", ErrSymlinkOutsideAllowed)
    }
    absPath = realPath
}

Recommendation: After symlink resolution, also check isPathDenied(realPath) and return an error if the resolved path matches a denied prefix.

Expand Down
Loading
Loading