diff --git a/cmd/serve.go b/cmd/serve.go index 2e19997..8337077 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -6,7 +6,9 @@ import ( "fmt" "log/slog" "net/http" + "os" "os/signal" + "strconv" "syscall" "time" @@ -16,6 +18,20 @@ import ( "github.com/koopa0/koopa/internal/config" ) +// parseRateBurst reads KOOPA_RATE_BURST from the environment. +// Returns 0 (use default) if unset or invalid. +func parseRateBurst() int { + v := os.Getenv("KOOPA_RATE_BURST") + if v == "" { + return 0 + } + n, err := strconv.Atoi(v) + if err != nil || n < 0 { + return 0 + } + return n +} + // Server timeout configuration. const ( readHeaderTimeout = 10 * time.Second @@ -72,6 +88,7 @@ func runServe() error { CORSOrigins: cfg.CORSOrigins, IsDev: cfg.PostgresSSLMode == "disable", TrustProxy: cfg.TrustProxy, + RateBurst: parseRateBurst(), }) if err != nil { return fmt.Errorf("creating API server: %w", err) diff --git a/db/migrations/000002_add_owner_id.down.sql b/db/migrations/000002_add_owner_id.down.sql new file mode 100644 index 0000000..b7e331e --- /dev/null +++ b/db/migrations/000002_add_owner_id.down.sql @@ -0,0 +1,2 @@ +DROP INDEX IF EXISTS idx_sessions_owner_id; +ALTER TABLE sessions DROP COLUMN IF EXISTS owner_id; diff --git a/db/migrations/000002_add_owner_id.up.sql b/db/migrations/000002_add_owner_id.up.sql new file mode 100644 index 0000000..6f44b30 --- /dev/null +++ b/db/migrations/000002_add_owner_id.up.sql @@ -0,0 +1,6 @@ +-- Add owner_id to sessions for multi-session support. +-- Each session is owned by a user identified by a persistent uid cookie. +-- Existing sessions get empty owner_id (orphaned — invisible to new users). +ALTER TABLE sessions ADD COLUMN owner_id TEXT NOT NULL DEFAULT ''; + +CREATE INDEX idx_sessions_owner_id ON sessions(owner_id, updated_at DESC); diff --git a/db/queries/sessions.sql b/db/queries/sessions.sql index 849204b..722f955 100644 --- a/db/queries/sessions.sql +++ b/db/queries/sessions.sql @@ -2,22 +2,30 @@ -- Generated code will be in internal/sqlc/sessions.sql.go -- name: CreateSession :one -INSERT INTO sessions (title) -VALUES ($1) +INSERT INTO sessions (title, owner_id) +VALUES ($1, sqlc.arg(owner_id)) RETURNING *; -- name: Session :one -SELECT id, title, created_at, updated_at +SELECT id, title, owner_id, created_at, updated_at FROM sessions WHERE id = $1; -- name: Sessions :many -SELECT id, title, created_at, updated_at +SELECT id, title, owner_id, created_at, updated_at FROM sessions +WHERE owner_id = sqlc.arg(owner_id) ORDER BY updated_at DESC LIMIT sqlc.arg(result_limit) OFFSET sqlc.arg(result_offset); +-- name: SessionByIDAndOwner :one +-- Verify session exists and is owned by the given user. +-- Used for ownership checks without a separate query + comparison. +SELECT id, title, owner_id, created_at, updated_at +FROM sessions +WHERE id = sqlc.arg(session_id) AND owner_id = sqlc.arg(owner_id); + -- name: UpdateSessionUpdatedAt :exec UPDATE sessions SET updated_at = NOW() diff --git a/internal/api/chat.go b/internal/api/chat.go index 1035986..0e5a805 100644 --- a/internal/api/chat.go +++ b/internal/api/chat.go @@ -84,24 +84,20 @@ func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) { return } - // Resolve session from context (set by session middleware from cookie) - sessionID, ok := sessionIDFromContext(r.Context()) - if !ok { - WriteError(w, http.StatusBadRequest, "session_required", "session ID required", h.logger) + if req.SessionID == "" { + WriteError(w, http.StatusBadRequest, "session_required", "sessionId is required", h.logger) return } - // If body also specifies a session, verify it matches (defense-in-depth) - if req.SessionID != "" { - parsed, err := uuid.Parse(req.SessionID) - if err != nil { - WriteError(w, http.StatusBadRequest, "invalid_session", "invalid session ID", h.logger) - return - } - if parsed != sessionID { - WriteError(w, http.StatusForbidden, "forbidden", "session access denied", h.logger) - return - } + sessionID, err := uuid.Parse(req.SessionID) + if err != nil { + WriteError(w, http.StatusBadRequest, "invalid_session", "invalid session ID", h.logger) + return + } + + if !h.sessionAccessAllowed(r, sessionID) { + WriteError(w, http.StatusForbidden, "forbidden", "session access denied", h.logger) + return } msgID := uuid.New().String() @@ -118,6 +114,30 @@ func (h *chatHandler) send(w http.ResponseWriter, r *http.Request) { }, h.logger) } +// sessionAccessAllowed checks whether the request may access the session. +// Returns true when no session manager is configured (unit tests, CLI mode). +// When configured, verifies the session belongs to the authenticated user. +func (h *chatHandler) sessionAccessAllowed(r *http.Request, sessionID uuid.UUID) bool { + if h.sessions == nil { + return true // no session manager → allow (test/CLI mode) + } + if h.sessions.store == nil { + return false // configured but no store → deny + } + + userID, ok := userIDFromContext(r.Context()) + if !ok || userID == "" { + return false + } + + sess, err := h.sessions.store.Session(r.Context(), sessionID) + if err != nil { + return false + } + + return sess.OwnerID == userID +} + // stream handles GET /api/v1/chat/stream — SSE endpoint with JSON events. func (h *chatHandler) stream(w http.ResponseWriter, r *http.Request) { msgID := r.URL.Query().Get("msgId") @@ -129,14 +149,13 @@ func (h *chatHandler) stream(w http.ResponseWriter, r *http.Request) { return } - // Verify session ownership parsedID, err := uuid.Parse(sessionID) if err != nil { WriteError(w, http.StatusBadRequest, "invalid_session", "invalid session ID", h.logger) return } - ctxID, ok := sessionIDFromContext(r.Context()) - if !ok || ctxID != parsedID { + + if !h.sessionAccessAllowed(r, parsedID) { WriteError(w, http.StatusForbidden, "forbidden", "session access denied", h.logger) return } diff --git a/internal/api/chat_test.go b/internal/api/chat_test.go index e08af8a..667810c 100644 --- a/internal/api/chat_test.go +++ b/internal/api/chat_test.go @@ -21,6 +21,16 @@ import ( func newTestChatHandler() *chatHandler { return &chatHandler{ logger: slog.New(slog.DiscardHandler), + // sessions is nil — ownership verification is skipped for unit tests + } +} + +// newTestChatHandlerWithSessions creates a chat handler with a session manager +// but no store, causing sessionAccessAllowed to always return false (ownership denied). +func newTestChatHandlerWithSessions() *chatHandler { + return &chatHandler{ + logger: slog.New(slog.DiscardHandler), + sessions: &sessionManager{logger: slog.New(slog.DiscardHandler)}, } } @@ -35,13 +45,11 @@ func TestChatSend_URLEncoding(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) - ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) - r = r.WithContext(ctx) newTestChatHandler().send(w, r) if w.Code != http.StatusOK { - t.Fatalf("send() status = %d, want %d", w.Code, http.StatusOK) + t.Fatalf("send() status = %d, want %d\nbody: %s", w.Code, http.StatusOK, w.Body.String()) } var resp map[string]string @@ -78,13 +86,11 @@ func TestChatSend_SessionIDFromBody(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) - ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) - r = r.WithContext(ctx) newTestChatHandler().send(w, r) if w.Code != http.StatusOK { - t.Fatalf("send() status = %d, want %d", w.Code, http.StatusOK) + t.Fatalf("send() status = %d, want %d\nbody: %s", w.Code, http.StatusOK, w.Body.String()) } var resp map[string]string @@ -99,8 +105,7 @@ func TestChatSend_SessionIDFromBody(t *testing.T) { } } -func TestChatSend_SessionIDFromContext(t *testing.T) { - sessionID := uuid.New() +func TestChatSend_MissingSessionID(t *testing.T) { body, _ := json.Marshal(map[string]string{ "content": "hello", // No sessionId in body @@ -109,21 +114,15 @@ func TestChatSend_SessionIDFromContext(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) - // Inject session ID via context (as sessionMiddleware would) - ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) - r = r.WithContext(ctx) - newTestChatHandler().send(w, r) - if w.Code != http.StatusOK { - t.Fatalf("send() status = %d, want %d", w.Code, http.StatusOK) + if w.Code != http.StatusBadRequest { + t.Fatalf("send(no session) status = %d, want %d", w.Code, http.StatusBadRequest) } - var resp map[string]string - decodeData(t, w, &resp) - - if resp["sessionId"] != sessionID.String() { - t.Errorf("send() sessionId = %s, want %s (from context)", resp["sessionId"], sessionID) + errResp := decodeErrorEnvelope(t, w) + if errResp.Code != "session_required" { + t.Errorf("send(no session) code = %q, want %q", errResp.Code, "session_required") } } @@ -172,9 +171,6 @@ func TestChatSend_InvalidSessionID(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) - // Inject context session so we reach the body parse check - ctx := context.WithValue(r.Context(), ctxKeySessionID, uuid.New()) - r = r.WithContext(ctx) newTestChatHandler().send(w, r) @@ -189,36 +185,36 @@ func TestChatSend_InvalidSessionID(t *testing.T) { } } -func TestChatSend_NoSession(t *testing.T) { - body, _ := json.Marshal(map[string]string{ - "content": "hello", - // No sessionId in body, no session in context - }) - +func TestChatSend_InvalidJSON(t *testing.T) { w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) + r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader([]byte("not json"))) newTestChatHandler().send(w, r) if w.Code != http.StatusBadRequest { - t.Fatalf("send(no session) status = %d, want %d", w.Code, http.StatusBadRequest) - } - - errResp := decodeErrorEnvelope(t, w) - - if errResp.Code != "session_required" { - t.Errorf("send(no session) code = %q, want %q", errResp.Code, "session_required") + t.Fatalf("send(invalid json) status = %d, want %d", w.Code, http.StatusBadRequest) } } -func TestChatSend_InvalidJSON(t *testing.T) { +func TestChatSend_OwnershipDenied(t *testing.T) { + body, _ := json.Marshal(map[string]string{ + "content": "hello", + "sessionId": uuid.New().String(), + }) + w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader([]byte("not json"))) + r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) - newTestChatHandler().send(w, r) + // Use handler with sessions configured but no store — ownership always fails + newTestChatHandlerWithSessions().send(w, r) - if w.Code != http.StatusBadRequest { - t.Fatalf("send(invalid json) status = %d, want %d", w.Code, http.StatusBadRequest) + if w.Code != http.StatusForbidden { + t.Fatalf("send(ownership denied) status = %d, want %d", w.Code, http.StatusForbidden) + } + + errResp := decodeErrorEnvelope(t, w) + if errResp.Code != "forbidden" { + t.Errorf("send(ownership denied) code = %q, want %q", errResp.Code, "forbidden") } } @@ -312,8 +308,6 @@ func TestStream_SSEHeaders(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil) - ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) - r = r.WithContext(ctx) ch.stream(w, r) @@ -337,8 +331,6 @@ func TestStream_NilFlow(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hello", nil) - ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) - r = r.WithContext(ctx) ch.stream(w, r) @@ -414,7 +406,6 @@ func TestStream_NilFlow_ContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately - ctx = context.WithValue(ctx, ctxKeySessionID, sessionID) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil) @@ -458,58 +449,34 @@ func TestClassifyError(t *testing.T) { } } -func TestChatSend_SessionMismatch(t *testing.T) { - body, _ := json.Marshal(map[string]string{ - "content": "hello", - "sessionId": uuid.New().String(), // Different from context - }) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, "/api/v1/chat", bytes.NewReader(body)) - ctx := context.WithValue(r.Context(), ctxKeySessionID, uuid.New()) - r = r.WithContext(ctx) - - newTestChatHandler().send(w, r) - - if w.Code != http.StatusForbidden { - t.Fatalf("send(mismatched session) status = %d, want %d", w.Code, http.StatusForbidden) - } - - errResp := decodeErrorEnvelope(t, w) - if errResp.Code != "forbidden" { - t.Errorf("send(mismatched session) code = %q, want %q", errResp.Code, "forbidden") - } -} - func TestStream_OwnershipDenied(t *testing.T) { - ch := newTestChatHandler() + // Handler with sessions configured but no store → ownership always fails + ch := newTestChatHandlerWithSessions() sessionID := uuid.New() - otherID := uuid.New() w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil) - ctx := context.WithValue(r.Context(), ctxKeySessionID, otherID) - r = r.WithContext(ctx) ch.stream(w, r) if w.Code != http.StatusForbidden { - t.Fatalf("stream(wrong session) status = %d, want %d", w.Code, http.StatusForbidden) + t.Fatalf("stream(ownership denied) status = %d, want %d", w.Code, http.StatusForbidden) } } -func TestStream_NoSession(t *testing.T) { - ch := newTestChatHandler() +func TestStream_NoUser(t *testing.T) { + // Handler with sessions configured but no user in context + ch := newTestChatHandlerWithSessions() sessionID := uuid.New() w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionID.String()+"&query=hi", nil) - // No session in context + // No user in context ch.stream(w, r) if w.Code != http.StatusForbidden { - t.Fatalf("stream(no session) status = %d, want %d", w.Code, http.StatusForbidden) + t.Fatalf("stream(no user) status = %d, want %d", w.Code, http.StatusForbidden) } } @@ -559,6 +526,164 @@ func filterSSEEvents(events []sseTestEvent, eventType string) []sseTestEvent { return filtered } +// TestJSONToolEmitter verifies that jsonToolEmitter emits correct SSE events +// for tool start, complete, and error — for both known and unknown tools. +func TestJSONToolEmitter(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string // "start", "complete", "error" + toolName string + wantEvent string + wantMsg string + }{ + { + name: "start known tool", + method: "start", + toolName: "web_search", + wantEvent: "tool_start", + wantMsg: toolDisplay["web_search"].StartMsg, + }, + { + name: "start unknown tool", + method: "start", + toolName: "custom_tool", + wantEvent: "tool_start", + wantMsg: defaultToolDisplay.StartMsg, + }, + { + name: "complete known tool", + method: "complete", + toolName: "read_file", + wantEvent: "tool_complete", + wantMsg: toolDisplay["read_file"].CompleteMsg, + }, + { + name: "complete unknown tool", + method: "complete", + toolName: "custom_tool", + wantEvent: "tool_complete", + wantMsg: defaultToolDisplay.CompleteMsg, + }, + { + name: "error known tool", + method: "error", + toolName: "web_fetch", + wantEvent: "tool_error", + wantMsg: toolDisplay["web_fetch"].ErrorMsg, + }, + { + name: "error unknown tool", + method: "error", + toolName: "custom_tool", + wantEvent: "tool_error", + wantMsg: defaultToolDisplay.ErrorMsg, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + emitter := &jsonToolEmitter{w: w, msgID: "test-msg"} + + switch tt.method { + case "start": + emitter.OnToolStart(tt.toolName) + case "complete": + emitter.OnToolComplete(tt.toolName) + case "error": + emitter.OnToolError(tt.toolName) + } + + events := parseSSEEvents(t, w.Body.String()) + if len(events) != 1 { + t.Fatalf("jsonToolEmitter.%s(%q) emitted %d events, want 1", tt.method, tt.toolName, len(events)) + } + + ev := events[0] + if ev.Type != tt.wantEvent { + t.Errorf("jsonToolEmitter.%s(%q) event type = %q, want %q", tt.method, tt.toolName, ev.Type, tt.wantEvent) + } + if ev.Data["msgId"] != "test-msg" { + t.Errorf("jsonToolEmitter.%s(%q) msgId = %q, want %q", tt.method, tt.toolName, ev.Data["msgId"], "test-msg") + } + if ev.Data["tool"] != tt.toolName { + t.Errorf("jsonToolEmitter.%s(%q) tool = %q, want %q", tt.method, tt.toolName, ev.Data["tool"], tt.toolName) + } + if ev.Data["message"] != tt.wantMsg { + t.Errorf("jsonToolEmitter.%s(%q) message = %q, want %q", tt.method, tt.toolName, ev.Data["message"], tt.wantMsg) + } + }) + } +} + +// TestCreateSession_MissingUser verifies that createSession returns 400 +// when no user identity is present in the request context. +func TestCreateSession_MissingUser(t *testing.T) { + t.Parallel() + + sm := newTestSessionManager() + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil) + // No user in context + + sm.createSession(w, r) + + if w.Code != http.StatusBadRequest { + t.Fatalf("createSession(no user) status = %d, want %d\nbody: %s", w.Code, http.StatusBadRequest, w.Body.String()) + } + + errResp := decodeErrorEnvelope(t, w) + if errResp.Code != "user_required" { + t.Errorf("createSession(no user) code = %q, want %q", errResp.Code, "user_required") + } +} + +// TestMaybeGenerateTitle_NilPaths verifies early-return paths in maybeGenerateTitle +// when sessions or store are nil, or when the session ID is invalid. +func TestMaybeGenerateTitle_NilPaths(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + handler *chatHandler + sessionID string + }{ + { + name: "nil sessions", + handler: &chatHandler{logger: slog.New(slog.DiscardHandler)}, + sessionID: uuid.New().String(), + }, + { + name: "nil store", + handler: &chatHandler{ + logger: slog.New(slog.DiscardHandler), + sessions: &sessionManager{logger: slog.New(slog.DiscardHandler)}, + }, + sessionID: uuid.New().String(), + }, + { + name: "invalid UUID", + handler: &chatHandler{logger: slog.New(slog.DiscardHandler)}, + sessionID: "not-a-uuid", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + title := tt.handler.maybeGenerateTitle(context.Background(), tt.sessionID, "test message") + if title != "" { + t.Errorf("maybeGenerateTitle(%s) = %q, want empty string", tt.name, title) + } + }) + } +} + func TestStreamWithFlow(t *testing.T) { sessionID := uuid.New() sessionIDStr := sessionID.String() @@ -620,13 +745,12 @@ func TestStreamWithFlow(t *testing.T) { ch := &chatHandler{ logger: slog.New(slog.DiscardHandler), flow: testFlow, + // sessions is nil — ownership skipped for unit tests } w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/chat/stream?msgId=m1&session_id="+sessionIDStr+"&query=test", nil) - rctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) - r = r.WithContext(rctx) ch.stream(w, r) diff --git a/internal/api/integration_test.go b/internal/api/integration_test.go index 94dedc1..d785c1f 100644 --- a/internal/api/integration_test.go +++ b/internal/api/integration_test.go @@ -18,6 +18,8 @@ import ( "github.com/koopa0/koopa/internal/testutil" ) +const testOwnerID = "test-user" + // setupIntegrationSessionManager creates a sessionManager backed by a real PostgreSQL database. func setupIntegrationSessionManager(t *testing.T) *sessionManager { t.Helper() @@ -40,6 +42,10 @@ func TestCreateSession_Success(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil) + // Inject user identity into context (normally done by userMiddleware) + ctx := context.WithValue(r.Context(), ctxKeyUserID, testOwnerID) + r = r.WithContext(ctx) + sm.createSession(w, r) if w.Code != http.StatusCreated { @@ -54,7 +60,7 @@ func TestCreateSession_Success(t *testing.T) { t.Errorf("createSession() id = %q, want valid UUID", resp["id"]) } - // Should return a CSRF token bound to the new session + // Should return a CSRF token bound to the user if resp["csrfToken"] == "" { t.Error("createSession() expected csrfToken in response") } @@ -83,7 +89,7 @@ func TestGetSession_Success(t *testing.T) { ctx := context.Background() // Create a session first - sess, err := sm.store.CreateSession(ctx, "Test Session") + sess, err := sm.store.CreateSession(ctx, testOwnerID, "Test Session") if err != nil { t.Fatalf("setup: CreateSession() error: %v", err) } @@ -92,8 +98,8 @@ func TestGetSession_Success(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sess.ID.String(), nil) r.SetPathValue("id", sess.ID.String()) - // Inject session ownership (same session ID in context) - rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID) + // Inject user identity for ownership check + rctx := context.WithValue(r.Context(), ctxKeyUserID, testOwnerID) r = r.WithContext(rctx) sm.getSession(w, r) @@ -128,8 +134,8 @@ func TestGetSession_NotFound(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+missingID.String(), nil) r.SetPathValue("id", missingID.String()) - // Set ownership to match (bypasses ownership check, tests store-level 404) - rctx := context.WithValue(r.Context(), ctxKeySessionID, missingID) + // Set user identity (session doesn't exist, so ownership check returns 404) + rctx := context.WithValue(r.Context(), ctxKeyUserID, testOwnerID) r = r.WithContext(rctx) sm.getSession(w, r) @@ -149,7 +155,7 @@ func TestListSessions_WithSession(t *testing.T) { ctx := context.Background() // Create a session - sess, err := sm.store.CreateSession(ctx, "My Chat") + sess, err := sm.store.CreateSession(ctx, testOwnerID, "My Chat") if err != nil { t.Fatalf("setup: CreateSession() error: %v", err) } @@ -157,8 +163,8 @@ func TestListSessions_WithSession(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil) - // Inject session ownership - rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID) + // Inject user identity for listing + rctx := context.WithValue(r.Context(), ctxKeyUserID, testOwnerID) r = r.WithContext(rctx) sm.listSessions(w, r) @@ -194,7 +200,7 @@ func TestGetSessionMessages_Empty(t *testing.T) { ctx := context.Background() // Create a session with no messages - sess, err := sm.store.CreateSession(ctx, "") + sess, err := sm.store.CreateSession(ctx, testOwnerID, "") if err != nil { t.Fatalf("setup: CreateSession() error: %v", err) } @@ -203,7 +209,7 @@ func TestGetSessionMessages_Empty(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sess.ID.String()+"/messages", nil) r.SetPathValue("id", sess.ID.String()) - rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID) + rctx := context.WithValue(r.Context(), ctxKeyUserID, testOwnerID) r = r.WithContext(rctx) sm.getSessionMessages(w, r) @@ -235,7 +241,7 @@ func TestDeleteSession_Success(t *testing.T) { ctx := context.Background() // Create a session - sess, err := sm.store.CreateSession(ctx, "To Delete") + sess, err := sm.store.CreateSession(ctx, testOwnerID, "To Delete") if err != nil { t.Fatalf("setup: CreateSession() error: %v", err) } @@ -244,8 +250,8 @@ func TestDeleteSession_Success(t *testing.T) { r := httptest.NewRequest(http.MethodDelete, "/api/v1/sessions/"+sess.ID.String(), nil) r.SetPathValue("id", sess.ID.String()) - // Inject ownership - rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID) + // Inject user identity for ownership check + rctx := context.WithValue(r.Context(), ctxKeyUserID, testOwnerID) r = r.WithContext(rctx) sm.deleteSession(w, r) @@ -270,25 +276,20 @@ func TestDeleteSession_Success(t *testing.T) { func TestCSRFTokenEndpoint_WithSession(t *testing.T) { sm := setupIntegrationSessionManager(t) - ctx := context.Background() - // Create a real session - sess, err := sm.store.CreateSession(ctx, "") - if err != nil { - t.Fatalf("setup: CreateSession() error: %v", err) - } + userID := uuid.New().String() w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil) - r.AddCookie(&http.Cookie{ - Name: "sid", - Value: sess.ID.String(), - }) + + // Inject user identity into context (csrfToken handler reads from context now) + ctx := context.WithValue(r.Context(), ctxKeyUserID, userID) + r = r.WithContext(ctx) sm.csrfToken(w, r) if w.Code != http.StatusOK { - t.Fatalf("csrfToken(with session) status = %d, want %d", w.Code, http.StatusOK) + t.Fatalf("csrfToken(with user) status = %d, want %d", w.Code, http.StatusOK) } var body map[string]string @@ -296,17 +297,80 @@ func TestCSRFTokenEndpoint_WithSession(t *testing.T) { token := body["csrfToken"] if token == "" { - t.Fatal("csrfToken(with session) expected csrfToken in response") + t.Fatal("csrfToken(with user) expected csrfToken in response") } - // Session-bound tokens should NOT have pre: prefix + // User-bound tokens should NOT have pre: prefix if isPreSessionToken(token) { - t.Error("csrfToken(with session) should return session-bound token, not pre-session") + t.Error("csrfToken(with user) should return user-bound token, not pre-session") + } + + // Token should validate against the user + if err := sm.CheckCSRF(userID, token); err != nil { + t.Fatalf("csrfToken(with user) returned invalid token: %v", err) + } +} + +// TestMaybeGenerateTitle_SessionHasTitle verifies that maybeGenerateTitle +// returns empty when the session already has a title (no overwrite). +func TestMaybeGenerateTitle_SessionHasTitle(t *testing.T) { + sm := setupIntegrationSessionManager(t) + ctx := context.Background() + + sess, err := sm.store.CreateSession(ctx, testOwnerID, "Existing Title") + if err != nil { + t.Fatalf("setup: CreateSession() error: %v", err) + } + + ch := &chatHandler{ + logger: slog.New(slog.DiscardHandler), + sessions: sm, } - // Token should validate against the session - if err := sm.CheckCSRF(sess.ID, token); err != nil { - t.Fatalf("csrfToken(with session) returned invalid token: %v", err) + title := ch.maybeGenerateTitle(ctx, sess.ID.String(), "new message") + if title != "" { + t.Errorf("maybeGenerateTitle(%q) = %q, want empty string", sess.ID.String(), title) + } +} + +// TestMaybeGenerateTitle_FallbackTruncation verifies that when agent is nil, +// maybeGenerateTitle falls back to truncateForTitle for title generation. +func TestMaybeGenerateTitle_FallbackTruncation(t *testing.T) { + sm := setupIntegrationSessionManager(t) + ctx := context.Background() + + // Create session with empty title + sess, err := sm.store.CreateSession(ctx, testOwnerID, "") + if err != nil { + t.Fatalf("setup: CreateSession() error: %v", err) + } + + ch := &chatHandler{ + logger: slog.New(slog.DiscardHandler), + agent: nil, // no AI title generation + sessions: sm, + } + + userMsg := "How do I use Go generics effectively?" + title := ch.maybeGenerateTitle(ctx, sess.ID.String(), userMsg) + + if title == "" { + t.Fatal("maybeGenerateTitle(fallback) = empty, want truncated title") + } + + // Verify fallback matches truncateForTitle behavior + want := truncateForTitle(userMsg) + if title != want { + t.Errorf("maybeGenerateTitle(%q) = %q, want %q", sess.ID.String(), title, want) + } + + // Verify title was persisted + updated, err := sm.store.Session(ctx, sess.ID) + if err != nil { + t.Fatalf("verifying title: %v", err) + } + if updated.Title != title { + t.Errorf("persisted title = %q, want %q", updated.Title, title) } } @@ -315,7 +379,7 @@ func TestGetSessionMessages_WithMessages(t *testing.T) { ctx := context.Background() // Create a session with messages - sess, err := sm.store.CreateSession(ctx, "Test Chat") + sess, err := sm.store.CreateSession(ctx, testOwnerID, "Test Chat") if err != nil { t.Fatalf("setup: CreateSession() error: %v", err) } @@ -333,7 +397,7 @@ func TestGetSessionMessages_WithMessages(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+sess.ID.String()+"/messages", nil) r.SetPathValue("id", sess.ID.String()) - rctx := context.WithValue(r.Context(), ctxKeySessionID, sess.ID) + rctx := context.WithValue(r.Context(), ctxKeyUserID, testOwnerID) r = r.WithContext(rctx) sm.getSessionMessages(w, r) diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 821eaef..bb0f5de 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -10,18 +10,27 @@ import ( "github.com/google/uuid" ) -// sessionIDKey is an unexported context key type to prevent collisions. +// Context key types (unexported to prevent collisions). type sessionIDKey struct{} +type userIDCtxKey struct{} var ctxKeySessionID = sessionIDKey{} +var ctxKeyUserID = userIDCtxKey{} -// sessionIDFromContext retrieves the session ID from the request context. +// sessionIDFromContext retrieves the active session ID from the request context. // Returns uuid.Nil and false if not found. func sessionIDFromContext(ctx context.Context) (uuid.UUID, bool) { sessionID, ok := ctx.Value(ctxKeySessionID).(uuid.UUID) return sessionID, ok } +// userIDFromContext retrieves the user identity from the request context. +// Returns empty string and false if not found. +func userIDFromContext(ctx context.Context) (string, bool) { + uid, ok := ctx.Value(ctxKeyUserID).(string) + return uid, ok +} + // loggingWriter wraps http.ResponseWriter to capture metrics. // Implements Flusher for SSE streaming and Unwrap for ResponseController. type loggingWriter struct { @@ -152,14 +161,30 @@ func corsMiddleware(allowedOrigins []string) func(http.Handler) http.Handler { } } -// sessionMiddleware extracts the session ID from the cookie and adds it to the -// request context. If no valid session cookie is present, the request continues -// without a session ID in context. Individual handlers are responsible for -// creating sessions when needed (e.g., createSession). +// userMiddleware auto-provisions and extracts user identity (uid cookie). +// On first visit, generates a new UUID and sets the uid cookie. +// Subsequent requests use the existing uid cookie value. +func userMiddleware(sm *sessionManager) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userID := sm.UserID(r) + if userID == "" { + userID = uuid.New().String() + sm.setUserCookie(w, userID) + } + ctx := context.WithValue(r.Context(), ctxKeyUserID, userID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// sessionMiddleware extracts the active session ID from the sid cookie and adds +// it to the request context. If no valid session cookie is present, the request +// continues without a session ID in context. func sessionMiddleware(sm *sessionManager) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sessionID, err := sm.ID(r) + sessionID, err := sm.SessionID(r) if err == nil { ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) next.ServeHTTP(w, r.WithContext(ctx)) @@ -173,7 +198,7 @@ func sessionMiddleware(sm *sessionManager) func(http.Handler) http.Handler { // csrfMiddleware validates CSRF tokens for state-changing requests. // Reads token from X-CSRF-Token header (JSON API pattern). -// Supports both pre-session and session-bound tokens. +// Supports both pre-session and user-bound tokens. func csrfMiddleware(sm *sessionManager, logger *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -185,7 +210,7 @@ func csrfMiddleware(sm *sessionManager, logger *slog.Logger) func(http.Handler) csrfToken := r.Header.Get("X-CSRF-Token") - // Check pre-session token + // Check pre-session token (before uid cookie is established) if isPreSessionToken(csrfToken) { if err := sm.CheckPreSessionCSRF(csrfToken); err != nil { logger.Warn("pre-session CSRF validation failed", @@ -200,21 +225,21 @@ func csrfMiddleware(sm *sessionManager, logger *slog.Logger) func(http.Handler) return } - // Session-bound token - sessionID, ok := sessionIDFromContext(r.Context()) - if !ok { - logger.Error("validating CSRF: session ID not in context", + // User-bound token + userID, ok := userIDFromContext(r.Context()) + if !ok || userID == "" { + logger.Error("validating CSRF: user ID not in context", "path", r.URL.Path, "method", r.Method, ) - WriteError(w, http.StatusForbidden, "session_required", "session required", logger) + WriteError(w, http.StatusForbidden, "user_required", "user identity required", logger) return } - if err := sm.CheckCSRF(sessionID, csrfToken); err != nil { + if err := sm.CheckCSRF(userID, csrfToken); err != nil { logger.Warn("validating CSRF", "error", err, - "session", sessionID, + "user", userID, "path", r.URL.Path, "method", r.Method, ) diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index c27e60f..643c039 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -209,8 +209,8 @@ func TestCSRFMiddleware_AcceptsValidSessionToken(t *testing.T) { logger: logger, } - sessionID := uuid.New() - token := sm.NewCSRFToken(sessionID) + userID := uuid.New().String() + token := sm.NewCSRFToken(userID) called := false handler := csrfMiddleware(sm, logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -222,8 +222,8 @@ func TestCSRFMiddleware_AcceptsValidSessionToken(t *testing.T) { r := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil) r.Header.Set("X-CSRF-Token", token) - // Inject session ID into context (normally done by sessionMiddleware) - ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) + // Inject user ID into context (normally done by userMiddleware) + ctx := context.WithValue(r.Context(), ctxKeyUserID, userID) r = r.WithContext(ctx) handler.ServeHTTP(w, r) @@ -240,7 +240,7 @@ func TestCSRFMiddleware_RejectsInvalidToken(t *testing.T) { logger: logger, } - sessionID := uuid.New() + userID := uuid.New().String() handler := csrfMiddleware(sm, logger)(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { t.Error("handler should not be called with invalid token") @@ -250,7 +250,7 @@ func TestCSRFMiddleware_RejectsInvalidToken(t *testing.T) { r := httptest.NewRequest(http.MethodPost, "/api/v1/sessions", nil) r.Header.Set("X-CSRF-Token", "obviously-invalid-token") - ctx := context.WithValue(r.Context(), ctxKeySessionID, sessionID) + ctx := context.WithValue(r.Context(), ctxKeyUserID, userID) r = r.WithContext(ctx) handler.ServeHTTP(w, r) diff --git a/internal/api/server.go b/internal/api/server.go index df0f20c..b246e58 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -19,6 +19,7 @@ type ServerConfig struct { CORSOrigins []string // Allowed origins for CORS IsDev bool // Enables HTTP cookies (no Secure flag) TrustProxy bool // Trust X-Real-IP/X-Forwarded-For headers (behind reverse proxy) + RateBurst int // Rate limiter burst size per IP (0 = default 60) } // Server is the JSON API HTTP server. @@ -70,13 +71,18 @@ func NewServer(cfg ServerConfig) (*Server, error) { mux.HandleFunc("POST /api/v1/chat", ch.send) mux.HandleFunc("GET /api/v1/chat/stream", ch.stream) - // Rate limiter: 60 requests/minute per IP (1 token/sec, burst 60) - rl := newRateLimiter(1.0, 60) + // Rate limiter: per-IP token bucket (1 token/sec refill) + burst := cfg.RateBurst + if burst <= 0 { + burst = 60 + } + rl := newRateLimiter(1.0, burst) - // Build middleware stack: Recovery → Logging → RateLimit → CORS → Session → CSRF → Routes + // Build middleware stack: Recovery → Logging → RateLimit → CORS → User → Session → CSRF → Routes var handler http.Handler = mux handler = csrfMiddleware(sm, logger)(handler) handler = sessionMiddleware(sm)(handler) + handler = userMiddleware(sm)(handler) handler = corsMiddleware(cfg.CORSOrigins)(handler) handler = rateLimitMiddleware(rl, cfg.TrustProxy, logger)(handler) handler = loggingMiddleware(logger)(handler) diff --git a/internal/api/session.go b/internal/api/session.go index d062316..62c7ad8 100644 --- a/internal/api/session.go +++ b/internal/api/session.go @@ -33,19 +33,21 @@ var ( ErrCSRFMalformed = errors.New("csrf token malformed") ) -// Pre-session CSRF token prefix to distinguish from session-bound tokens. +// Pre-session CSRF token prefix to distinguish from user-bound tokens. const preSessionPrefix = "pre:" // Cookie and CSRF configuration. const ( sessionCookieName = "sid" + userCookieName = "uid" csrfTokenTTL = 24 * time.Hour - sessionMaxAge = 30 * 24 * 3600 // 30 days in seconds + cookieMaxAge = 30 * 24 * 3600 // 30 days in seconds csrfClockSkew = 5 * time.Minute messagesDefaultLimit = 100 + sessionsDefaultLimit = 50 ) -// sessionManager handles session cookies and CSRF token operations. +// sessionManager handles session cookies, user identity, and CSRF token operations. type sessionManager struct { store *session.Store hmacSecret []byte @@ -53,8 +55,8 @@ type sessionManager struct { logger *slog.Logger } -// ID extracts session ID from cookie without creating a new session. -func (*sessionManager) ID(r *http.Request) (uuid.UUID, error) { +// SessionID extracts the active session ID from the sid cookie. +func (*sessionManager) SessionID(r *http.Request) (uuid.UUID, error) { cookie, err := r.Cookie(sessionCookieName) if err != nil { return uuid.Nil, ErrSessionCookieNotFound @@ -68,11 +70,21 @@ func (*sessionManager) ID(r *http.Request) (uuid.UUID, error) { return sessionID, nil } -// NewCSRFToken creates an HMAC-based token bound to the session ID. +// UserID extracts the user identity from the uid cookie. +// Returns empty string if no uid cookie is present. +func (*sessionManager) UserID(r *http.Request) string { + cookie, err := r.Cookie(userCookieName) + if err != nil { + return "" + } + return cookie.Value +} + +// NewCSRFToken creates an HMAC-based token bound to the user ID. // Format: "timestamp:signature" -func (sm *sessionManager) NewCSRFToken(sessionID uuid.UUID) string { +func (sm *sessionManager) NewCSRFToken(userID string) string { timestamp := time.Now().Unix() - message := fmt.Sprintf("%s:%d", sessionID.String(), timestamp) + message := fmt.Sprintf("%s:%d", userID, timestamp) h := hmac.New(sha256.New, sm.hmacSecret) h.Write([]byte(message)) @@ -81,8 +93,8 @@ func (sm *sessionManager) NewCSRFToken(sessionID uuid.UUID) string { return fmt.Sprintf("%d:%s", timestamp, signature) } -// CheckCSRF verifies a session-bound CSRF token. -func (sm *sessionManager) CheckCSRF(sessionID uuid.UUID, token string) error { +// CheckCSRF verifies a user-bound CSRF token. +func (sm *sessionManager) CheckCSRF(userID, token string) error { if token == "" { return ErrCSRFRequired } @@ -105,7 +117,7 @@ func (sm *sessionManager) CheckCSRF(sessionID uuid.UUID, token string) error { return ErrCSRFInvalid } - message := fmt.Sprintf("%s:%d", sessionID.String(), timestamp) + message := fmt.Sprintf("%s:%d", userID, timestamp) h := hmac.New(sha256.New, sm.hmacSecret) h.Write([]byte(message)) expectedSig := base64.URLEncoding.EncodeToString(h.Sum(nil)) @@ -173,9 +185,9 @@ func (sm *sessionManager) CheckPreSessionCSRF(token string) error { return nil } -// requireOwnership verifies the requested session ID matches the caller's session cookie. +// requireOwnership verifies the requested session belongs to the caller. +// Uses owner_id from the database to support multi-session ownership. // Returns the verified session ID and true, or writes an error response and returns false. -// This prevents session enumeration and cross-session access. func (sm *sessionManager) requireOwnership(w http.ResponseWriter, r *http.Request) (uuid.UUID, bool) { idStr := r.PathValue("id") if idStr == "" { @@ -189,12 +201,30 @@ func (sm *sessionManager) requireOwnership(w http.ResponseWriter, r *http.Reques return uuid.Nil, false } - ownerID, ok := sessionIDFromContext(r.Context()) - if !ok || ownerID != targetID { + userID, ok := userIDFromContext(r.Context()) + if !ok || userID == "" { + WriteError(w, http.StatusForbidden, "forbidden", "user identity required", sm.logger) + return uuid.Nil, false + } + + // Verify session exists and is owned by this user + sess, err := sm.store.Session(r.Context(), targetID) + if err != nil { + if errors.Is(err, session.ErrNotFound) { + WriteError(w, http.StatusNotFound, "not_found", "session not found", sm.logger) + return uuid.Nil, false + } + sm.logger.Error("checking session ownership", "error", err, "session_id", targetID) + WriteError(w, http.StatusInternalServerError, "get_failed", "failed to verify session", sm.logger) + return uuid.Nil, false + } + + if sess.OwnerID != userID { sm.logger.Warn("session ownership check failed", "target", targetID, + "owner", sess.OwnerID, + "caller", userID, "path", r.URL.Path, - "remote_addr", r.RemoteAddr, ) WriteError(w, http.StatusForbidden, "forbidden", "session access denied", sm.logger) return uuid.Nil, false @@ -203,7 +233,7 @@ func (sm *sessionManager) requireOwnership(w http.ResponseWriter, r *http.Reques return targetID, true } -func (sm *sessionManager) setCookie(w http.ResponseWriter, sessionID uuid.UUID) { +func (sm *sessionManager) setSessionCookie(w http.ResponseWriter, sessionID uuid.UUID) { http.SetCookie(w, &http.Cookie{ Name: sessionCookieName, Value: sessionID.String(), @@ -211,17 +241,29 @@ func (sm *sessionManager) setCookie(w http.ResponseWriter, sessionID uuid.UUID) Secure: !sm.isDev, HttpOnly: true, SameSite: http.SameSiteLaxMode, - MaxAge: sessionMaxAge, + MaxAge: cookieMaxAge, + }) +} + +func (sm *sessionManager) setUserCookie(w http.ResponseWriter, userID string) { + http.SetCookie(w, &http.Cookie{ + Name: userCookieName, + Value: userID, + Path: "/", + Secure: !sm.isDev, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + MaxAge: cookieMaxAge, }) } // csrfToken handles GET /api/v1/csrf-token — provisions a CSRF token. -// Returns a session-bound token if a session exists, otherwise a pre-session token. +// Returns a user-bound token if uid cookie exists, otherwise a pre-session token. func (sm *sessionManager) csrfToken(w http.ResponseWriter, r *http.Request) { - sessionID, err := sm.ID(r) - if err == nil { + userID, ok := userIDFromContext(r.Context()) + if ok && userID != "" { WriteJSON(w, http.StatusOK, map[string]string{ - "csrfToken": sm.NewCSRFToken(sessionID), + "csrfToken": sm.NewCSRFToken(userID), }, sm.logger) return } @@ -231,8 +273,7 @@ func (sm *sessionManager) csrfToken(w http.ResponseWriter, r *http.Request) { }, sm.logger) } -// listSessions handles GET /api/v1/sessions — returns sessions owned by the caller. -// Only returns the session matching the caller's cookie (ownership enforcement). +// listSessions handles GET /api/v1/sessions — returns all sessions owned by the caller. func (sm *sessionManager) listSessions(w http.ResponseWriter, r *http.Request) { type sessionItem struct { ID string `json:"id"` @@ -240,52 +281,56 @@ func (sm *sessionManager) listSessions(w http.ResponseWriter, r *http.Request) { UpdatedAt string `json:"updatedAt"` } - sessionID, ok := sessionIDFromContext(r.Context()) - if !ok { - // No session cookie — return empty list + userID, ok := userIDFromContext(r.Context()) + if !ok || userID == "" { WriteJSON(w, http.StatusOK, []sessionItem{}, sm.logger) return } - sess, err := sm.store.Session(r.Context(), sessionID) + sessions, err := sm.store.Sessions(r.Context(), userID, sessionsDefaultLimit, 0) if err != nil { - if errors.Is(err, session.ErrNotFound) { - WriteJSON(w, http.StatusOK, []sessionItem{}, sm.logger) - return - } - sm.logger.Error("getting session", "error", err, "session_id", sessionID) + sm.logger.Error("listing sessions", "error", err, "user_id", userID) WriteError(w, http.StatusInternalServerError, "list_failed", "failed to list sessions", sm.logger) return } - WriteJSON(w, http.StatusOK, []sessionItem{ - { + items := make([]sessionItem, len(sessions)) + for i, sess := range sessions { + items[i] = sessionItem{ ID: sess.ID.String(), Title: sess.Title, UpdatedAt: sess.UpdatedAt.Format(time.RFC3339), - }, - }, sm.logger) + } + } + + WriteJSON(w, http.StatusOK, items, sm.logger) } // createSession handles POST /api/v1/sessions — creates a new session. func (sm *sessionManager) createSession(w http.ResponseWriter, r *http.Request) { - sess, err := sm.store.CreateSession(r.Context(), "") + userID, ok := userIDFromContext(r.Context()) + if !ok || userID == "" { + WriteError(w, http.StatusBadRequest, "user_required", "user identity required", sm.logger) + return + } + + sess, err := sm.store.CreateSession(r.Context(), userID, "") if err != nil { sm.logger.Error("creating session", "error", err) WriteError(w, http.StatusInternalServerError, "create_failed", "failed to create session", sm.logger) return } - sm.setCookie(w, sess.ID) + sm.setSessionCookie(w, sess.ID) WriteJSON(w, http.StatusCreated, map[string]string{ "id": sess.ID.String(), - "csrfToken": sm.NewCSRFToken(sess.ID), + "csrfToken": sm.NewCSRFToken(userID), }, sm.logger) } // getSession handles GET /api/v1/sessions/{id} — returns a single session. -// Requires ownership: the session ID must match the caller's session cookie. +// Requires ownership: the session must belong to the caller. func (sm *sessionManager) getSession(w http.ResponseWriter, r *http.Request) { id, ok := sm.requireOwnership(w, r) if !ok { @@ -312,7 +357,7 @@ func (sm *sessionManager) getSession(w http.ResponseWriter, r *http.Request) { } // getSessionMessages handles GET /api/v1/sessions/{id}/messages — returns messages for a session. -// Requires ownership: the session ID must match the caller's session cookie. +// Requires ownership: the session must belong to the caller. func (sm *sessionManager) getSessionMessages(w http.ResponseWriter, r *http.Request) { id, ok := sm.requireOwnership(w, r) if !ok { @@ -355,7 +400,7 @@ func (sm *sessionManager) getSessionMessages(w http.ResponseWriter, r *http.Requ } // deleteSession handles DELETE /api/v1/sessions/{id} — deletes a session. -// Requires ownership: the session ID must match the caller's session cookie. +// Requires ownership: the session must belong to the caller. func (sm *sessionManager) deleteSession(w http.ResponseWriter, r *http.Request) { id, ok := sm.requireOwnership(w, r) if !ok { diff --git a/internal/api/session_test.go b/internal/api/session_test.go index efdf2b5..e67bde6 100644 --- a/internal/api/session_test.go +++ b/internal/api/session_test.go @@ -24,8 +24,8 @@ func newTestSessionManager() *sessionManager { } // csrfTokenWithTimestamp creates a CSRF token with a specific timestamp for testing expiration. -func csrfTokenWithTimestamp(secret []byte, sessionID uuid.UUID, ts int64) string { - msg := fmt.Sprintf("%s:%d", sessionID.String(), ts) +func csrfTokenWithTimestamp(secret []byte, userID string, ts int64) string { + msg := fmt.Sprintf("%s:%d", userID, ts) h := hmac.New(sha256.New, secret) h.Write([]byte(msg)) sig := base64.URLEncoding.EncodeToString(h.Sum(nil)) @@ -34,27 +34,27 @@ func csrfTokenWithTimestamp(secret []byte, sessionID uuid.UUID, ts int64) string func TestNewCSRFToken_RoundTrip(t *testing.T) { sm := newTestSessionManager() - sessionID := uuid.New() + userID := uuid.New().String() - token := sm.NewCSRFToken(sessionID) + token := sm.NewCSRFToken(userID) if token == "" { t.Fatal("NewCSRFToken() returned empty token") } - if err := sm.CheckCSRF(sessionID, token); err != nil { + if err := sm.CheckCSRF(userID, token); err != nil { t.Fatalf("CheckCSRF(valid token) error: %v", err) } } -func TestCSRFToken_WrongSession(t *testing.T) { +func TestCSRFToken_WrongUser(t *testing.T) { sm := newTestSessionManager() - sessionID := uuid.New() - otherID := uuid.New() + userID := uuid.New().String() + otherID := uuid.New().String() - token := sm.NewCSRFToken(sessionID) + token := sm.NewCSRFToken(userID) if err := sm.CheckCSRF(otherID, token); err == nil { - t.Error("CheckCSRF(wrong session) expected error, got nil") + t.Error("CheckCSRF(wrong user) expected error, got nil") } } @@ -65,17 +65,17 @@ func TestCSRFToken_WrongSecret(t *testing.T) { logger: slog.New(slog.DiscardHandler), } - sessionID := uuid.New() - token := sm1.NewCSRFToken(sessionID) + userID := uuid.New().String() + token := sm1.NewCSRFToken(userID) - if err := sm2.CheckCSRF(sessionID, token); err == nil { + if err := sm2.CheckCSRF(userID, token); err == nil { t.Error("CheckCSRF(wrong secret) expected error, got nil") } } func TestCSRFToken_Malformed(t *testing.T) { sm := newTestSessionManager() - sessionID := uuid.New() + userID := uuid.New().String() tests := []struct { name string @@ -88,7 +88,7 @@ func TestCSRFToken_Malformed(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := sm.CheckCSRF(sessionID, tt.token) + err := sm.CheckCSRF(userID, tt.token) if err == nil { t.Errorf("CheckCSRF(%q) expected error, got nil", tt.token) } @@ -98,13 +98,13 @@ func TestCSRFToken_Malformed(t *testing.T) { func TestCSRFToken_Expired(t *testing.T) { sm := newTestSessionManager() - sessionID := uuid.New() + userID := uuid.New().String() // Construct a token with a timestamp 25 hours ago (exceeds 24h TTL) oldTimestamp := time.Now().Add(-25 * time.Hour).Unix() - token := csrfTokenWithTimestamp(sm.hmacSecret, sessionID, oldTimestamp) + token := csrfTokenWithTimestamp(sm.hmacSecret, userID, oldTimestamp) - err := sm.CheckCSRF(sessionID, token) + err := sm.CheckCSRF(userID, token) if err == nil { t.Error("CheckCSRF(expired token) expected error, got nil") } @@ -146,7 +146,7 @@ func TestCSRFTokenEndpoint_PreSession(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil) - // No session cookie — should get pre-session token + // No uid cookie — should get pre-session token sm.csrfToken(w, r) @@ -163,7 +163,7 @@ func TestCSRFTokenEndpoint_PreSession(t *testing.T) { } if !isPreSessionToken(token) { - t.Error("csrfToken(no cookie) token should be pre-session") + t.Error("csrfToken(no uid) token should be pre-session") } if err := sm.CheckPreSessionCSRF(token); err != nil { @@ -171,6 +171,34 @@ func TestCSRFTokenEndpoint_PreSession(t *testing.T) { } } +func TestCSRFTokenEndpoint_WithUser(t *testing.T) { + sm := newTestSessionManager() + userID := uuid.New().String() + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/v1/csrf-token", nil) + ctx := context.WithValue(r.Context(), ctxKeyUserID, userID) + r = r.WithContext(ctx) + + sm.csrfToken(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("csrfToken() status = %d, want %d", w.Code, http.StatusOK) + } + + var body map[string]string + decodeData(t, w, &body) + + token := body["csrfToken"] + if isPreSessionToken(token) { + t.Error("csrfToken(with uid) should not be pre-session") + } + + if err := sm.CheckCSRF(userID, token); err != nil { + t.Fatalf("csrfToken() returned invalid user-bound token: %v", err) + } +} + func TestDeleteSession_InvalidUUID(t *testing.T) { sm := newTestSessionManager() @@ -245,93 +273,70 @@ func TestGetSessionMessages_InvalidUUID(t *testing.T) { } } -func TestRequireOwnership_NoSession(t *testing.T) { +func TestRequireOwnership_NoUser(t *testing.T) { sm := newTestSessionManager() targetID := uuid.New() w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+targetID.String(), nil) r.SetPathValue("id", targetID.String()) - // No session in context — should return 403 + // No user in context — should return 403 sm.getSession(w, r) if w.Code != http.StatusForbidden { - t.Fatalf("getSession(no session cookie) status = %d, want %d", w.Code, http.StatusForbidden) + t.Fatalf("getSession(no user) status = %d, want %d", w.Code, http.StatusForbidden) } body := decodeErrorEnvelope(t, w) if body.Code != "forbidden" { - t.Errorf("getSession(no session cookie) code = %q, want %q", body.Code, "forbidden") - } -} - -func TestRequireOwnership_Mismatch(t *testing.T) { - sm := newTestSessionManager() - ownerID := uuid.New() - targetID := uuid.New() - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+targetID.String(), nil) - r.SetPathValue("id", targetID.String()) - // Set a different session ID in context (simulates different cookie) - ctx := context.WithValue(r.Context(), ctxKeySessionID, ownerID) - r = r.WithContext(ctx) - - sm.getSession(w, r) - - if w.Code != http.StatusForbidden { - t.Fatalf("getSession(mismatched session) status = %d, want %d", w.Code, http.StatusForbidden) + t.Errorf("getSession(no user) code = %q, want %q", body.Code, "forbidden") } } -func TestDeleteSession_OwnershipDenied(t *testing.T) { +func TestDeleteSession_NoUser(t *testing.T) { sm := newTestSessionManager() - ownerID := uuid.New() targetID := uuid.New() w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodDelete, "/api/v1/sessions/"+targetID.String(), nil) r.SetPathValue("id", targetID.String()) - ctx := context.WithValue(r.Context(), ctxKeySessionID, ownerID) - r = r.WithContext(ctx) + // No user in context sm.deleteSession(w, r) if w.Code != http.StatusForbidden { - t.Fatalf("deleteSession(not owner) status = %d, want %d", w.Code, http.StatusForbidden) + t.Fatalf("deleteSession(no user) status = %d, want %d", w.Code, http.StatusForbidden) } } -func TestGetSessionMessages_OwnershipDenied(t *testing.T) { +func TestGetSessionMessages_NoUser(t *testing.T) { sm := newTestSessionManager() - ownerID := uuid.New() targetID := uuid.New() w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions/"+targetID.String()+"/messages", nil) r.SetPathValue("id", targetID.String()) - ctx := context.WithValue(r.Context(), ctxKeySessionID, ownerID) - r = r.WithContext(ctx) + // No user in context sm.getSessionMessages(w, r) if w.Code != http.StatusForbidden { - t.Fatalf("getSessionMessages(not owner) status = %d, want %d", w.Code, http.StatusForbidden) + t.Fatalf("getSessionMessages(no user) status = %d, want %d", w.Code, http.StatusForbidden) } } -func TestListSessions_NoSession(t *testing.T) { +func TestListSessions_NoUser(t *testing.T) { sm := newTestSessionManager() w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/v1/sessions", nil) - // No session in context + // No user in context sm.listSessions(w, r) if w.Code != http.StatusOK { - t.Fatalf("listSessions(no session) status = %d, want %d", w.Code, http.StatusOK) + t.Fatalf("listSessions(no user) status = %d, want %d", w.Code, http.StatusOK) } // Should return empty list, not an error @@ -341,29 +346,25 @@ func TestListSessions_NoSession(t *testing.T) { var items []sessionItem decodeData(t, w, &items) if len(items) != 0 { - t.Errorf("listSessions(no session) returned %d items, want 0", len(items)) + t.Errorf("listSessions(no user) returned %d items, want 0", len(items)) } } func FuzzCheckCSRF(f *testing.F) { sm := newTestSessionManager() - sessionID := uuid.New() - validToken := sm.NewCSRFToken(sessionID) + userID := uuid.New().String() + validToken := sm.NewCSRFToken(userID) - f.Add(sessionID.String(), validToken) - f.Add(sessionID.String(), "") - f.Add(sessionID.String(), "notanumber:signature") - f.Add(sessionID.String(), "12345:badsig") + f.Add(userID, validToken) + f.Add(userID, "") + f.Add(userID, "notanumber:signature") + f.Add(userID, "12345:badsig") f.Add(uuid.New().String(), validToken) f.Add("", "") f.Add("not-a-uuid", "1234:sig") - f.Fuzz(func(t *testing.T, sessionIDStr, token string) { - id, err := uuid.Parse(sessionIDStr) - if err != nil { - return - } - _ = sm.CheckCSRF(id, token) // must not panic + f.Fuzz(func(t *testing.T, uid, token string) { + _ = sm.CheckCSRF(uid, token) // must not panic }) } @@ -386,17 +387,17 @@ func FuzzCheckPreSessionCSRF(f *testing.F) { func BenchmarkNewCSRFToken(b *testing.B) { sm := newTestSessionManager() - sessionID := uuid.New() + userID := uuid.New().String() for b.Loop() { - sm.NewCSRFToken(sessionID) + sm.NewCSRFToken(userID) } } func BenchmarkCheckCSRF(b *testing.B) { sm := newTestSessionManager() - sessionID := uuid.New() - token := sm.NewCSRFToken(sessionID) + userID := uuid.New().String() + token := sm.NewCSRFToken(userID) for b.Loop() { - _ = sm.CheckCSRF(sessionID, token) + _ = sm.CheckCSRF(userID, token) } } diff --git a/internal/chat/chat_test.go b/internal/chat/chat_test.go index f0b4c70..086e63f 100644 --- a/internal/chat/chat_test.go +++ b/internal/chat/chat_test.go @@ -1,12 +1,15 @@ package chat import ( + "context" "log/slog" "strings" "testing" + "time" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/genkit" + "golang.org/x/time/rate" "github.com/koopa0/koopa/internal/session" ) @@ -308,3 +311,309 @@ func TestShallowCopyMap_MutateValue(t *testing.T) { copied["key"], "value") } } + +// testToolRef implements ai.ToolRef for testing tool caching. +type testToolRef struct { + name string +} + +func (r *testToolRef) Name() string { return r.name } + +// newTestAgent builds an Agent struct using the same defaults logic as New(), +// but bypasses genkit.LookupPrompt which requires a real Genkit environment. +func newTestAgent(t *testing.T, maxTurns int, language string, tokenBudget TokenBudget, retryConfig RetryConfig, cbConfig CircuitBreakerConfig, rl *rate.Limiter, toolNames []string) *Agent { + t.Helper() + + if maxTurns <= 0 { + maxTurns = 5 + } + + languagePrompt := language + if languagePrompt == "" || languagePrompt == "auto" { + languagePrompt = "the same language as the user's input (auto-detect)" + } + + if retryConfig.MaxRetries == 0 { + retryConfig = DefaultRetryConfig() + } + + if cbConfig.FailureThreshold == 0 { + cbConfig = DefaultCircuitBreakerConfig() + } + + if tokenBudget.MaxHistoryTokens == 0 { + tokenBudget = DefaultTokenBudget() + } + + if rl == nil { + rl = rate.NewLimiter(10, 30) + } + + if toolNames == nil { + toolNames = []string{"t1"} + } + + toolRefs := make([]ai.ToolRef, len(toolNames)) + for i, n := range toolNames { + toolRefs[i] = &testToolRef{name: n} + } + + return &Agent{ + maxTurns: maxTurns, + languagePrompt: languagePrompt, + retryConfig: retryConfig, + circuitBreaker: NewCircuitBreaker(cbConfig), + rateLimiter: rl, + tokenBudget: tokenBudget, + logger: slog.New(slog.DiscardHandler), + toolRefs: toolRefs, + toolNames: strings.Join(toolNames, ", "), + } +} + +// TestNew_Defaults verifies that New() applies correct defaults for optional fields. +func TestNew_Defaults(t *testing.T) { + t.Parallel() + + customLimiter := rate.NewLimiter(5, 10) + + tests := []struct { + name string + maxTurns int + language string + tokenBudget TokenBudget + retryConfig RetryConfig + cbConfig CircuitBreakerConfig + rateLimiter *rate.Limiter + toolNames []string + check func(t *testing.T, a *Agent) + }{ + { + name: "maxTurns zero defaults to 5", + maxTurns: 0, + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if a.maxTurns != 5 { + t.Errorf("New(MaxTurns=0).maxTurns = %d, want 5", a.maxTurns) + } + }, + }, + { + name: "maxTurns custom", + maxTurns: 20, + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if a.maxTurns != 20 { + t.Errorf("New(MaxTurns=20).maxTurns = %d, want 20", a.maxTurns) + } + }, + }, + { + name: "language empty defaults to auto-detect", + language: "", + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if !strings.Contains(a.languagePrompt, "auto-detect") { + t.Errorf("New(Language=\"\").languagePrompt = %q, want to contain %q", a.languagePrompt, "auto-detect") + } + }, + }, + { + name: "language auto defaults to auto-detect", + language: "auto", + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if !strings.Contains(a.languagePrompt, "auto-detect") { + t.Errorf("New(Language=\"auto\").languagePrompt = %q, want to contain %q", a.languagePrompt, "auto-detect") + } + }, + }, + { + name: "language custom", + language: "Japanese", + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if a.languagePrompt != "Japanese" { + t.Errorf("New(Language=\"Japanese\").languagePrompt = %q, want %q", a.languagePrompt, "Japanese") + } + }, + }, + { + name: "tokenBudget zero defaults to 32000", + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if a.tokenBudget.MaxHistoryTokens != 32000 { + t.Errorf("New(TokenBudget{}).MaxHistoryTokens = %d, want 32000", a.tokenBudget.MaxHistoryTokens) + } + }, + }, + { + name: "tokenBudget custom", + tokenBudget: TokenBudget{MaxHistoryTokens: 16000}, + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if a.tokenBudget.MaxHistoryTokens != 16000 { + t.Errorf("New(MaxHistoryTokens=16000).MaxHistoryTokens = %d, want 16000", a.tokenBudget.MaxHistoryTokens) + } + }, + }, + { + name: "rateLimiter nil creates default", + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if a.rateLimiter == nil { + t.Error("New(RateLimiter=nil).rateLimiter = nil, want non-nil default") + } + }, + }, + { + name: "rateLimiter custom", + rateLimiter: customLimiter, + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if a.rateLimiter != customLimiter { + t.Error("New(custom RateLimiter).rateLimiter != provided limiter") + } + }, + }, + { + name: "toolRefs cached", + toolNames: []string{"a", "b"}, + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if len(a.toolRefs) != 2 { + t.Errorf("New(2 tools).toolRefs len = %d, want 2", len(a.toolRefs)) + } + }, + }, + { + name: "toolNames formatted", + toolNames: []string{"a", "b"}, + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if a.toolNames != "a, b" { + t.Errorf("New(tools a,b).toolNames = %q, want %q", a.toolNames, "a, b") + } + }, + }, + { + name: "retryConfig zero defaults to MaxRetries=3", + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if a.retryConfig.MaxRetries != 3 { + t.Errorf("New(RetryConfig{}).MaxRetries = %d, want 3", a.retryConfig.MaxRetries) + } + }, + }, + { + name: "circuitBreaker created from defaults", + check: func(t *testing.T, a *Agent) { //nolint:thelper // table-driven check func, t.Helper() is noise + if a.circuitBreaker == nil { + t.Error("New(CBConfig{}).circuitBreaker = nil, want non-nil") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + a := newTestAgent(t, tt.maxTurns, tt.language, tt.tokenBudget, tt.retryConfig, tt.cbConfig, tt.rateLimiter, tt.toolNames) + tt.check(t, a) + }) + } +} + +// TestNew_PromptNotFound verifies that New returns an error when the dotprompt is not found. +func TestNew_PromptNotFound(t *testing.T) { + t.Parallel() + + ctx := context.Background() + g := genkit.Init(ctx) + + // Use genkit.DefineTool to create a real ai.Tool for Config validation. + tool := genkit.DefineTool(g, "test_tool", "test", func(_ *ai.ToolContext, _ string) (string, error) { + return "", nil + }) + + _, err := New(Config{ + Genkit: g, + SessionStore: new(session.Store), + Logger: slog.New(slog.DiscardHandler), + Tools: []ai.Tool{tool}, + }) + if err == nil { + t.Fatal("New(no prompt) expected error, got nil") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("New(no prompt) error = %q, want to contain %q", err.Error(), "not found") + } +} + +// TestMessagePreparation verifies the message preparation sequence: +// deepCopy -> truncate -> append user message. +func TestMessagePreparation(t *testing.T) { + t.Parallel() + + // Create an agent with small token budget to trigger truncation. + a := &Agent{ + logger: slog.New(slog.DiscardHandler), + tokenBudget: TokenBudget{MaxHistoryTokens: 50}, + } + + // Build a long history that exceeds the token budget. + // Each message ~50 tokens (100 chars / 2), so 3 messages ≈ 150 tokens > budget 50. + history := []*ai.Message{ + ai.NewUserMessage(ai.NewTextPart(strings.Repeat("a", 100))), + ai.NewModelMessage(ai.NewTextPart(strings.Repeat("b", 100))), + ai.NewUserMessage(ai.NewTextPart(strings.Repeat("c", 100))), + } + + // Simulate the preparation sequence from generateResponse + messages := deepCopyMessages(history) + messages = a.truncateHistory(messages, a.tokenBudget.MaxHistoryTokens) + messages = append(messages, ai.NewUserMessage(ai.NewTextPart("new question"))) + + t.Run("truncation reduces message count", func(t *testing.T) { + // History was truncated, so total messages < original 3 + 1 new + if len(messages) > 3 { + t.Errorf("message preparation: len = %d, want <= 3 (truncated + new)", len(messages)) + } + }) + + t.Run("user message is last", func(t *testing.T) { + last := messages[len(messages)-1] + if last.Role != ai.RoleUser { + t.Errorf("message preparation: last.Role = %q, want %q", last.Role, ai.RoleUser) + } + if last.Content[0].Text != "new question" { + t.Errorf("message preparation: last.Text = %q, want %q", last.Content[0].Text, "new question") + } + }) + + t.Run("original history unmodified", func(t *testing.T) { + if len(history) != 3 { + t.Errorf("message preparation: original history len = %d, want 3", len(history)) + } + }) +} + +// TestGenerateResponse_CircuitBreakerOpen verifies that generateResponse rejects +// requests when the circuit breaker is open. +func TestGenerateResponse_CircuitBreakerOpen(t *testing.T) { + t.Parallel() + + cb := NewCircuitBreaker(CircuitBreakerConfig{ + FailureThreshold: 1, + SuccessThreshold: 1, + Timeout: 1 * time.Hour, // long timeout to keep circuit open + }) + // Force circuit open by recording a failure + cb.Failure() + if cb.State() != CircuitOpen { + t.Fatalf("circuit breaker state = %v, want %v", cb.State(), CircuitOpen) + } + + a := &Agent{ + logger: slog.New(slog.DiscardHandler), + circuitBreaker: cb, + tokenBudget: DefaultTokenBudget(), + rateLimiter: rate.NewLimiter(10, 30), + } + + _, err := a.generateResponse(context.Background(), "hello", nil, nil) + if err == nil { + t.Fatal("generateResponse(CB open) expected error, got nil") + } + if !strings.Contains(err.Error(), "service unavailable") { + t.Errorf("generateResponse(CB open) error = %q, want to contain %q", err.Error(), "service unavailable") + } +} diff --git a/internal/chat/setup_test.go b/internal/chat/setup_test.go index 8cd1bed..da5997e 100644 --- a/internal/chat/setup_test.go +++ b/internal/chat/setup_test.go @@ -106,7 +106,7 @@ func SetupTest(t *testing.T) *TestFramework { } // Create test session - testSession, err := sessionStore.CreateSession(ctx, "Chat Integration Test") + testSession, err := sessionStore.CreateSession(ctx, "test-user", "Chat Integration Test") if err != nil { t.Fatalf("creating test session: %v", err) } @@ -164,7 +164,7 @@ func SetupTest(t *testing.T) *TestFramework { func (f *TestFramework) CreateTestSession(t *testing.T, name string) uuid.UUID { t.Helper() ctx := context.Background() - sess, err := f.SessionStore.CreateSession(ctx, name) + sess, err := f.SessionStore.CreateSession(ctx, "test-user", name) if err != nil { t.Fatalf("creating test session: %v", err) } diff --git a/internal/session/benchmark_test.go b/internal/session/benchmark_test.go index bf7f20b..a5f9bbb 100644 --- a/internal/session/benchmark_test.go +++ b/internal/session/benchmark_test.go @@ -84,7 +84,7 @@ func BenchmarkStore_AddMessages(b *testing.B) { store := New(sqlc.New(pool), pool, logger) // Create a test session - session, err := store.CreateSession(ctx, "Benchmark-AddMessages") + session, err := store.CreateSession(ctx, "bench-owner", "Benchmark-AddMessages") if err != nil { b.Fatalf("creating session: %v", err) } @@ -125,7 +125,7 @@ func BenchmarkStore_AppendMessages(b *testing.B) { store := New(sqlc.New(pool), pool, logger) // Create a test session - session, err := store.CreateSession(ctx, "Benchmark-AppendMessages") + session, err := store.CreateSession(ctx, "bench-owner", "Benchmark-AppendMessages") if err != nil { b.Fatalf("creating session: %v", err) } @@ -169,7 +169,7 @@ func BenchmarkStore_CreateSession(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := range b.N { - session, err := store.CreateSession(ctx, fmt.Sprintf("Benchmark-Session-%d", i)) + session, err := store.CreateSession(ctx, "bench-owner", fmt.Sprintf("Benchmark-Session-%d", i)) if err != nil { b.Fatalf("CreateSession failed at iteration %d: %v", i, err) } @@ -187,7 +187,7 @@ func BenchmarkStore_GetSession(b *testing.B) { store := New(sqlc.New(pool), pool, logger) // Create a test session - session, err := store.CreateSession(ctx, "Benchmark-GetSession") + session, err := store.CreateSession(ctx, "bench-owner", "Benchmark-GetSession") if err != nil { b.Fatalf("creating session: %v", err) } @@ -214,7 +214,7 @@ func BenchmarkStore_Sessions(b *testing.B) { // Create some test sessions for i := 0; i < 20; i++ { - session, err := store.CreateSession(ctx, fmt.Sprintf("Benchmark-List-%d", i)) + session, err := store.CreateSession(ctx, "bench-owner", fmt.Sprintf("Benchmark-List-%d", i)) if err != nil { b.Fatalf("creating session: %v", err) } @@ -224,7 +224,7 @@ func BenchmarkStore_Sessions(b *testing.B) { b.ReportAllocs() b.ResetTimer() for b.Loop() { - _, err := store.Sessions(ctx, 100, 0) + _, err := store.Sessions(ctx, "bench-owner", 100, 0) if err != nil { b.Fatalf("Sessions() unexpected error: %v", err) } @@ -256,7 +256,7 @@ func setupBenchmarkSession(b *testing.B, ctx context.Context, numMessages int) ( store := New(sqlc.New(pool), pool, logger) // Create a test session - session, err := store.CreateSession(ctx, "Benchmark-Session") + session, err := store.CreateSession(ctx, "bench-owner", "Benchmark-Session") if err != nil { cleanup() b.Fatalf("creating session: %v", err) diff --git a/internal/session/integration_test.go b/internal/session/integration_test.go index 35593f5..4a67f67 100644 --- a/internal/session/integration_test.go +++ b/internal/session/integration_test.go @@ -35,7 +35,7 @@ func TestStore_CreateAndGet(t *testing.T) { ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "Test Session") + session, err := store.CreateSession(ctx, "test-owner", "Test Session") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -48,6 +48,9 @@ func TestStore_CreateAndGet(t *testing.T) { if session.Title != "Test Session" { t.Errorf("CreateSession() Title = %q, want %q", session.Title, "Test Session") } + if session.OwnerID != "test-owner" { + t.Errorf("CreateSession() OwnerID = %q, want %q", session.OwnerID, "test-owner") + } if session.CreatedAt.IsZero() { t.Error("CreateSession() CreatedAt should be set") } @@ -77,7 +80,7 @@ func TestStore_CreateWithEmptyFields(t *testing.T) { ctx := context.Background() // Create session with empty title - session, err := store.CreateSession(ctx, "") + session, err := store.CreateSession(ctx, "test-owner", "") if err != nil { t.Fatalf("CreateSession() with empty fields unexpected error: %v", err) } @@ -105,14 +108,14 @@ func TestStore_ListSessions_Integration(t *testing.T) { // Create multiple sessions for i := 0; i < 5; i++ { - _, err := store.CreateSession(ctx, fmt.Sprintf("Session %d", i+1)) + _, err := store.CreateSession(ctx, "test-owner", fmt.Sprintf("Session %d", i+1)) if err != nil { t.Fatalf("CreateSession(%d) unexpected error: %v", i+1, err) } } // List all sessions - sessions, err := store.Sessions(ctx, 10, 0) + sessions, err := store.Sessions(ctx, "test-owner", 10, 0) if err != nil { t.Fatalf("Sessions(10, 0) unexpected error: %v", err) } @@ -121,7 +124,7 @@ func TestStore_ListSessions_Integration(t *testing.T) { } // Test pagination - first 3 - sessions, err = store.Sessions(ctx, 3, 0) + sessions, err = store.Sessions(ctx, "test-owner", 3, 0) if err != nil { t.Fatalf("Sessions(3, 0) unexpected error: %v", err) } @@ -130,7 +133,7 @@ func TestStore_ListSessions_Integration(t *testing.T) { } // Test pagination - next 2 - sessions, err = store.Sessions(ctx, 3, 3) + sessions, err = store.Sessions(ctx, "test-owner", 3, 3) if err != nil { t.Fatalf("Sessions(3, 3) unexpected error: %v", err) } @@ -145,7 +148,7 @@ func TestStore_DeleteSession_Integration(t *testing.T) { ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "To Be Deleted") + session, err := store.CreateSession(ctx, "test-owner", "To Be Deleted") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -175,7 +178,7 @@ func TestStore_AddMessage(t *testing.T) { ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "Message Test") + session, err := store.CreateSession(ctx, "test-owner", "Message Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -236,7 +239,7 @@ func TestStore_GetMessages_Integration(t *testing.T) { ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "Pagination Test") + session, err := store.CreateSession(ctx, "test-owner", "Pagination Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -288,7 +291,7 @@ func TestStore_MessageOrdering(t *testing.T) { ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "Ordering Test") + session, err := store.CreateSession(ctx, "test-owner", "Ordering Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -353,7 +356,7 @@ func TestStore_LargeMessageContent(t *testing.T) { ctx := context.Background() // Create a session - session, err := store.CreateSession(ctx, "Large Content Test") + session, err := store.CreateSession(ctx, "test-owner", "Large Content Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -392,7 +395,7 @@ func TestStore_DeleteSessionWithMessages(t *testing.T) { ctx := context.Background() // Create a session with messages - session, err := store.CreateSession(ctx, "Cascade Delete Test") + session, err := store.CreateSession(ctx, "test-owner", "Cascade Delete Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -447,7 +450,7 @@ func TestStore_ConcurrentSessionCreation(t *testing.T) { go func(id int) { defer wg.Done() title := fmt.Sprintf("Race-Session-%d", id) - session, err := store.CreateSession(ctx, title) + session, err := store.CreateSession(ctx, "test-owner", title) if err != nil { errs <- fmt.Errorf("goroutine %d: %w", id, err) return @@ -489,7 +492,7 @@ func TestStore_ConcurrentHistoryUpdate(t *testing.T) { ctx := context.Background() // Create a test session - session, err := store.CreateSession(ctx, "Race-History-Test") + session, err := store.CreateSession(ctx, "test-owner", "Race-History-Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -565,7 +568,7 @@ func TestStore_ConcurrentLoadAndSaveHistory(t *testing.T) { ctx := context.Background() // Create a test session with initial messages - session, err := store.CreateSession(ctx, "Race-LoadSave-Test") + session, err := store.CreateSession(ctx, "test-owner", "Race-LoadSave-Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -652,7 +655,7 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) { // Create test sessions for i := 0; i < numSessions; i++ { - session, err := store.CreateSession(ctx, fmt.Sprintf("Race-Delete-Test-%d", i)) + session, err := store.CreateSession(ctx, "test-owner", fmt.Sprintf("Race-Delete-Test-%d", i)) if err != nil { t.Fatalf("CreateSession() for session %d unexpected error: %v", i, err) } @@ -686,7 +689,7 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) { // List goroutine go func() { defer wg.Done() - _, _ = store.Sessions(ctx, 100, 0) + _, _ = store.Sessions(ctx, "test-owner", 100, 0) }() // Get goroutine @@ -699,7 +702,7 @@ func TestStore_ConcurrentSessionDeletion(t *testing.T) { wg.Wait() // Verify all sessions are deleted - remaining, err := store.Sessions(ctx, 100, 0) + remaining, err := store.Sessions(ctx, "test-owner", 100, 0) if err != nil { t.Fatalf("Sessions(100, 0) after concurrent deletion unexpected error: %v", err) } @@ -728,7 +731,7 @@ func TestStore_RaceDetector(t *testing.T) { ctx := context.Background() // Create a shared session - session, err := store.CreateSession(ctx, "Race-Detector-Test") + session, err := store.CreateSession(ctx, "test-owner", "Race-Detector-Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -783,7 +786,7 @@ func TestStore_ConcurrentWrites(t *testing.T) { numSessions := 5 sessions := make([]*Session, numSessions) for i := 0; i < numSessions; i++ { - session, err := store.CreateSession(ctx, fmt.Sprintf("Concurrent Session %d", i+1)) + session, err := store.CreateSession(ctx, "test-owner", fmt.Sprintf("Concurrent Session %d", i+1)) if err != nil { t.Fatalf("CreateSession() for session %d unexpected error: %v", i+1, err) } @@ -842,14 +845,14 @@ func TestStore_SQLInjectionPrevention(t *testing.T) { ctx := context.Background() // First, create a legitimate session - legitSession, err := store.CreateSession(ctx, "Legitimate Session") + legitSession, err := store.CreateSession(ctx, "test-owner", "Legitimate Session") if err != nil { t.Fatalf("CreateSession() for legitimate session unexpected error: %v", err) } t.Logf("Created legitimate session: %s", legitSession.ID) // Count sessions before attacks - sessions, err := store.Sessions(ctx, 100, 0) + sessions, err := store.Sessions(ctx, "test-owner", 100, 0) if err != nil { t.Fatalf("Sessions(100, 0) before attacks unexpected error: %v", err) } @@ -889,7 +892,7 @@ func TestStore_SQLInjectionPrevention(t *testing.T) { for _, tc := range maliciousTitles { t.Run("title_"+tc.name, func(t *testing.T) { // Attempt SQL injection via session title - session, err := store.CreateSession(ctx, tc.title) + session, err := store.CreateSession(ctx, "test-owner", tc.title) // Should either succeed (with escaped title) or fail safely if err != nil { @@ -906,7 +909,7 @@ func TestStore_SQLInjectionPrevention(t *testing.T) { // Verify database integrity t.Run("verify database integrity", func(t *testing.T) { // Sessions table should still exist - sessions, err := store.Sessions(ctx, 100, 0) + sessions, err := store.Sessions(ctx, "test-owner", 100, 0) if err != nil { t.Fatalf("Sessions(100, 0) after attacks unexpected error: %v (sessions table should still exist)", err) } @@ -940,7 +943,7 @@ func TestStore_SQLInjectionViaSessionID(t *testing.T) { ctx := context.Background() // Create a test session - session, err := store.CreateSession(ctx, "Test Session") + session, err := store.CreateSession(ctx, "test-owner", "Test Session") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } @@ -1029,7 +1032,7 @@ func TestStore_SQLInjectionViaMessageContent(t *testing.T) { ctx := context.Background() // Create a test session - session, err := store.CreateSession(ctx, "Message Test") + session, err := store.CreateSession(ctx, "test-owner", "Message Test") if err != nil { t.Fatalf("CreateSession() unexpected error: %v", err) } diff --git a/internal/session/session.go b/internal/session/session.go index 4975c54..9c68285 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -14,6 +14,7 @@ var ErrNotFound = errors.New("session not found") // Session represents a conversation session (application-level type). type Session struct { ID uuid.UUID + OwnerID string Title string CreatedAt time.Time UpdatedAt time.Time diff --git a/internal/session/store.go b/internal/session/store.go index edb13e2..2425e30 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -48,58 +48,63 @@ func New(queries *sqlc.Queries, pool *pgxpool.Pool, logger *slog.Logger) *Store } } -// CreateSession creates a new conversation session. +// CreateSession creates a new conversation session owned by the given user. // // Parameters: // - ctx: Context for the operation +// - ownerID: User identity that owns this session // - title: Session title (empty string = no title) // // Returns: // - *Session: Created session with generated UUID // - error: If creation fails -func (s *Store) CreateSession(ctx context.Context, title string) (*Session, error) { +func (s *Store) CreateSession(ctx context.Context, ownerID, title string) (*Session, error) { var titlePtr *string if title != "" { titlePtr = &title } - sqlcSession, err := s.queries.CreateSession(ctx, titlePtr) + sqlcSession, err := s.queries.CreateSession(ctx, sqlc.CreateSessionParams{ + Title: titlePtr, + OwnerID: ownerID, + }) if err != nil { return nil, fmt.Errorf("creating session: %w", err) } session := s.sqlcSessionToSession(sqlcSession) - s.logger.Debug("created session", "id", session.ID, "title", session.Title) + s.logger.Debug("created session", "id", session.ID, "owner", ownerID, "title", session.Title) return session, nil } // Session retrieves a session by ID. // Returns ErrNotFound if the session does not exist. func (s *Store) Session(ctx context.Context, sessionID uuid.UUID) (*Session, error) { - sqlcSession, err := s.queries.Session(ctx, sessionID) + row, err := s.queries.Session(ctx, sessionID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - // Return sentinel error directly (no wrapping per reviewer guidance) return nil, ErrNotFound } return nil, fmt.Errorf("getting session %s: %w", sessionID, err) } - return s.sqlcSessionToSession(sqlcSession), nil + return s.sqlcSessionRowToSession(row), nil } -// Sessions lists sessions with pagination, ordered by updated_at descending. +// Sessions lists sessions owned by the given user, ordered by updated_at descending. // // Parameters: // - ctx: Context for the operation +// - ownerID: User identity to filter by // - limit: Maximum number of sessions to return // - offset: Number of sessions to skip (for pagination) // // Returns: // - []*Session: List of sessions // - error: If listing fails -func (s *Store) Sessions(ctx context.Context, limit, offset int32) ([]*Session, error) { - sqlcSessions, err := s.queries.Sessions(ctx, sqlc.SessionsParams{ +func (s *Store) Sessions(ctx context.Context, ownerID string, limit, offset int32) ([]*Session, error) { + rows, err := s.queries.Sessions(ctx, sqlc.SessionsParams{ + OwnerID: ownerID, ResultLimit: limit, ResultOffset: offset, }) @@ -107,12 +112,12 @@ func (s *Store) Sessions(ctx context.Context, limit, offset int32) ([]*Session, return nil, fmt.Errorf("listing sessions: %w", err) } - sessions := make([]*Session, 0, len(sqlcSessions)) - for i := range sqlcSessions { - sessions = append(sessions, s.sqlcSessionToSession(sqlcSessions[i])) + sessions := make([]*Session, 0, len(rows)) + for i := range rows { + sessions = append(sessions, s.sqlcSessionsRowToSession(rows[i])) } - s.logger.Debug("listed sessions", "count", len(sessions), "limit", limit, "offset", offset) + s.logger.Debug("listed sessions", "owner", ownerID, "count", len(sessions), "limit", limit, "offset", offset) return sessions, nil } @@ -307,6 +312,15 @@ func normalizeRole(role string) string { return role } +// denormalizeRole converts database roles back to Genkit roles. +// Reverses normalizeRole: "assistant" → "model" so Gemini API accepts the history. +func denormalizeRole(role string) string { + if role == "assistant" { + return "model" + } + return role +} + // AppendMessages appends new messages to a session. // This is the preferred method for saving conversation history. // @@ -359,12 +373,13 @@ func (s *Store) History(ctx context.Context, sessionID uuid.UUID) ([]*ai.Message return nil, fmt.Errorf("loading history: %w", err) } - // Convert to ai.Message + // Convert to ai.Message with reverse role normalization. + // DB stores "assistant" but Gemini API requires "model". aiMessages := make([]*ai.Message, len(messages)) for i, msg := range messages { aiMessages[i] = &ai.Message{ Content: msg.Content, - Role: ai.Role(msg.Role), + Role: ai.Role(denormalizeRole(msg.Role)), } } @@ -394,7 +409,7 @@ func (s *Store) ResolveCurrentSession(ctx context.Context) (uuid.UUID, error) { } } - newSess, err := s.CreateSession(ctx, "") + newSess, err := s.CreateSession(ctx, "cli", "") if err != nil { return uuid.Nil, fmt.Errorf("creating session: %w", err) } @@ -408,18 +423,45 @@ func (s *Store) ResolveCurrentSession(ctx context.Context) (uuid.UUID, error) { return newSess.ID, nil } -// sqlcSessionToSession converts sqlc.Session to Session (application type). +// sqlcSessionToSession converts sqlc.Session (from CreateSession RETURNING *) to Session. func (*Store) sqlcSessionToSession(ss sqlc.Session) *Session { session := &Session{ ID: ss.ID, + OwnerID: ss.OwnerID, CreatedAt: ss.CreatedAt.Time, UpdatedAt: ss.UpdatedAt.Time, } - if ss.Title != nil { session.Title = *ss.Title } + return session +} +// sqlcSessionRowToSession converts sqlc.SessionRow (from Session query) to Session. +func (*Store) sqlcSessionRowToSession(row sqlc.SessionRow) *Session { + session := &Session{ + ID: row.ID, + OwnerID: row.OwnerID, + CreatedAt: row.CreatedAt.Time, + UpdatedAt: row.UpdatedAt.Time, + } + if row.Title != nil { + session.Title = *row.Title + } + return session +} + +// sqlcSessionsRowToSession converts sqlc.SessionsRow (from Sessions query) to Session. +func (*Store) sqlcSessionsRowToSession(row sqlc.SessionsRow) *Session { + session := &Session{ + ID: row.ID, + OwnerID: row.OwnerID, + CreatedAt: row.CreatedAt.Time, + UpdatedAt: row.UpdatedAt.Time, + } + if row.Title != nil { + session.Title = *row.Title + } return session } diff --git a/internal/sqlc/models.go b/internal/sqlc/models.go index 4ec4daf..6f8e0be 100644 --- a/internal/sqlc/models.go +++ b/internal/sqlc/models.go @@ -32,4 +32,5 @@ type Session struct { Title *string `json:"title"` CreatedAt pgtype.Timestamptz `json:"created_at"` UpdatedAt pgtype.Timestamptz `json:"updated_at"` + OwnerID string `json:"owner_id"` } diff --git a/internal/sqlc/sessions.sql.go b/internal/sqlc/sessions.sql.go index df5e238..befca87 100644 --- a/internal/sqlc/sessions.sql.go +++ b/internal/sqlc/sessions.sql.go @@ -9,6 +9,7 @@ import ( "context" "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" ) const addMessage = `-- name: AddMessage :exec @@ -36,21 +37,27 @@ func (q *Queries) AddMessage(ctx context.Context, arg AddMessageParams) error { const createSession = `-- name: CreateSession :one -INSERT INTO sessions (title) -VALUES ($1) -RETURNING id, title, created_at, updated_at +INSERT INTO sessions (title, owner_id) +VALUES ($1, $2) +RETURNING id, title, created_at, updated_at, owner_id ` +type CreateSessionParams struct { + Title *string `json:"title"` + OwnerID string `json:"owner_id"` +} + // Sessions and messages queries for sqlc // Generated code will be in internal/sqlc/sessions.sql.go -func (q *Queries) CreateSession(ctx context.Context, title *string) (Session, error) { - row := q.db.QueryRow(ctx, createSession, title) +func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { + row := q.db.QueryRow(ctx, createSession, arg.Title, arg.OwnerID) var i Session err := row.Scan( &i.ID, &i.Title, &i.CreatedAt, &i.UpdatedAt, + &i.OwnerID, ) return i, err } @@ -65,13 +72,24 @@ func (q *Queries) DeleteSession(ctx context.Context, id uuid.UUID) error { return err } +const lockSession = `-- name: LockSession :one +SELECT id FROM sessions WHERE id = $1 FOR UPDATE +` + +// Locks the session row to prevent concurrent modifications +func (q *Queries) LockSession(ctx context.Context, id uuid.UUID) (uuid.UUID, error) { + row := q.db.QueryRow(ctx, lockSession, id) + err := row.Scan(&id) + return id, err +} + const maxSequenceNumber = `-- name: MaxSequenceNumber :one SELECT COALESCE(MAX(sequence_number), 0)::integer AS max_seq FROM messages WHERE session_id = $1 ` -// MaxSequenceNumber returns the max sequence number for a session (returns 0 if no messages). +// Get max sequence number for a session (returns 0 if no messages) func (q *Queries) MaxSequenceNumber(ctx context.Context, sessionID uuid.UUID) (int32, error) { row := q.db.QueryRow(ctx, maxSequenceNumber, sessionID) var max_seq int32 @@ -79,17 +97,6 @@ func (q *Queries) MaxSequenceNumber(ctx context.Context, sessionID uuid.UUID) (i return max_seq, err } -const lockSession = `-- name: LockSession :one -SELECT id FROM sessions WHERE id = $1 FOR UPDATE -` - -// Locks the session row to prevent concurrent modifications -func (q *Queries) LockSession(ctx context.Context, id uuid.UUID) (uuid.UUID, error) { - row := q.db.QueryRow(ctx, lockSession, id) - err := row.Scan(&id) - return id, err -} - const messages = `-- name: Messages :many SELECT id, session_id, role, content, sequence_number, created_at FROM messages @@ -134,17 +141,60 @@ func (q *Queries) Messages(ctx context.Context, arg MessagesParams) ([]Message, } const session = `-- name: Session :one -SELECT id, title, created_at, updated_at +SELECT id, title, owner_id, created_at, updated_at FROM sessions WHERE id = $1 ` -func (q *Queries) Session(ctx context.Context, id uuid.UUID) (Session, error) { +type SessionRow struct { + ID uuid.UUID `json:"id"` + Title *string `json:"title"` + OwnerID string `json:"owner_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +func (q *Queries) Session(ctx context.Context, id uuid.UUID) (SessionRow, error) { row := q.db.QueryRow(ctx, session, id) - var i Session + var i SessionRow + err := row.Scan( + &i.ID, + &i.Title, + &i.OwnerID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const sessionByIDAndOwner = `-- name: SessionByIDAndOwner :one +SELECT id, title, owner_id, created_at, updated_at +FROM sessions +WHERE id = $1 AND owner_id = $2 +` + +type SessionByIDAndOwnerParams struct { + SessionID uuid.UUID `json:"session_id"` + OwnerID string `json:"owner_id"` +} + +type SessionByIDAndOwnerRow struct { + ID uuid.UUID `json:"id"` + Title *string `json:"title"` + OwnerID string `json:"owner_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +// Verify session exists and is owned by the given user. +// Used for ownership checks without a separate query + comparison. +func (q *Queries) SessionByIDAndOwner(ctx context.Context, arg SessionByIDAndOwnerParams) (SessionByIDAndOwnerRow, error) { + row := q.db.QueryRow(ctx, sessionByIDAndOwner, arg.SessionID, arg.OwnerID) + var i SessionByIDAndOwnerRow err := row.Scan( &i.ID, &i.Title, + &i.OwnerID, &i.CreatedAt, &i.UpdatedAt, ) @@ -152,30 +202,41 @@ func (q *Queries) Session(ctx context.Context, id uuid.UUID) (Session, error) { } const sessions = `-- name: Sessions :many -SELECT id, title, created_at, updated_at +SELECT id, title, owner_id, created_at, updated_at FROM sessions +WHERE owner_id = $1 ORDER BY updated_at DESC -LIMIT $2 -OFFSET $1 +LIMIT $3 +OFFSET $2 ` type SessionsParams struct { - ResultOffset int32 `json:"result_offset"` - ResultLimit int32 `json:"result_limit"` + OwnerID string `json:"owner_id"` + ResultOffset int32 `json:"result_offset"` + ResultLimit int32 `json:"result_limit"` +} + +type SessionsRow struct { + ID uuid.UUID `json:"id"` + Title *string `json:"title"` + OwnerID string `json:"owner_id"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` } -func (q *Queries) Sessions(ctx context.Context, arg SessionsParams) ([]Session, error) { - rows, err := q.db.Query(ctx, sessions, arg.ResultOffset, arg.ResultLimit) +func (q *Queries) Sessions(ctx context.Context, arg SessionsParams) ([]SessionsRow, error) { + rows, err := q.db.Query(ctx, sessions, arg.OwnerID, arg.ResultOffset, arg.ResultLimit) if err != nil { return nil, err } defer rows.Close() - items := []Session{} + items := []SessionsRow{} for rows.Next() { - var i Session + var i SessionsRow if err := rows.Scan( &i.ID, &i.Title, + &i.OwnerID, &i.CreatedAt, &i.UpdatedAt, ); err != nil { diff --git a/internal/tools/network.go b/internal/tools/network.go index 32f3c64..8849a20 100644 --- a/internal/tools/network.go +++ b/internal/tools/network.go @@ -98,9 +98,11 @@ func NewNetwork(cfg NetConfig, logger *slog.Logger) (*Network, error) { return &Network{ searchBaseURL: strings.TrimSuffix(cfg.SearchBaseURL, "/"), + // searchClient uses default transport: searchBaseURL is admin-configured + // infrastructure (like a database URL), not user-controlled input. + // SSRF protection applies to web_fetch (user/LLM-controlled URLs), not here. searchClient: &http.Client{ - Timeout: 30 * time.Second, - Transport: urlValidator.SafeTransport(), + Timeout: 30 * time.Second, }, fetchParallelism: cfg.FetchParallelism, fetchDelay: cfg.FetchDelay, diff --git a/internal/tools/setup_test.go b/internal/tools/setup_test.go index 491c48f..368708f 100644 --- a/internal/tools/setup_test.go +++ b/internal/tools/setup_test.go @@ -20,6 +20,5 @@ func newNetworkForTesting(tb testing.TB, cfg NetConfig, logger *slog.Logger) *Ne tb.Fatalf("NewNetwork() unexpected error: %v", err) } nt.skipSSRFCheck = true - nt.searchClient.Transport = nil // allow localhost in tests return nt } diff --git a/internal/tui/integration_test.go b/internal/tui/integration_test.go index 5a80ea9..db4663f 100644 --- a/internal/tui/integration_test.go +++ b/internal/tui/integration_test.go @@ -28,7 +28,7 @@ func TestMain(m *testing.M) { // createTestSession creates a session in the database and returns its ID and cleanup function. func createTestSession(t *testing.T, setup *chatFlowSetup) (uuid.UUID, func()) { t.Helper() - sess, err := setup.SessionStore.CreateSession(setup.Ctx, "test-session") + sess, err := setup.SessionStore.CreateSession(setup.Ctx, "test-user", "test-session") if err != nil { t.Fatalf("CreateSession() error: %v", err) }