Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions db/migrations/000004_trigram_search.up.sql
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
-- Enable pg_trgm extension for trigram-based text search (CJK support).
CREATE EXTENSION IF NOT EXISTS pg_trgm;

-- CONCURRENTLY avoids locking the table during index creation.
-- NOTE: CONCURRENTLY cannot run inside a transaction block.
-- golang-migrate runs each file in a transaction by default;
-- the operator must run this migration manually with:
-- psql -f 000004_trigram_search.up.sql
-- or disable transactions in the migration tool.

-- GIN trigram index on messages.text_content for ILIKE fallback search.
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_messages_text_content_trgm
-- NOTE: For large production tables with existing data, create these indexes
-- manually with CONCURRENTLY before running this migration (they will be
-- no-ops due to IF NOT EXISTS):
-- CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_messages_text_content_trgm
-- ON messages USING gin (text_content gin_trgm_ops);
-- CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_memories_content_trgm
-- ON memories USING gin (content gin_trgm_ops);
CREATE INDEX IF NOT EXISTS idx_messages_text_content_trgm
ON messages USING gin (text_content gin_trgm_ops);

-- GIN trigram index on memories.content for similarity() scoring.
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_memories_content_trgm
CREATE INDEX IF NOT EXISTS idx_memories_content_trgm
ON memories USING gin (content gin_trgm_ops);
Comment on lines +12 to 17

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Risk of Table Locks and Downtime on Large Tables

Creating GIN indexes without the CONCURRENTLY keyword (lines 12-17) can lock writes to the messages and memories tables, potentially causing downtime in production environments with large datasets. Although the comments warn about this, the migration itself does not enforce safe index creation. To mitigate this risk, consider splitting index creation into a separate migration using CONCURRENTLY, or ensure that this migration is not run on large tables in production. Example:

CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_messages_text_content_trgm
    ON messages USING gin (text_content gin_trgm_ops);

Recommendation:

  • For production environments, manually create these indexes with CONCURRENTLY before running this migration, or modify the migration to use CONCURRENTLY if your migration framework supports it.

2 changes: 2 additions & 0 deletions db/migrations/000005_owner_id_check.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE sessions ALTER COLUMN owner_id SET DEFAULT '';

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Data Integrity Risk: Setting the default value of owner_id to an empty string ('') may lead to ambiguous or invalid data, as empty strings are not valid identifiers. This can cause issues in application logic and downstream processes.

Recommendation: Consider using NULL as the default if the absence of an owner is valid, or ensure that a valid identifier is always provided.

ALTER TABLE sessions DROP CONSTRAINT IF EXISTS sessions_owner_id_not_empty;
7 changes: 7 additions & 0 deletions db/migrations/000005_owner_id_check.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- Backfill empty owner_id before adding constraint.
-- Sessions created before ownership was enforced may have owner_id = ''.
UPDATE sessions SET owner_id = 'anonymous' WHERE owner_id = '';

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential Data Integrity Issue:
The UPDATE sessions SET owner_id = 'anonymous' WHERE owner_id = ''; statement only backfills empty strings, not NULL values. If any rows have owner_id IS NULL, they will not be updated and may violate the new constraint.

Recommended Solution:
Update both empty and NULL values:

UPDATE sessions SET owner_id = 'anonymous' WHERE owner_id IS NULL OR owner_id = '';


-- Prevent empty owner_id in sessions (enforces NOT NULL + non-empty).
ALTER TABLE sessions ADD CONSTRAINT sessions_owner_id_not_empty CHECK (owner_id != '');

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constraint Does Not Prevent NULL Values:
The CHECK constraint CHECK (owner_id != '') only prevents empty strings, not NULLs. If the column is not already defined as NOT NULL, NULL values can still be inserted, bypassing the constraint.

Recommended Solution:
Ensure the column is NOT NULL or update the constraint to:

ALTER TABLE sessions ADD CONSTRAINT sessions_owner_id_not_empty CHECK (owner_id IS NOT NULL AND owner_id != '');

ALTER TABLE sessions ALTER COLUMN owner_id DROP DEFAULT;
1 change: 1 addition & 0 deletions internal/api/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"

"github.com/google/uuid"

"github.com/koopa0/koopa/internal/chat"
"github.com/koopa0/koopa/internal/session"
"github.com/koopa0/koopa/internal/tools"
Expand Down
17 changes: 15 additions & 2 deletions internal/api/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//
// The API server uses Go 1.22+ routing with a layered middleware stack:
//
// Recovery → Logging → RateLimit → CORS → Session → CSRF → Routes
// Recovery → RequestID → Logging → CORS → RateLimit → User → Session → CSRF → Routes
//
// Health probes (/health, /ready) bypass the middleware stack via a
// top-level mux, ensuring they remain fast and unauthenticated.
Expand All @@ -24,11 +24,24 @@
// - GET /api/v1/sessions/{id} — get session by ID
// - GET /api/v1/sessions/{id}/messages — get session messages
// - DELETE /api/v1/sessions/{id} — delete session
// - GET /api/v1/sessions/{id}/export — export session
//
// Chat (ownership-enforced):
// - POST /api/v1/chat — initiate chat, returns stream URL
// - GET /api/v1/chat/stream — SSE endpoint for streaming responses
//
// Search:
// - GET /api/v1/search — full-text search across messages
//
// Stats:
// - GET /api/v1/stats — usage statistics (sessions, messages, memories)
//
// Memory (ownership-enforced):
// - GET /api/v1/memories — list memories
// - POST /api/v1/memories — create memory
// - DELETE /api/v1/memories/{id} — delete memory
// - GET /api/v1/memories/search — search memories
//
// # CSRF Token Model
//
// Two token types prevent cross-site request forgery:
Expand All @@ -39,7 +52,7 @@
// - Session-bound tokens ("timestamp:signature"): bound to a specific
// session via HMAC-SHA256, verified with constant-time comparison.
//
// Both expire after 24 hours with 5 minutes of clock skew tolerance.
// Both expire after 1 hour with 5 minutes of clock skew tolerance.
//
// # Session Ownership
//
Expand Down
2 changes: 1 addition & 1 deletion internal/api/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (h *searchHandler) searchMessages(w http.ResponseWriter, r *http.Request) {

results, total, err := h.store.SearchMessages(r.Context(), userID, query, limit, offset)
if err != nil {
h.logger.Error("searching messages", "error", err, "user_id", userID, "query", query)
h.logger.Error("searching messages", "error", err, "user_id", userID, "query_len", len(query))
WriteError(w, http.StatusInternalServerError, "search_failed", "failed to search messages", h.logger)
return
}
Comment on lines 47 to 51

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error Handling Granularity:
The error handling block does not distinguish between user input errors (e.g., malformed queries) and internal server errors. All errors are treated as internal failures, which may result in less informative responses for the client and complicate debugging.

Recommendation:
Consider inspecting the error returned by h.store.SearchMessages and returning more specific error codes/messages for user errors versus internal errors. For example:

if errors.Is(err, session.ErrInvalidQuery) {
    WriteError(w, http.StatusBadRequest, "invalid_query", "query is malformed", h.logger)
    return
}

This improves maintainability and user experience.

Comment on lines 46 to 51

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance and Context Handling:
The search operation (h.store.SearchMessages) may be long-running, but there is no explicit handling for context cancellation or timeouts. If the client disconnects or the request times out, the search may continue unnecessarily, impacting server performance.

Recommendation:
Ensure that SearchMessages respects the request context for cancellation and timeout. If not already implemented, propagate r.Context() and handle context cancellation within the search logic.

Expand Down
54 changes: 24 additions & 30 deletions internal/api/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,17 @@ import (
"time"

"github.com/google/uuid"

"github.com/koopa0/koopa/internal/session"
)

// Sentinel errors for session/CSRF operations.
var (
// ErrSessionCookieNotFound is returned when the session cookie is absent from the request.
ErrSessionCookieNotFound = errors.New("session cookie not found")
// ErrSessionInvalid is returned when the session cookie value is not a valid UUID.
ErrSessionInvalid = errors.New("session ID invalid")
// ErrCSRFRequired is returned when a state-changing request has no CSRF token.
ErrCSRFRequired = errors.New("csrf token required")
// ErrCSRFInvalid is returned when the CSRF token signature does not match.
ErrCSRFInvalid = errors.New("csrf token invalid")
// ErrCSRFExpired is returned when the CSRF token timestamp exceeds csrfTokenTTL.
ErrCSRFExpired = errors.New("csrf token expired")
// ErrCSRFMalformed is returned when the CSRF token format cannot be parsed.
ErrCSRFMalformed = errors.New("csrf token malformed")
errSessionCookieNotFound = errors.New("session cookie not found")
errSessionInvalid = errors.New("session ID invalid")
errCSRFRequired = errors.New("csrf token required")
errCSRFInvalid = errors.New("csrf token invalid")
errCSRFExpired = errors.New("csrf token expired")
errCSRFMalformed = errors.New("csrf token malformed")
)

// Pre-session CSRF token prefix to distinguish from user-bound tokens.
Expand Down Expand Up @@ -61,12 +55,12 @@ type sessionManager struct {
func (*sessionManager) SessionID(r *http.Request) (uuid.UUID, error) {
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
return uuid.Nil, ErrSessionCookieNotFound
return uuid.Nil, errSessionCookieNotFound
}

sessionID, err := uuid.Parse(cookie.Value)
if err != nil {
return uuid.Nil, ErrSessionInvalid
return uuid.Nil, errSessionInvalid
}

return sessionID, nil
Expand Down Expand Up @@ -109,17 +103,17 @@ func (sm *sessionManager) NewCSRFToken(userID string) string {
// CheckCSRF verifies a user-bound CSRF token.
func (sm *sessionManager) CheckCSRF(userID, token string) error {
if token == "" {
return ErrCSRFRequired
return errCSRFRequired
}

parts := strings.SplitN(token, ":", 2)
if len(parts) != 2 {
return ErrCSRFMalformed
return errCSRFMalformed
}

timestamp, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return ErrCSRFMalformed
return errCSRFMalformed
}

// SECURITY: Compute and verify HMAC BEFORE timestamp checks to prevent
Expand All @@ -133,19 +127,19 @@ func (sm *sessionManager) CheckCSRF(userID, token string) error {

actualSig, err := base64.URLEncoding.DecodeString(parts[1])
if err != nil {
return ErrCSRFMalformed
return errCSRFMalformed
}

if subtle.ConstantTimeCompare(actualSig, expectedSig) != 1 {
return ErrCSRFInvalid
return errCSRFInvalid
}

age := time.Since(time.Unix(timestamp, 0))
if age > csrfTokenTTL {
return ErrCSRFExpired
return errCSRFExpired
}
if age < -csrfClockSkew {
return ErrCSRFInvalid
return errCSRFInvalid
}
Comment on lines 141 to 143

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ambiguous handling of future-dated CSRF tokens

The check if age < -csrfClockSkew { return errCSRFInvalid } treats tokens with timestamps significantly in the future as 'invalid'. This could be misleading, as such tokens are more accurately 'malformed' or 'expired'. Consider returning errCSRFMalformed or errCSRFExpired to clarify the error semantics and improve maintainability.

Recommended solution:

if age < -csrfClockSkew {
    return errCSRFMalformed // or errCSRFExpired, depending on intended semantics
}


return nil
Expand All @@ -168,23 +162,23 @@ func (sm *sessionManager) NewPreSessionCSRFToken() string {
// CheckPreSessionCSRF verifies a pre-session CSRF token.
func (sm *sessionManager) CheckPreSessionCSRF(token string) error {
if token == "" {
return ErrCSRFRequired
return errCSRFRequired
}

if !strings.HasPrefix(token, preSessionPrefix) {
return ErrCSRFMalformed
return errCSRFMalformed
}

tokenBody := strings.TrimPrefix(token, preSessionPrefix)
parts := strings.SplitN(tokenBody, ":", 3)
if len(parts) != 3 {
return ErrCSRFMalformed
return errCSRFMalformed
}

nonce := parts[0]
timestamp, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return ErrCSRFMalformed
return errCSRFMalformed
}

// SECURITY: Compute and verify HMAC BEFORE timestamp checks to prevent
Expand All @@ -196,19 +190,19 @@ func (sm *sessionManager) CheckPreSessionCSRF(token string) error {

actualSig, err := base64.URLEncoding.DecodeString(parts[2])
if err != nil {
return ErrCSRFMalformed
return errCSRFMalformed
}

if subtle.ConstantTimeCompare(actualSig, expectedSig) != 1 {
return ErrCSRFInvalid
return errCSRFInvalid
}

age := time.Since(time.Unix(timestamp, 0))
if age > csrfTokenTTL {
return ErrCSRFExpired
return errCSRFExpired
}
if age < -csrfClockSkew {
return ErrCSRFInvalid
return errCSRFInvalid
}
Comment on lines 204 to 206

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ambiguous handling of future-dated pre-session CSRF tokens

The check if age < -csrfClockSkew { return errCSRFInvalid } in CheckPreSessionCSRF mirrors the ambiguity found in CheckCSRF. Tokens with timestamps far in the future should be treated as 'malformed' or 'expired' rather than simply 'invalid'.

Recommended solution:

if age < -csrfClockSkew {
    return errCSRFMalformed // or errCSRFExpired, depending on intended semantics
}


return nil
Expand Down
2 changes: 1 addition & 1 deletion internal/chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (cfg Config) validate() error {
return errors.New("at least one tool is required")
}
if cfg.MemoryStore != nil && cfg.WG == nil {
return errors.New("WG is required when MemoryStore is set")
return errors.New("wg is required when memory store is set")
}
Comment on lines 105 to 107

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential Incomplete Validation for Memory-Related Fields

The current validation checks that WG is set when MemoryStore is enabled, but does not enforce that BackgroundCtx is also set. While the constructor defaults BackgroundCtx to context.Background() if not provided, this may not be the intended lifecycle context for memory extraction, potentially leading to premature cancellation or resource leaks. Consider enforcing that BackgroundCtx is non-nil when MemoryStore is set, or document clearly that the default is always safe.

Recommended solution:

if cfg.MemoryStore != nil && cfg.BackgroundCtx == nil {
    return errors.New("background context is required when memory store is set")
}

Alternatively, document the defaulting behavior explicitly in the configuration struct.

return nil
}
Expand Down
21 changes: 15 additions & 6 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"log/slog"
"os"
"path/filepath"
Expand Down Expand Up @@ -51,9 +52,6 @@ var (
// ErrInvalidEmbedderModel indicates the embedder model is invalid.
ErrInvalidEmbedderModel = errors.New("invalid embedder model")

// ErrInvalidEmbedderDimension indicates the embedder produces incompatible vector dimensions.
ErrInvalidEmbedderDimension = errors.New("incompatible embedder dimension")

// ErrInvalidPostgresHost indicates the PostgreSQL host is invalid.
ErrInvalidPostgresHost = errors.New("invalid PostgreSQL host")

Expand Down Expand Up @@ -158,6 +156,11 @@ type Config struct {
CORSOrigins []string `mapstructure:"cors_origins" json:"cors_origins"`
TrustProxy bool `mapstructure:"trust_proxy" json:"trust_proxy"` // Trust X-Real-IP/X-Forwarded-For headers (set true behind reverse proxy)
DevMode bool `mapstructure:"dev_mode" json:"dev_mode"` // Dev mode: disables Secure flag on cookies (decoupled from DB SSL)

// API keys read from environment at Load() time.
// Unexported to avoid accidental serialization; validated in validateProviderAPIKey().
geminiAPIKey string
openaiAPIKey string
}

// Load loads configuration.
Expand Down Expand Up @@ -213,6 +216,12 @@ func Load() (*Config, error) {
return nil, fmt.Errorf("parsing DATABASE_URL: %w", err)
}

// Capture API keys from environment at load time.
// These are read by Genkit plugins directly, but we also need them
// for fail-fast validation (no os.Getenv in business logic).
cfg.geminiAPIKey = os.Getenv("GEMINI_API_KEY")
cfg.openaiAPIKey = os.Getenv("OPENAI_API_KEY")

// CRITICAL: Validate immediately (fail-fast)
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validating configuration: %w", err)
Expand Down Expand Up @@ -280,11 +289,11 @@ func setDefaults(v *viper.Viper) {
// 2. DD_API_KEY - Datadog API key (optional, for observability)
// 3. HMAC_SECRET - HMAC secret for CSRF protection (serve mode only)
func bindEnvVariables(v *viper.Viper) {
// Helper to panic on unexpected bind errors (hardcoded strings can't fail)
// If this panics, it's a BUG in our code, not a runtime error
// Helper to fatal on unexpected bind errors (hardcoded strings can't fail)
// If this fails, it's a BUG in our code, not a runtime error
mustBind := func(key, envVar string) {
if err := v.BindEnv(key, envVar); err != nil {
panic(fmt.Sprintf("BUG: failed to bind %q to %q: %v", key, envVar, err))
log.Fatalf("BUG: failed to bind %q to %q: %v", key, envVar, err)
}
}

Expand Down
11 changes: 7 additions & 4 deletions internal/config/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import (
"fmt"
"log/slog"
"os"
"slices"

"github.com/koopa0/koopa/internal/rag"
)

// supportedProviders lists all valid AI provider values.
Expand Down Expand Up @@ -87,7 +88,7 @@
return fmt.Errorf("%w: postgres_password must be set in config.yaml",
ErrInvalidPostgresPassword)
}
if c.PostgresPassword == "koopa_dev_password" {

Check failure on line 91 in internal/config/validation.go

View workflow job for this annotation

GitHub Actions / Lint

string `koopa_dev_password` has 3 occurrences, make it a constant (goconst)
slog.Warn("Using default development password for PostgreSQL",
"warning", "Change postgres_password in config.yaml for production deployments")
}
Expand Down Expand Up @@ -127,7 +128,8 @@
}

// requiredVectorDimension must match the pgvector schema: embedding vector(768).
const requiredVectorDimension = 768
// Canonical source: rag.VectorDimension.
var requiredVectorDimension = int(rag.VectorDimension)

// validateEmbedder checks that the configured embedder model produces vectors
// compatible with the database schema. For known models whose native dimension
Expand Down Expand Up @@ -200,16 +202,17 @@
}

// validateProviderAPIKey checks that the required API key is set for the configured provider.
// API keys are captured from environment in Load() and stored as unexported fields.
func (c *Config) validateProviderAPIKey() error {
switch c.resolvedProvider() {
case ProviderGemini:
if os.Getenv("GEMINI_API_KEY") == "" {
if c.geminiAPIKey == "" {
return fmt.Errorf("%w: GEMINI_API_KEY environment variable is required for provider %q\n"+
"Get your API key at: https://ai.google.dev/gemini-api/docs/api-key",
ErrMissingAPIKey, c.resolvedProvider())
}
case ProviderOpenAI:
if os.Getenv("OPENAI_API_KEY") == "" {
if c.openaiAPIKey == "" {
return fmt.Errorf("%w: OPENAI_API_KEY environment variable is required for provider %q\n"+
"Get your API key at: https://platform.openai.com/api-keys",
ErrMissingAPIKey, c.resolvedProvider())
Comment on lines 202 to 218

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing default case in provider API key validation

The validateProviderAPIKey function uses a switch statement to check for required API keys for each provider. However, there is no default case. If a new provider is added to supportedProviders but not handled in this function, the check will silently pass, potentially allowing misconfiguration or security issues.

Recommended solution:
Add a default case to the switch statement to return an error for unknown providers:

    default:
        return fmt.Errorf("%w: unknown provider %q", ErrInvalidProvider, c.resolvedProvider())

Expand Down
Loading
Loading