diff --git a/.env.example b/.env.example index 5ab7655..d23de21 100644 --- a/.env.example +++ b/.env.example @@ -4,7 +4,7 @@ # --- V-Mail app secrets and settings --- -# The environment the V-Mail app is running in. One of "development" or "production" +# The environment the V-Mail app is running in. One of "development", "test", and "production". VMAIL_ENV=production # The encryption key for AES-GCM encryption. 32-byte (256-bit) cryptographically secure random string, base64-encoded. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a55173e..6c67c47 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -33,11 +33,19 @@ This setup lets you run the Go backend and the React frontend locally for debugg The project includes several utility scripts in the `scripts/` directory. See [`scripts/README.md`](../scripts/README.md) for detailed documentation. -**Available scripts:** - -- **`check.sh`** - Runs all formatting, linting, and tests. Use `./scripts/check.sh` before committing new code and ensure all checks pass locally. -- **`roadmap-burndown.go`** - Analyzes git history of `ROADMAP.md` to generate a CSV burndown chart showing task completion over time. - +## Testing +`scripts/check.sh`uns all formatting, linting, and tests. +Always use `./scripts/check.sh` before committing new code and ensure all checks pass locally. + +More ideas to make it efficient: + +```bash +./scripts/check.sh # Run all checks (backend and frontend) +./scripts/check.sh --backend # Run only backend checks +./scripts/check.sh --frontend # Run only frontend checks +./scripts/check.sh --check # Run a specific check +./scripts/check.sh --help # Show help, including a list of available checks +``` ### Dev process diff --git a/ROADMAP.md b/ROADMAP.md index bfd514c..cce1f67 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -26,6 +26,15 @@ * Goal: Basic offline support. * Tasks: Implement IndexedDB caching for recently viewed emails. Build the sync logic. +## Milestone 2: Missing things + +### **2/7. 🔐 Authentication** + +- [ ] `ValidateToken` in `middleware.go` is currently a stub, and it always returns "test@example.com" without + actually validating the Authelia JWT token. This must be implemented before deploying to production. + The function should parse and validate the JWT token from Authelia, extract the user's email from the token claims, + and verify the token's signature and expiration. + ## Milestone 3: Actions - Goal: Be able to manage email. @@ -805,7 +814,7 @@ Done! 🎉 It works nicely. It's in `/backend/cmd/spike`. See `/backend/README.m * Create its handler function. This function should: * (For now) Assume auth is okay. * Check if a row exists in `user_settings` for this user. - * Return `{"isAuthenticated": true, "isSetupComplete": [true/false]}`. + * Return `{"isSetupComplete": [true/false]}`. * [x] **Create API: settings endpoints:** * In `/backend/internal/db`, create `user_settings.go`. Add `GetUserSettings(userID string)` and `SaveUserSettings(settings UserSettings)` functions. * Add the `GET /api/v1/settings` route and handler. It should call `GetUserSettings` and return the data (without passwords). diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 6f97480..186343a 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -42,21 +42,21 @@ func main() { } // NewServer creates and returns a new HTTP handler for the V-Mail API server. -func NewServer(cfg *config.Config, pool *pgxpool.Pool) http.Handler { +func NewServer(cfg *config.Config, dbPool *pgxpool.Pool) http.Handler { encryptor, err := crypto.NewEncryptor(cfg.EncryptionKeyBase64) if err != nil { log.Fatalf("Failed to create encryptor: %v", err) } - imapPool := imap.NewPool() - imapService := imap.NewService(pool, encryptor) + imapPool := imap.NewPoolWithMaxWorkers(cfg.IMAPMaxWorkers) + imapService := imap.NewService(dbPool, imapPool, encryptor) - authHandler := api.NewAuthHandler(pool) - settingsHandler := api.NewSettingsHandler(pool, encryptor) - foldersHandler := api.NewFoldersHandler(pool, encryptor, imapPool) - threadsHandler := api.NewThreadsHandler(pool, encryptor, imapService) - threadHandler := api.NewThreadHandler(pool, encryptor, imapService) - searchHandler := api.NewSearchHandler(pool, encryptor, imapService) + authHandler := api.NewAuthHandler(dbPool) + settingsHandler := api.NewSettingsHandler(dbPool, encryptor) + foldersHandler := api.NewFoldersHandler(dbPool, encryptor, imapPool) + threadsHandler := api.NewThreadsHandler(dbPool, encryptor, imapService) + threadHandler := api.NewThreadHandler(dbPool, encryptor, imapService) + searchHandler := api.NewSearchHandler(dbPool, encryptor, imapService) mux := http.NewServeMux() diff --git a/backend/cmd/test-server/main.go b/backend/cmd/test-server/main.go index 07d2452..d52258f 100644 --- a/backend/cmd/test-server/main.go +++ b/backend/cmd/test-server/main.go @@ -194,7 +194,10 @@ func setupTestUser(ctx context.Context, pool *pgxpool.Pool, cfg *config.Config, return fmt.Errorf("failed to create encryptor: %w", err) } - imapService := imap.NewService(pool, encryptor) + imapPool := imap.NewPool() + defer imapPool.Close() + + imapService := imap.NewService(pool, imapPool, encryptor) if err := imapService.SyncThreadsForFolder(ctx, userID, "INBOX"); err != nil { log.Printf("Warning: Failed to sync INBOX folder: %v", err) } else { @@ -205,8 +208,8 @@ func setupTestUser(ctx context.Context, pool *pgxpool.Pool, cfg *config.Config, } // startHTTPServer starts the HTTP server and waits for shutdown signals. -func startHTTPServer(cfg *config.Config, pool *pgxpool.Pool, imapServer *testutil.TestIMAPServer, smtpServer *testutil.TestSMTPServer) error { - server := NewServer(cfg, pool) +func startHTTPServer(cfg *config.Config, dbPool *pgxpool.Pool, imapServer *testutil.TestIMAPServer, smtpServer *testutil.TestSMTPServer) error { + server := NewServer(cfg, dbPool) address := ":" + cfg.Port log.Printf("V-Mail test server starting on %s", address) @@ -234,21 +237,21 @@ func startHTTPServer(cfg *config.Config, pool *pgxpool.Pool, imapServer *testuti } // NewServer creates and returns a new HTTP handler for the V-Mail API server. -func NewServer(cfg *config.Config, pool *pgxpool.Pool) http.Handler { +func NewServer(cfg *config.Config, dbPool *pgxpool.Pool) http.Handler { encryptor, err := crypto.NewEncryptor(cfg.EncryptionKeyBase64) if err != nil { log.Fatalf("Failed to create encryptor: %v", err) } - imapPool := imap.NewPool() - imapService := imap.NewService(pool, encryptor) - - authHandler := api.NewAuthHandler(pool) - settingsHandler := api.NewSettingsHandler(pool, encryptor) - foldersHandler := api.NewFoldersHandler(pool, encryptor, imapPool) - threadsHandler := api.NewThreadsHandler(pool, encryptor, imapService) - threadHandler := api.NewThreadHandler(pool, encryptor, imapService) - searchHandler := api.NewSearchHandler(pool, encryptor, imapService) + imapPool := imap.NewPoolWithMaxWorkers(cfg.IMAPMaxWorkers) + imapService := imap.NewService(dbPool, imapPool, encryptor) + + authHandler := api.NewAuthHandler(dbPool) + settingsHandler := api.NewSettingsHandler(dbPool, encryptor) + foldersHandler := api.NewFoldersHandler(dbPool, encryptor, imapPool) + threadsHandler := api.NewThreadsHandler(dbPool, encryptor, imapService) + threadHandler := api.NewThreadHandler(dbPool, encryptor, imapService) + searchHandler := api.NewSearchHandler(dbPool, encryptor, imapService) mux := http.NewServeMux() diff --git a/backend/internal/api/api_test_helpers.go b/backend/internal/api/api_test_helpers.go new file mode 100644 index 0000000..cb5dc8d --- /dev/null +++ b/backend/internal/api/api_test_helpers.go @@ -0,0 +1,69 @@ +package api + +import ( + "context" + "encoding/base64" + "net/http" + "net/http/httptest" + "testing" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/vdavid/vmail/backend/internal/auth" + "github.com/vdavid/vmail/backend/internal/crypto" + "github.com/vdavid/vmail/backend/internal/db" + "github.com/vdavid/vmail/backend/internal/models" +) + +// getTestEncryptor creates a test encryptor with a deterministic key for testing. +func getTestEncryptor(t *testing.T) *crypto.Encryptor { + t.Helper() + + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + base64Key := base64.StdEncoding.EncodeToString(key) + + encryptor, err := crypto.NewEncryptor(base64Key) + if err != nil { + t.Fatalf("Failed to create encryptor: %v", err) + } + return encryptor +} + +// setupTestUserAndSettings creates a test user and saves their settings. +// Returns the userID for use in tests. +func setupTestUserAndSettings(t *testing.T, pool *pgxpool.Pool, encryptor *crypto.Encryptor, email string) string { + t.Helper() + ctx := context.Background() + userID, err := db.GetOrCreateUser(ctx, pool, email) + if err != nil { + t.Fatalf("Failed to create user: %v", err) + } + + encryptedIMAPPassword, _ := encryptor.Encrypt("imap_pass") + encryptedSMTPPassword, _ := encryptor.Encrypt("smtp_pass") + + settings := &models.UserSettings{ + UserID: userID, + UndoSendDelaySeconds: 20, + PaginationThreadsPerPage: 100, + IMAPServerHostname: "imap.test.com", + IMAPUsername: "user", + EncryptedIMAPPassword: encryptedIMAPPassword, + SMTPServerHostname: "smtp.test.com", + SMTPUsername: "user", + EncryptedSMTPPassword: encryptedSMTPPassword, + } + if err := db.SaveUserSettings(ctx, pool, settings); err != nil { + t.Fatalf("Failed to save settings: %v", err) + } + return userID +} + +// createRequestWithUser creates an HTTP request with user email in context. +func createRequestWithUser(method, url, email string) *http.Request { + req := httptest.NewRequest(method, url, nil) + ctx := context.WithValue(req.Context(), auth.UserEmailKey, email) + return req.WithContext(ctx) +} diff --git a/backend/internal/api/auth_handler.go b/backend/internal/api/auth_handler.go index f56efd0..ca95bfb 100644 --- a/backend/internal/api/auth_handler.go +++ b/backend/internal/api/auth_handler.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "log" "net/http" @@ -41,18 +40,16 @@ func (h *AuthHandler) GetAuthStatus(w http.ResponseWriter, r *http.Request) { } response := models.AuthStatusResponse{ - IsAuthenticated: true, // TODO: Check if user is authenticated IsSetupComplete: isSetupComplete, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - log.Printf("AuthHandler: Failed to encode response: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) + if !WriteJSONResponse(w, response) { return } } +// checkSetupComplete determines if the user has completed onboarding by checking +// if user settings exist in the database. func (h *AuthHandler) checkSetupComplete(ctx context.Context, email string) (bool, error) { userID, err := db.GetOrCreateUser(ctx, h.pool, email) if err != nil { diff --git a/backend/internal/api/auth_handler_test.go b/backend/internal/api/auth_handler_test.go index 7c52431..1849ed9 100644 --- a/backend/internal/api/auth_handler_test.go +++ b/backend/internal/api/auth_handler_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/vdavid/vmail/backend/internal/auth" "github.com/vdavid/vmail/backend/internal/db" @@ -37,9 +38,6 @@ func TestAuthHandler_GetAuthStatus(t *testing.T) { t.Fatalf("Failed to decode response: %v", err) } - if !response.IsAuthenticated { - t.Error("Expected isAuthenticated to be true") - } if response.IsSetupComplete { t.Error("Expected isSetupComplete to be false for new user") } @@ -85,9 +83,6 @@ func TestAuthHandler_GetAuthStatus(t *testing.T) { t.Fatalf("Failed to decode response: %v", err) } - if !response.IsAuthenticated { - t.Error("Expected isAuthenticated to be true") - } if !response.IsSetupComplete { t.Error("Expected isSetupComplete to be true for user with settings") } @@ -103,4 +98,48 @@ func TestAuthHandler_GetAuthStatus(t *testing.T) { t.Errorf("Expected status 401, got %d", rr.Code) } }) + + t.Run("returns 500 when GetOrCreateUser returns an error", func(t *testing.T) { + req := httptest.NewRequest("GET", "/api/v1/auth/status", nil) + + // Use a cancelled context to simulate database connection failure + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + reqCtx := context.WithValue(cancelledCtx, auth.UserEmailKey, "test@example.com") + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetAuthStatus(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", rr.Code) + } + }) + + t.Run("returns 500 when UserSettingsExist returns an error", func(t *testing.T) { + email := "erroruser@example.com" + + // Create user first with valid context + ctx := context.Background() + _, err := db.GetOrCreateUser(ctx, pool, email) + if err != nil { + t.Fatalf("Failed to create user: %v", err) + } + + // Use a context with a deadline that's already passed to cause UserSettingsExist to fail + // Note: GetOrCreateUser might succeed due to ON CONFLICT, but UserSettingsExist will fail + deadlineCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + defer cancel() + reqCtx := context.WithValue(deadlineCtx, auth.UserEmailKey, email) + + req := httptest.NewRequest("GET", "/api/v1/auth/status", nil) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetAuthStatus(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", rr.Code) + } + }) } diff --git a/backend/internal/api/folders_handler.go b/backend/internal/api/folders_handler.go index 544bf8b..56d363d 100644 --- a/backend/internal/api/folders_handler.go +++ b/backend/internal/api/folders_handler.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "errors" "log" "net/http" @@ -10,7 +9,6 @@ import ( "strings" "github.com/jackc/pgx/v5/pgxpool" - "github.com/vdavid/vmail/backend/internal/auth" "github.com/vdavid/vmail/backend/internal/crypto" "github.com/vdavid/vmail/backend/internal/db" "github.com/vdavid/vmail/backend/internal/imap" @@ -37,7 +35,7 @@ func NewFoldersHandler(pool *pgxpool.Pool, encryptor *crypto.Encryptor, imapPool func (h *FoldersHandler) GetFolders(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - userID, ok := h.getUserIDFromContext(ctx, w) + userID, ok := GetUserIDFromContext(ctx, w, h.pool) if !ok { return } @@ -47,17 +45,21 @@ func (h *FoldersHandler) GetFolders(w http.ResponseWriter, r *http.Request) { return } - client, ok := h.getIMAPClient(w, userID, settings, imapPassword) - if !ok { - return - } + // Use WithClient to ensure the client is always released + err := h.imapPool.WithClient(userID, settings.IMAPServerHostname, settings.IMAPUsername, imapPassword, func(client imap.IMAPClient) error { + folders, err := client.ListFolders() + if err != nil { + return h.handleListFoldersError(w, userID, err, settings, imapPassword) + } - folders, ok := h.listFoldersWithRetry(w, userID, client, settings, imapPassword) - if !ok { - return - } + h.writeFoldersResponse(w, folders) + return nil + }) - h.writeFoldersResponse(w, folders) + if err != nil { + // Error handling is done inside the callback, so if we get here it's a connection error + h.handleConnectionError(w, err) + } } // getUserSettingsAndPassword retrieves user settings and decrypts the IMAP password. @@ -83,39 +85,26 @@ func (h *FoldersHandler) getUserSettingsAndPassword(ctx context.Context, w http. return settings, imapPassword, true } -// getIMAPClient gets an IMAP client, handling connection errors. -func (h *FoldersHandler) getIMAPClient(w http.ResponseWriter, userID string, settings *models.UserSettings, imapPassword string) (imap.IMAPClient, bool) { - client, err := h.imapPool.GetClient(userID, settings.IMAPServerHostname, settings.IMAPUsername, imapPassword) - if err != nil { - log.Printf("FoldersHandler: Failed to get IMAP client: %v", err) - errMsg := err.Error() - if strings.Contains(errMsg, "i/o timeout") { - http.Error(w, "Connection to IMAP server timed out. Please double-check your server hostname in your Settings and try again.", http.StatusServiceUnavailable) - } else { - http.Error(w, "Failed to connect to IMAP server", http.StatusInternalServerError) - } - return nil, false - } - return client, true -} - -// listFoldersWithRetry lists folders with automatic retry on connection errors. -func (h *FoldersHandler) listFoldersWithRetry(w http.ResponseWriter, userID string, client imap.IMAPClient, settings *models.UserSettings, imapPassword string) ([]*models.Folder, bool) { - folders, err := client.ListFolders() - if err != nil { - return h.handleListFoldersError(w, userID, err, settings, imapPassword) +// handleConnectionError handles errors when getting a client from the pool. +func (h *FoldersHandler) handleConnectionError(w http.ResponseWriter, err error) { + log.Printf("FoldersHandler: Failed to get IMAP client: %v", err) + errMsg := err.Error() + if strings.Contains(errMsg, "i/o timeout") { + http.Error(w, "Connection to IMAP server timed out. Please double-check your server hostname in your Settings and try again.", http.StatusServiceUnavailable) + } else { + http.Error(w, "Failed to connect to IMAP server", http.StatusInternalServerError) } - return folders, true } // handleListFoldersError handles errors from ListFolders, including retry logic. -func (h *FoldersHandler) handleListFoldersError(w http.ResponseWriter, userID string, err error, settings *models.UserSettings, imapPassword string) ([]*models.Folder, bool) { +// Returns an error to propagate to the WithClient callback. +func (h *FoldersHandler) handleListFoldersError(w http.ResponseWriter, userID string, err error, settings *models.UserSettings, imapPassword string) error { log.Printf("FoldersHandler: Failed to list folders: %v", err) errMsg := err.Error() if strings.Contains(errMsg, "SPECIAL-USE") { http.Error(w, "Your IMAP server doesn't support the SPECIAL-USE extension (RFC 6154), which is required for V-Mail to identify folder types. Please contact your email provider or use a different IMAP server.", http.StatusBadRequest) - return nil, false + return err // Return error to stop processing } if h.isBrokenConnectionError(errMsg) { @@ -123,42 +112,43 @@ func (h *FoldersHandler) handleListFoldersError(w http.ResponseWriter, userID st } http.Error(w, "Failed to list folders", http.StatusInternalServerError) - return nil, false + return err // Return error to stop processing } -// isBrokenConnectionError checks if the error indicates a broken connection. +// isBrokenConnectionError checks if the error message indicates a broken connection +// that can be recovered by retrying with a fresh IMAP client. func (h *FoldersHandler) isBrokenConnectionError(errMsg string) bool { return strings.Contains(errMsg, "broken pipe") || strings.Contains(errMsg, "connection reset") || strings.Contains(errMsg, "EOF") } -// retryListFolders retries listing folders after removing the broken connection. -func (h *FoldersHandler) retryListFolders(w http.ResponseWriter, userID string, settings *models.UserSettings, imapPassword string) ([]*models.Folder, bool) { +// retryListFolders retries listing folders after removing the broken connection from the pool. +// This handles transient connection issues by getting a fresh IMAP client and retrying the operation. +// Returns an error to propagate to the WithClient callback. +func (h *FoldersHandler) retryListFolders(w http.ResponseWriter, userID string, settings *models.UserSettings, imapPassword string) error { h.imapPool.RemoveClient(userID) - client, retryErr := h.imapPool.GetClient(userID, settings.IMAPServerHostname, settings.IMAPUsername, imapPassword) - if retryErr != nil { - log.Printf("FoldersHandler: Failed to get IMAP client on retry: %v", retryErr) - http.Error(w, "Failed to connect to IMAP server", http.StatusInternalServerError) - return nil, false - } - - folders, err := client.ListFolders() - if err != nil { - log.Printf("FoldersHandler: Failed to list folders on retry: %v", err) - if strings.Contains(err.Error(), "SPECIAL-USE") { - http.Error(w, "Your IMAP server doesn't support the SPECIAL-USE extension (RFC 6154), which is required for V-Mail to identify folder types. Please contact your email provider or use a different IMAP server.", http.StatusBadRequest) - return nil, false + // Use WithClient for the retry to ensure release happens + return h.imapPool.WithClient(userID, settings.IMAPServerHostname, settings.IMAPUsername, imapPassword, func(client imap.IMAPClient) error { + folders, err := client.ListFolders() + if err != nil { + log.Printf("FoldersHandler: Failed to list folders on retry: %v", err) + if strings.Contains(err.Error(), "SPECIAL-USE") { + http.Error(w, "Your IMAP server doesn't support the SPECIAL-USE extension (RFC 6154), which is required for V-Mail to identify folder types. Please contact your email provider or use a different IMAP server.", http.StatusBadRequest) + return err + } + http.Error(w, "Failed to list folders", http.StatusInternalServerError) + return err } - http.Error(w, "Failed to list folders", http.StatusInternalServerError) - return nil, false - } - return folders, true + h.writeFoldersResponse(w, folders) + return nil + }) } // writeFoldersResponse writes the folders response as JSON. +// Uses a buffered approach to prevent partial writes if JSON encoding fails. func (h *FoldersHandler) writeFoldersResponse(w http.ResponseWriter, folders []*models.Folder) { sortFoldersByRole(folders) @@ -167,29 +157,9 @@ func (h *FoldersHandler) writeFoldersResponse(w http.ResponseWriter, folders []* folderValues[i] = *f } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(folderValues); err != nil { - log.Printf("FoldersHandler: Failed to encode response: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - } -} - -func (h *FoldersHandler) getUserIDFromContext(ctx context.Context, w http.ResponseWriter) (string, bool) { - email, ok := auth.GetUserEmailFromContext(ctx) - if !ok { - log.Println("FoldersHandler: No user email in context") - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return "", false - } - - userID, err := db.GetOrCreateUser(ctx, h.pool, email) - if err != nil { - log.Printf("FoldersHandler: Failed to get/create user: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return "", false + if !WriteJSONResponse(w, folderValues) { + return } - - return userID, true } // sortFoldersByRole sorts folders by role priority, then alphabetically for "other" folders. diff --git a/backend/internal/api/folders_handler_test.go b/backend/internal/api/folders_handler_test.go index 607a4f6..a297d5a 100644 --- a/backend/internal/api/folders_handler_test.go +++ b/backend/internal/api/folders_handler_test.go @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/vdavid/vmail/backend/internal/auth" "github.com/vdavid/vmail/backend/internal/crypto" + "github.com/vdavid/vmail/backend/internal/db" "github.com/vdavid/vmail/backend/internal/imap" "github.com/vdavid/vmail/backend/internal/models" "github.com/vdavid/vmail/backend/internal/testutil" @@ -55,6 +56,24 @@ func TestFoldersHandler_GetFolders(t *testing.T) { // Note: Testing the actual IMAP connection would require a real IMAP server // or a mock. For now, we test the error handling paths. // Integration tests would test the full IMAP connection flow. + + t.Run("returns 500 when GetOrCreateUser returns an error", func(t *testing.T) { + email := "dberror@example.com" + + // Use a cancelled context to simulate database connection failure + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + req := httptest.NewRequest("GET", "/api/v1/folders", nil) + reqCtx := context.WithValue(cancelledCtx, auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetFolders(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", rr.Code) + } + }) } // mockIMAPClient is a mock implementation of IMAPClient for testing @@ -83,7 +102,7 @@ type mockIMAPPool struct { retryClientErr error } -func (m *mockIMAPPool) GetClient(userID, server, username, password string) (imap.IMAPClient, error) { +func (m *mockIMAPPool) WithClient(userID, server, username, password string, fn func(imap.IMAPClient) error) error { m.getClientCalled = true m.getClientCallCount++ m.getClientUserID = userID @@ -92,11 +111,21 @@ func (m *mockIMAPPool) GetClient(userID, server, username, password string) (ima m.getClientPass = password // If this is a retry (second call) and we have a retry client configured, use it + var client imap.IMAPClient + var err error if m.getClientCallCount > 1 && m.retryClient != nil { - return m.retryClient, m.retryClientErr + client = m.retryClient + err = m.retryClientErr + } else { + client = m.getClientResult + err = m.getClientErr + } + + if err != nil { + return err } - return m.getClientResult, m.getClientErr + return fn(client) } func (m *mockIMAPPool) RemoveClient(userID string) { @@ -396,4 +425,203 @@ func TestFoldersHandler_WithMocks(t *testing.T) { t.Errorf("Expected error message to mention SPECIAL-USE, got: %s", body) } }) + + t.Run("returns 500 when decrypting IMAP password fails", func(t *testing.T) { + email := "decrypt-error@example.com" + ctx := context.Background() + + // Create user + userID, err := db.GetOrCreateUser(ctx, pool, email) + if err != nil { + t.Fatalf("Failed to create user: %v", err) + } + + // Create settings with corrupted encrypted password (invalid encrypted data) + corruptedPassword := []byte("not-valid-encrypted-data") + encryptedSMTPPassword, _ := encryptor.Encrypt("smtp_pass") + + settings := &models.UserSettings{ + UserID: userID, + UndoSendDelaySeconds: 20, + PaginationThreadsPerPage: 100, + IMAPServerHostname: "imap.test.com", + IMAPUsername: "user", + EncryptedIMAPPassword: corruptedPassword, + SMTPServerHostname: "smtp.test.com", + SMTPUsername: "user", + EncryptedSMTPPassword: encryptedSMTPPassword, + } + if err := db.SaveUserSettings(ctx, pool, settings); err != nil { + t.Fatalf("Failed to save settings: %v", err) + } + + handler := NewFoldersHandler(pool, encryptor, imap.NewPool()) + rr := callGetFolders(t, handler, email) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", rr.Code) + } + }) + + t.Run("handles timeout error with StatusServiceUnavailable", func(t *testing.T) { + email := "timeout-test@example.com" + setupTestUserAndSettings(t, pool, encryptor, email) + + mockPool := &mockIMAPPool{ + getClientResult: nil, + getClientErr: fmt.Errorf("dial tcp 192.168.1.1:993: i/o timeout"), + } + + handler := NewFoldersHandler(pool, encryptor, mockPool) + rr := callGetFolders(t, handler, email) + + if rr.Code != http.StatusServiceUnavailable { + t.Errorf("Expected status 503, got %d", rr.Code) + } + + // Verify error message + body := rr.Body.String() + if !strings.Contains(body, "timed out") { + t.Errorf("Expected error message to mention timeout, got: %s", body) + } + if !strings.Contains(body, "hostname") { + t.Errorf("Expected error message to mention hostname, got: %s", body) + } + }) +} + +// failingResponseWriter is a ResponseWriter that fails on Write to test error handling. +type failingResponseWriter struct { + http.ResponseWriter + writeShouldFail bool +} + +func (f *failingResponseWriter) Write(p []byte) (int, error) { + if f.writeShouldFail { + return 0, fmt.Errorf("write failed") + } + return f.ResponseWriter.Write(p) +} + +func TestFoldersHandler_WriteResponseErrors(t *testing.T) { + pool := testutil.NewTestDB(t) + defer pool.Close() + + encryptor := getTestEncryptor(t) + email := "write-error@example.com" + setupTestUserAndSettings(t, pool, encryptor, email) + + t.Run("handles write failure gracefully", func(t *testing.T) { + mockClient := &mockIMAPClient{ + listFoldersResult: []*models.Folder{ + {Name: "INBOX", Role: "inbox"}, + }, + listFoldersErr: nil, + } + + mockPool := &mockIMAPPool{ + getClientResult: mockClient, + getClientErr: nil, + } + + handler := NewFoldersHandler(pool, encryptor, mockPool) + + req := httptest.NewRequest("GET", "/api/v1/folders", nil) + reqCtx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + // Create a ResponseWriter that fails on Write + rr := httptest.NewRecorder() + failingWriter := &failingResponseWriter{ + ResponseWriter: rr, + writeShouldFail: true, + } + + handler.GetFolders(failingWriter, req) + + // The handler should handle the write error gracefully (it logs but doesn't crash) + // We can't easily test the error path without checking logs, but we verify it doesn't panic + }) +} + +func TestSortFoldersByRole(t *testing.T) { + tests := []struct { + name string + folders []*models.Folder + expected []string // Expected folder names in order + }{ + { + name: "sorts by role priority", + folders: []*models.Folder{ + {Name: "Archive", Role: "archive"}, + {Name: "INBOX", Role: "inbox"}, + {Name: "Drafts", Role: "drafts"}, + {Name: "Sent", Role: "sent"}, + }, + expected: []string{"INBOX", "Sent", "Drafts", "Archive"}, + }, + { + name: "sorts alphabetically within same role", + folders: []*models.Folder{ + {Name: "Zebra", Role: "other"}, + {Name: "Alpha", Role: "other"}, + {Name: "Beta", Role: "other"}, + }, + expected: []string{"Alpha", "Beta", "Zebra"}, + }, + { + name: "sorts by role then alphabetically", + folders: []*models.Folder{ + {Name: "Zebra", Role: "other"}, + {Name: "INBOX", Role: "inbox"}, + {Name: "Alpha", Role: "other"}, + {Name: "Sent", Role: "sent"}, + {Name: "Beta", Role: "other"}, + }, + expected: []string{"INBOX", "Sent", "Alpha", "Beta", "Zebra"}, + }, + { + name: "handles all role types", + folders: []*models.Folder{ + {Name: "Trash", Role: "trash"}, + {Name: "Spam", Role: "spam"}, + {Name: "INBOX", Role: "inbox"}, + {Name: "Sent", Role: "sent"}, + {Name: "Drafts", Role: "drafts"}, + {Name: "Archive", Role: "archive"}, + }, + expected: []string{"INBOX", "Sent", "Drafts", "Spam", "Trash", "Archive"}, + }, + { + name: "handles empty list", + folders: []*models.Folder{}, + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Make a copy to avoid modifying the original + folders := make([]*models.Folder, len(tt.folders)) + for i, f := range tt.folders { + folders[i] = &models.Folder{ + Name: f.Name, + Role: f.Role, + } + } + + sortFoldersByRole(folders) + + if len(folders) != len(tt.expected) { + t.Errorf("Expected %d folders, got %d", len(tt.expected), len(folders)) + return + } + + for i, expectedName := range tt.expected { + if folders[i].Name != expectedName { + t.Errorf("Expected folder at index %d to be '%s', got '%s'", i, expectedName, folders[i].Name) + } + } + }) + } } diff --git a/backend/internal/api/helpers.go b/backend/internal/api/helpers.go new file mode 100644 index 0000000..00c9109 --- /dev/null +++ b/backend/internal/api/helpers.go @@ -0,0 +1,93 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "log" + "net/http" + "strconv" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/vdavid/vmail/backend/internal/auth" + "github.com/vdavid/vmail/backend/internal/db" +) + +// GetUserIDFromContext extracts the user's email from context, resolves/creates the DB user, +// and writes appropriate HTTP errors when it fails. Returns (userID, true) on success. +// This is a shared helper function used across multiple handlers to ensure consistent +// error handling for user authentication and user ID resolution. +func GetUserIDFromContext(ctx context.Context, w http.ResponseWriter, pool *pgxpool.Pool) (string, bool) { + email, ok := auth.GetUserEmailFromContext(ctx) + if !ok { + log.Println("API: No user email in context") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return "", false + } + + userID, err := db.GetOrCreateUser(ctx, pool, email) + if err != nil { + log.Printf("API: Failed to get/create user: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return "", false + } + + return userID, true +} + +// ParsePaginationParams parses page and limit from query parameters. +// Returns default values (page=1, limit=defaultLimit) if parameters are missing or invalid. +// This is a shared helper function used by multiple handlers for consistent pagination parsing. +func ParsePaginationParams(r *http.Request, defaultLimit int) (page, limit int) { + page = 1 + limit = defaultLimit + + if pageStr := r.URL.Query().Get("page"); pageStr != "" { + if parsed, err := strconv.Atoi(pageStr); err == nil && parsed > 0 { + page = parsed + } + } + + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { + if parsed, err := strconv.Atoi(limitStr); err == nil && parsed > 0 { + limit = parsed + } + } + + return page, limit +} + +// WriteJSONResponse writes a JSON response using a buffered approach to prevent partial writes. +// If encoding fails, it writes an error response and returns false. Otherwise returns true. +// This ensures atomic responses and consistent error handling across all handlers. +func WriteJSONResponse(w http.ResponseWriter, data interface{}) bool { + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(data); err != nil { + log.Printf("API: Failed to encode JSON response: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return false + } + + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write(buf.Bytes()); err != nil { + log.Printf("API: Failed to write JSON response: %v", err) + } + return true +} + +// GetPaginationLimit gets the pagination limit, using user settings if available. +// If limitFromQuery is provided (> 0), it takes precedence. +// Otherwise, it uses the user's setting from the database, or defaults to 100. +// This is a shared helper function used by multiple handlers for consistent pagination limit handling. +func GetPaginationLimit(ctx context.Context, pool *pgxpool.Pool, userID string, limitFromQuery int) int { + if limitFromQuery > 0 { + return limitFromQuery + } + + settings, err := db.GetUserSettings(ctx, pool, userID) + if err == nil { + return settings.PaginationThreadsPerPage + } + + return 100 +} diff --git a/backend/internal/api/search_handler.go b/backend/internal/api/search_handler.go index d8667e2..6490baf 100644 --- a/backend/internal/api/search_handler.go +++ b/backend/internal/api/search_handler.go @@ -2,22 +2,19 @@ package api import ( "context" - "encoding/json" + "errors" "log" "net/http" - "strings" "github.com/jackc/pgx/v5/pgxpool" - "github.com/vdavid/vmail/backend/internal/auth" "github.com/vdavid/vmail/backend/internal/crypto" - "github.com/vdavid/vmail/backend/internal/db" "github.com/vdavid/vmail/backend/internal/imap" ) // SearchHandler handles search-related API requests. type SearchHandler struct { pool *pgxpool.Pool - encryptor *crypto.Encryptor + encryptor *crypto.Encryptor // Not used directly, but required by imapService imapService imap.IMAPService } @@ -34,7 +31,7 @@ func NewSearchHandler(pool *pgxpool.Pool, encryptor *crypto.Encryptor, imapServi func (h *SearchHandler) Search(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - userID, ok := h.getUserIDFromContext(ctx, w) + userID, ok := GetUserIDFromContext(ctx, w, h.pool) if !ok { return } @@ -44,14 +41,19 @@ func (h *SearchHandler) Search(w http.ResponseWriter, r *http.Request) { // Empty query means return all emails // Get pagination params - page, limitFromQuery := parsePaginationParams(r, 100) - limit := h.getPaginationLimit(ctx, userID, limitFromQuery) + page, limitFromQuery := ParsePaginationParams(r, 100) + limit := GetPaginationLimit(ctx, h.pool, userID, limitFromQuery) // Call IMAP service search threads, totalCount, err := h.imapService.Search(ctx, userID, query, page, limit) if err != nil { + // Treat client cancellations as non-errors + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + // Check if it's a query parsing error (should return 400) - if strings.Contains(err.Error(), "invalid search query") { + if errors.Is(err, imap.ErrInvalidSearchQuery) { log.Printf("SearchHandler: Invalid query: %v", err) http.Error(w, err.Error(), http.StatusBadRequest) return @@ -62,42 +64,8 @@ func (h *SearchHandler) Search(w http.ResponseWriter, r *http.Request) { } // Build and send the response - response := buildPaginationResponse(threads, totalCount, page, limit) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - log.Printf("SearchHandler: Failed to encode response: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) + response := BuildPaginationResponse(threads, totalCount, page, limit) + if !WriteJSONResponse(w, response) { return } } - -func (h *SearchHandler) getUserIDFromContext(ctx context.Context, w http.ResponseWriter) (string, bool) { - email, ok := auth.GetUserEmailFromContext(ctx) - if !ok { - log.Println("SearchHandler: No user email in context") - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return "", false - } - - userID, err := db.GetOrCreateUser(ctx, h.pool, email) - if err != nil { - log.Printf("SearchHandler: Failed to get/create user: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return "", false - } - - return userID, true -} - -func (h *SearchHandler) getPaginationLimit(ctx context.Context, userID string, limitFromQuery int) int { - if limitFromQuery > 0 { - return limitFromQuery - } - - settings, err := db.GetUserSettings(ctx, h.pool, userID) - if err == nil { - return settings.PaginationThreadsPerPage - } - - return 100 -} diff --git a/backend/internal/api/search_handler_test.go b/backend/internal/api/search_handler_test.go index 48bb195..127c9b0 100644 --- a/backend/internal/api/search_handler_test.go +++ b/backend/internal/api/search_handler_test.go @@ -3,10 +3,12 @@ package api import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" + "github.com/vdavid/vmail/backend/internal/auth" "github.com/vdavid/vmail/backend/internal/imap" "github.com/vdavid/vmail/backend/internal/models" "github.com/vdavid/vmail/backend/internal/testutil" @@ -141,8 +143,8 @@ func TestSearchHandler_Search(t *testing.T) { email := "searchuser5@example.com" setupTestUserAndSettings(t, pool, encryptor, email) - // Mock IMAP service to return parser error - mockIMAP.searchErr = &imapError{message: "invalid search query: empty from: value"} + // Mock IMAP service to return parser error wrapped with ErrInvalidSearchQuery + mockIMAP.searchErr = fmt.Errorf("%w: empty from: value", imap.ErrInvalidSearchQuery) req := createRequestWithUser("GET", "/api/v1/search?q=from:", email) rr := httptest.NewRecorder() @@ -153,6 +155,135 @@ func TestSearchHandler_Search(t *testing.T) { t.Errorf("Expected status 400, got %d", rr.Code) } }) + + t.Run("falls back to default limit when GetUserSettings returns error", func(t *testing.T) { + email := "settings-error-search@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + // Delete the user settings to simulate GetUserSettings returning an error + // (it will return NotFound, which getPaginationLimit handles by using default) + if _, err := pool.Exec(ctx, "DELETE FROM user_settings WHERE user_id = $1", userID); err != nil { + t.Fatalf("Failed to delete user settings: %v", err) + } + + // Reset mock state + mockIMAP.searchErr = nil + mockIMAP.searchResult = []*models.Thread{} + mockIMAP.searchCount = 0 + + req := createRequestWithUser("GET", "/api/v1/search?q=test", email) + rr := httptest.NewRecorder() + + handler.Search(rr, req) + + // Should still return 200 OK, using default limit of 100 + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + // Verify that the search was called with default limit (100) + // Since limitFromQuery is 0, it should use default + if mockIMAP.searchLimit != 100 { + t.Errorf("Expected default limit 100, got %d", mockIMAP.searchLimit) + } + }) + + t.Run("handles JSON encoding failure gracefully", func(t *testing.T) { + email := "json-error-search@example.com" + setupTestUserAndSettings(t, pool, encryptor, email) + + // Reset mock state + threads := []*models.Thread{ + { + ID: "thread-1", + StableThreadID: "stable-1", + Subject: "Test Thread", + UserID: "user-1", + }, + } + mockIMAP.searchResult = threads + mockIMAP.searchCount = 1 + mockIMAP.searchErr = nil + + req := createRequestWithUser("GET", "/api/v1/search?q=test", email) + + // Create a ResponseWriter that fails on Write + rr := httptest.NewRecorder() + failingWriter := &failingResponseWriterSearch{ + ResponseWriter: rr, + writeShouldFail: true, + } + + handler.Search(failingWriter, req) + + // The handler should handle the write error gracefully (it logs but doesn't crash) + // The status code should still be set (200) even if Write fails + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + }) + + t.Run("handles invalid pagination parameters gracefully", func(t *testing.T) { + email := "pagination-invalid-search@example.com" + setupTestUserAndSettings(t, pool, encryptor, email) + + testCases := []struct { + name string + query string + expectedPage int + expectedLimit int + }{ + {"page=0 uses default", "q=test&page=0&limit=50", 1, 50}, + {"page=-1 uses default", "q=test&page=-1&limit=50", 1, 50}, + {"limit=0 uses default", "q=test&page=1&limit=0", 1, 100}, + {"limit=-1 uses default", "q=test&page=1&limit=-1", 1, 100}, + {"both invalid", "q=test&page=0&limit=0", 1, 100}, + {"non-numeric page", "q=test&page=abc&limit=50", 1, 50}, + {"non-numeric limit", "q=test&page=1&limit=xyz", 1, 100}, + {"very large limit", "q=test&page=1&limit=999999", 1, 999999}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Reset mock state for each test case + mockIMAP.searchErr = nil + mockIMAP.searchResult = []*models.Thread{} + mockIMAP.searchCount = 0 + + req := httptest.NewRequest("GET", "/api/v1/search?"+tc.query, nil) + reqCtx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.Search(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + if mockIMAP.searchPage != tc.expectedPage { + t.Errorf("Expected page %d, got %d", tc.expectedPage, mockIMAP.searchPage) + } + if mockIMAP.searchLimit != tc.expectedLimit { + t.Errorf("Expected limit %d, got %d", tc.expectedLimit, mockIMAP.searchLimit) + } + }) + } + }) +} + +// failingResponseWriterSearch is a ResponseWriter that fails on Write to test error handling. +type failingResponseWriterSearch struct { + http.ResponseWriter + writeShouldFail bool +} + +func (f *failingResponseWriterSearch) Write(p []byte) (int, error) { + if f.writeShouldFail { + return 0, fmt.Errorf("write failed") + } + return f.ResponseWriter.Write(p) } // mockIMAPServiceForSearch is a mock implementation of IMAPService for search tests diff --git a/backend/internal/api/settings_handler.go b/backend/internal/api/settings_handler.go index 3eb4071..02a24ff 100644 --- a/backend/internal/api/settings_handler.go +++ b/backend/internal/api/settings_handler.go @@ -1,14 +1,12 @@ package api import ( - "context" "encoding/json" "errors" "log" "net/http" "github.com/jackc/pgx/v5/pgxpool" - "github.com/vdavid/vmail/backend/internal/auth" "github.com/vdavid/vmail/backend/internal/crypto" "github.com/vdavid/vmail/backend/internal/db" "github.com/vdavid/vmail/backend/internal/models" @@ -32,7 +30,7 @@ func NewSettingsHandler(pool *pgxpool.Pool, encryptor *crypto.Encryptor) *Settin func (h *SettingsHandler) GetSettings(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - userID, ok := h.getUserIDFromContext(ctx, w) + userID, ok := GetUserIDFromContext(ctx, w, h.pool) if !ok { return } @@ -59,10 +57,7 @@ func (h *SettingsHandler) GetSettings(w http.ResponseWriter, r *http.Request) { SMTPPasswordSet: len(settings.EncryptedSMTPPassword) > 0, } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - log.Printf("SettingsHandler: Failed to encode response: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) + if !WriteJSONResponse(w, response) { return } } @@ -71,7 +66,7 @@ func (h *SettingsHandler) GetSettings(w http.ResponseWriter, r *http.Request) { func (h *SettingsHandler) PostSettings(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - userID, ok := h.getUserIDFromContext(ctx, w) + userID, ok := GetUserIDFromContext(ctx, w, h.pool) if !ok { return } @@ -89,7 +84,8 @@ func (h *SettingsHandler) PostSettings(w http.ResponseWriter, r *http.Request) { return } - // Get existing settings to preserve passwords if not provided + // Get existing settings to preserve passwords if not provided in the request. + // This allows users to update other settings without re-entering passwords. existingSettings, err := db.GetUserSettings(ctx, h.pool, userID) var encryptedIMAPPassword []byte var encryptedSMTPPassword []byte @@ -100,7 +96,8 @@ func (h *SettingsHandler) PostSettings(w http.ResponseWriter, r *http.Request) { return } - // Handle IMAP password: use existing if not provided, encrypt new one if provided + // Handle IMAP password: use existing if not provided, encrypt new one if provided. + // For initial setup (no existing settings), password is required. if req.IMAPPassword == "" { if existingSettings != nil { encryptedIMAPPassword = existingSettings.EncryptedIMAPPassword @@ -119,7 +116,8 @@ func (h *SettingsHandler) PostSettings(w http.ResponseWriter, r *http.Request) { } } - // Handle SMTP password: use existing if not provided, encrypt new one if provided + // Handle SMTP password: use existing if not provided, encrypt new one if provided. + // For initial setup (no existing settings), password is required. if req.SMTPPassword == "" { if existingSettings != nil { encryptedSMTPPassword = existingSettings.EncryptedSMTPPassword @@ -156,13 +154,19 @@ func (h *SettingsHandler) PostSettings(w http.ResponseWriter, r *http.Request) { return } + successResponse := struct { + Success bool `json:"success"` + }{Success: true} + w.WriteHeader(http.StatusOK) - _, err = w.Write([]byte(`{"success": true}`)) - if err != nil { + if !WriteJSONResponse(w, successResponse) { return } } +// validateSettingsRequest validates the user settings request, ensuring all required +// fields are present. Note that passwords are optional on update (they can be empty +// to preserve existing passwords), but are required for initial setup. func (h *SettingsHandler) validateSettingsRequest(req *models.UserSettingsRequest) error { if req.IMAPServerHostname == "" { return errors.New("IMAP server hostname is required") @@ -180,23 +184,3 @@ func (h *SettingsHandler) validateSettingsRequest(req *models.UserSettingsReques // Password validation removed - passwords are optional on update return nil } - -// getUserIDFromContext extracts the user's email from context, resolves/creates the DB user, -// and writes appropriate HTTP errors when it fails. Returns (userID, true) on success. -func (h *SettingsHandler) getUserIDFromContext(ctx context.Context, w http.ResponseWriter) (string, bool) { - email, ok := auth.GetUserEmailFromContext(ctx) - if !ok { - log.Println("SettingsHandler: No user email in context") - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return "", false - } - - userID, err := db.GetOrCreateUser(ctx, h.pool, email) - if err != nil { - log.Printf("SettingsHandler: Failed to get/create user: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return "", false - } - - return userID, true -} diff --git a/backend/internal/api/settings_handler_test.go b/backend/internal/api/settings_handler_test.go index ace8d50..d78a731 100644 --- a/backend/internal/api/settings_handler_test.go +++ b/backend/internal/api/settings_handler_test.go @@ -3,35 +3,19 @@ package api import ( "bytes" "context" - "encoding/base64" "encoding/json" + "fmt" "net/http" "net/http/httptest" + "strings" "testing" "github.com/vdavid/vmail/backend/internal/auth" - "github.com/vdavid/vmail/backend/internal/crypto" "github.com/vdavid/vmail/backend/internal/db" "github.com/vdavid/vmail/backend/internal/models" "github.com/vdavid/vmail/backend/internal/testutil" ) -func getTestEncryptor(t *testing.T) *crypto.Encryptor { - t.Helper() - - key := make([]byte, 32) - for i := range key { - key[i] = byte(i) - } - base64Key := base64.StdEncoding.EncodeToString(key) - - encryptor, err := crypto.NewEncryptor(base64Key) - if err != nil { - t.Fatalf("Failed to create encryptor: %v", err) - } - return encryptor -} - func TestSettingsHandler_GetSettings(t *testing.T) { pool := testutil.NewTestDB(t) defer pool.Close() @@ -346,4 +330,230 @@ func TestSettingsHandler_PostSettings(t *testing.T) { t.Errorf("Expected status 400 for empty passwords on new user, got %d", rr.Code) } }) + + t.Run("returns 500 when GetUserSettings returns non-NotFound error in PostSettings", func(t *testing.T) { + email := "dberror-post@example.com" + + // Use a cancelled context to simulate database connection failure + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + reqBody := models.UserSettingsRequest{ + UndoSendDelaySeconds: 25, + PaginationThreadsPerPage: 75, + IMAPServerHostname: "imap.new.com", + IMAPUsername: "new-user", + IMAPPassword: "imap_password_123", + SMTPServerHostname: "smtp.new.com", + SMTPUsername: "new-user", + SMTPPassword: "smtp_password_456", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body)) + reqCtx := context.WithValue(cancelledCtx, auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.PostSettings(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", rr.Code) + } + }) + + // Note: Testing SaveUserSettings failure is difficult without mocking the database layer. + // The error handling code path is covered by the handler implementation, but simulating + // a database save failure in a real test environment is complex. The error handling + // is straightforward (returns 500 on error), so we rely on integration tests and + // the code coverage to ensure this path works correctly. + + t.Run("returns 500 when GetUserSettings returns non-NotFound error in GetSettings", func(t *testing.T) { + email := "dberror-get@example.com" + + // Use a cancelled context to simulate database connection failure + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + req := httptest.NewRequest("GET", "/api/v1/settings", nil) + reqCtx := context.WithValue(cancelledCtx, auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetSettings(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", rr.Code) + } + }) + + t.Run("validates missing IMAP server hostname", func(t *testing.T) { + email := "validation-test@example.com" + + reqBody := models.UserSettingsRequest{ + UndoSendDelaySeconds: 25, + PaginationThreadsPerPage: 75, + IMAPServerHostname: "", // Missing + IMAPUsername: "user", + IMAPPassword: "password", + SMTPServerHostname: "smtp.test.com", + SMTPUsername: "user", + SMTPPassword: "password", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body)) + ctx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.PostSettings(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", rr.Code) + } + + bodyStr := rr.Body.String() + if !strings.Contains(bodyStr, "IMAP server hostname is required") { + t.Errorf("Expected error message about IMAP server hostname, got: %s", bodyStr) + } + }) + + t.Run("validates missing IMAP username", func(t *testing.T) { + email := "validation-test2@example.com" + + reqBody := models.UserSettingsRequest{ + UndoSendDelaySeconds: 25, + PaginationThreadsPerPage: 75, + IMAPServerHostname: "imap.test.com", + IMAPUsername: "", // Missing + IMAPPassword: "password", + SMTPServerHostname: "smtp.test.com", + SMTPUsername: "user", + SMTPPassword: "password", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body)) + ctx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.PostSettings(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", rr.Code) + } + + bodyStr := rr.Body.String() + if !strings.Contains(bodyStr, "IMAP username is required") { + t.Errorf("Expected error message about IMAP username, got: %s", bodyStr) + } + }) + + t.Run("validates missing SMTP server hostname", func(t *testing.T) { + email := "validation-test3@example.com" + + reqBody := models.UserSettingsRequest{ + UndoSendDelaySeconds: 25, + PaginationThreadsPerPage: 75, + IMAPServerHostname: "imap.test.com", + IMAPUsername: "user", + IMAPPassword: "password", + SMTPServerHostname: "", // Missing + SMTPUsername: "user", + SMTPPassword: "password", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body)) + ctx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.PostSettings(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", rr.Code) + } + + bodyStr := rr.Body.String() + if !strings.Contains(bodyStr, "SMTP server hostname is required") { + t.Errorf("Expected error message about SMTP server hostname, got: %s", bodyStr) + } + }) + + t.Run("validates missing SMTP username", func(t *testing.T) { + email := "validation-test4@example.com" + + reqBody := models.UserSettingsRequest{ + UndoSendDelaySeconds: 25, + PaginationThreadsPerPage: 75, + IMAPServerHostname: "imap.test.com", + IMAPUsername: "user", + IMAPPassword: "password", + SMTPServerHostname: "smtp.test.com", + SMTPUsername: "", // Missing + SMTPPassword: "password", + } + + body, _ := json.Marshal(reqBody) + req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body)) + ctx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.PostSettings(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", rr.Code) + } + + bodyStr := rr.Body.String() + if !strings.Contains(bodyStr, "SMTP username is required") { + t.Errorf("Expected error message about SMTP username, got: %s", bodyStr) + } + }) +} + +// failingResponseWriter is a ResponseWriter that fails on Write to test error handling. +type failingResponseWriterSettings struct { + http.ResponseWriter + writeShouldFail bool +} + +func (f *failingResponseWriterSettings) Write(p []byte) (int, error) { + if f.writeShouldFail { + return 0, fmt.Errorf("write failed") + } + return f.ResponseWriter.Write(p) +} + +func TestSettingsHandler_WriteResponseErrors(t *testing.T) { + pool := testutil.NewTestDB(t) + defer pool.Close() + + encryptor := getTestEncryptor(t) + handler := NewSettingsHandler(pool, encryptor) + + t.Run("handles write failure gracefully in GetSettings", func(t *testing.T) { + email := "write-error-get@example.com" + setupTestUserAndSettings(t, pool, encryptor, email) + + req := httptest.NewRequest("GET", "/api/v1/settings", nil) + reqCtx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + // Create a ResponseWriter that fails on Write + rr := httptest.NewRecorder() + failingWriter := &failingResponseWriterSettings{ + ResponseWriter: rr, + writeShouldFail: true, + } + + handler.GetSettings(failingWriter, req) + + // The handler should handle the write error gracefully (it logs but doesn't crash) + // We can't easily test the error path without checking logs, but we verify it doesn't panic + }) } diff --git a/backend/internal/api/thread_handler.go b/backend/internal/api/thread_handler.go index 245199c..dcf78e2 100644 --- a/backend/internal/api/thread_handler.go +++ b/backend/internal/api/thread_handler.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "errors" "fmt" "log" @@ -11,7 +10,6 @@ import ( "strings" "github.com/jackc/pgx/v5/pgxpool" - "github.com/vdavid/vmail/backend/internal/auth" "github.com/vdavid/vmail/backend/internal/crypto" "github.com/vdavid/vmail/backend/internal/db" "github.com/vdavid/vmail/backend/internal/imap" @@ -21,7 +19,7 @@ import ( // ThreadHandler handles individual thread-related API requests. type ThreadHandler struct { pool *pgxpool.Pool - encryptor *crypto.Encryptor + encryptor *crypto.Encryptor // Not used directly, but required by imapService imapService imap.IMAPService } @@ -51,7 +49,8 @@ func getStableThreadIDFromPath(path string) (string, error) { return decoded, nil } -// collectMessagesToSync collects messages that need syncing and returns them with a UID-to-index map. +// collectMessagesToSync collects messages that need syncing (those without body content) +// and returns them along with a map from IMAP UID to message index for efficient updates. func collectMessagesToSync(messages []*models.Message) ([]imap.MessageToSync, map[int64]int) { messagesToSync := make([]imap.MessageToSync, 0) messageUIDToIndex := make(map[int64]int) @@ -67,7 +66,9 @@ func collectMessagesToSync(messages []*models.Message) ([]imap.MessageToSync, ma return messagesToSync, messageUIDToIndex } -// syncMissingBodies syncs missing message bodies and updates the messages slice. +// syncMissingBodies syncs missing message bodies in batch and updates the messages slice. +// If sync fails, it logs the error but continues (graceful degradation - returns messages +// without bodies rather than failing the entire request). func (h *ThreadHandler) syncMissingBodies(ctx context.Context, userID string, messages []*models.Message, messagesToSync []imap.MessageToSync, messageUIDToIndex map[int64]int) { if len(messagesToSync) == 0 { return @@ -80,17 +81,25 @@ func (h *ThreadHandler) syncMissingBodies(ctx context.Context, userID string, me } // Re-fetch all synced messages to get updated bodies + // Log warnings for messages that couldn't be refreshed after sync + var failedMessages []string for _, msgToSync := range messagesToSync { updatedMsg, err := db.GetMessageByUID(ctx, h.pool, userID, msgToSync.FolderName, msgToSync.IMAPUID) if err == nil { if idx, found := messageUIDToIndex[msgToSync.IMAPUID]; found { messages[idx] = updatedMsg } + } else { + failedMessages = append(failedMessages, fmt.Sprintf("%s:%d", msgToSync.FolderName, msgToSync.IMAPUID)) } } + if len(failedMessages) > 0 { + log.Printf("ThreadHandler: Warning: %d message(s) couldn't be refreshed after sync: %v", len(failedMessages), failedMessages) + } } // assignAttachments assigns attachments from the batch-fetched map to messages. +// Ensures that each message's Attachments field is initialized (never nil). func assignAttachments(messages []*models.Message, attachmentsMap map[string][]*models.Attachment) { for _, msg := range messages { attachments := attachmentsMap[msg.ID] @@ -109,8 +118,8 @@ func assignAttachments(messages []*models.Message, attachmentsMap map[string][]* } } -// convertMessagesToThreadMessages converts []*Message to []Message. -// Ensures that Attachments is always an array, never nil. +// convertMessagesToThreadMessages converts []*Message to []Message for the response. +// Ensures that Attachments is always an array, never nil, and filters out nil messages. func convertMessagesToThreadMessages(messages []*models.Message) []models.Message { threadMessages := make([]models.Message, 0, len(messages)) for _, msg := range messages { @@ -129,7 +138,7 @@ func convertMessagesToThreadMessages(messages []*models.Message) []models.Messag func (h *ThreadHandler) GetThread(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - userID, ok := h.getUserIDFromContext(ctx, w) + userID, ok := GetUserIDFromContext(ctx, w, h.pool) if !ok { return } @@ -173,6 +182,7 @@ func (h *ThreadHandler) GetThread(w http.ResponseWriter, r *http.Request) { } // Fetch all attachments in a single query (fixes N+1 query bug) + // If fetching attachments fails, continue with empty attachments rather than failing the request. attachmentsMap, err := db.GetAttachmentsForMessages(ctx, h.pool, messageIDs) if err != nil { log.Printf("ThreadHandler: Failed to get attachments: %v", err) @@ -187,28 +197,7 @@ func (h *ThreadHandler) GetThread(w http.ResponseWriter, r *http.Request) { assignAttachments(messages, attachmentsMap) thread.Messages = convertMessagesToThreadMessages(messages) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(thread); err != nil { - log.Printf("ThreadHandler: Failed to encode response: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) + if !WriteJSONResponse(w, thread) { return } } - -func (h *ThreadHandler) getUserIDFromContext(ctx context.Context, w http.ResponseWriter) (string, bool) { - email, ok := auth.GetUserEmailFromContext(ctx) - if !ok { - log.Println("ThreadHandler: No user email in context") - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return "", false - } - - userID, err := db.GetOrCreateUser(ctx, h.pool, email) - if err != nil { - log.Printf("ThreadHandler: Failed to get/create user: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return "", false - } - - return userID, true -} diff --git a/backend/internal/api/thread_handler_test.go b/backend/internal/api/thread_handler_test.go index 283bce6..c64b2f7 100644 --- a/backend/internal/api/thread_handler_test.go +++ b/backend/internal/api/thread_handler_test.go @@ -3,8 +3,10 @@ package api import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" + "net/url" "testing" "time" @@ -20,7 +22,8 @@ func TestThreadHandler_GetThread(t *testing.T) { defer pool.Close() encryptor := getTestEncryptor(t) - imapService := imap.NewService(pool, encryptor) + imapService := imap.NewService(pool, imap.NewPool(), encryptor) + defer imapService.Close() handler := NewThreadHandler(pool, encryptor, imapService) t.Run("returns 401 when no user email in context", func(t *testing.T) { @@ -245,6 +248,270 @@ func TestThreadHandler_GetThread(t *testing.T) { t.Errorf("Expected filename 'test.pdf', got %s", response.Messages[0].Attachments[0].Filename) } }) + + t.Run("returns 500 when GetThreadByStableID returns non-NotFound error", func(t *testing.T) { + email := "dberror-thread@example.com" + setupTestUserAndSettings(t, pool, encryptor, email) + + // Use a cancelled context to simulate database connection failure + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + req := httptest.NewRequest("GET", "/api/v1/thread/test-thread-id", nil) + reqCtx := context.WithValue(cancelledCtx, auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetThread(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", rr.Code) + } + }) + + t.Run("returns 500 when GetMessagesForThread returns an error", func(t *testing.T) { + email := "dberror-messages@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + // Create a thread so GetThreadByStableID succeeds + thread := &models.Thread{ + UserID: userID, + StableThreadID: "thread-db-error", + Subject: "DB Error Test", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + // Use a cancelled context to simulate database error when getting messages + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + req := httptest.NewRequest("GET", "/api/v1/thread/thread-db-error", nil) + reqCtx := context.WithValue(cancelledCtx, auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetThread(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", rr.Code) + } + }) + + t.Run("continues with empty attachments when GetAttachmentsForMessages returns error", func(t *testing.T) { + email := "attachments-error@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + thread := &models.Thread{ + UserID: userID, + StableThreadID: "thread-attachments-error", + Subject: "Attachments Error Test", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + now := time.Now() + msg := &models.Message{ + ThreadID: thread.ID, + UserID: userID, + IMAPUID: 1, + IMAPFolderName: "INBOX", + MessageIDHeader: "msg-attachments-error", + Subject: "Test", + SentAt: &now, + UnsafeBodyHTML: "

Body

", + BodyText: "Body", + } + if err := db.SaveMessage(ctx, pool, msg); err != nil { + t.Fatalf("Failed to save message: %v", err) + } + + req := httptest.NewRequest("GET", "/api/v1/thread/thread-attachments-error", nil) + reqCtx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + // Note: This test verifies that GetAttachmentsForMessages errors are handled gracefully. + // The handler already handles this by continuing with empty attachments. + // The assignAttachments function ensures attachments are never nil. + rr := httptest.NewRecorder() + handler.GetThread(rr, req) + + // The handler should handle the error gracefully + // The handler already handles GetAttachmentsForMessages errors by continuing with empty attachments + // This test verifies the handler completes successfully + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + var response models.Thread + if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Verify the handler completed successfully + // The convertMessagesToThreadMessages function ensures Attachments is never nil in the response + if len(response.Messages) > 0 { + // JSON unmarshaling might set nil for empty slices, but the handler ensures they're arrays + // The important thing is the handler doesn't crash + _ = response.Messages[0].Attachments + } + }) + + t.Run("handles invalid thread_id encoding", func(t *testing.T) { + email := "encoding-test@example.com" + setupTestUserAndSettings(t, pool, encryptor, email) + + t.Run("valid URL-encoded Message-ID", func(t *testing.T) { + encodedID := url.QueryEscape("") + req := httptest.NewRequest("GET", "/api/v1/thread/"+encodedID, nil) + reqCtx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetThread(rr, req) + + // For valid encoding, we expect either 404 (not found) or 200 (found) + if rr.Code != http.StatusNotFound && rr.Code != http.StatusOK { + t.Errorf("Expected status 404 or 200 for valid encoding, got %d", rr.Code) + } + }) + + t.Run("invalid encoding", func(t *testing.T) { + // Create a request with invalid URL encoding manually + // httptest.NewRequest will fail on invalid encoding, so we construct it differently + req, err := http.NewRequest("GET", "/api/v1/thread/%ZZ", nil) + if err != nil { + // If NewRequest fails due to invalid encoding, that's actually what we want to test + // But we can't test the handler in that case. Instead, test with a path that + // will cause PathUnescape to fail + req = &http.Request{ + Method: "GET", + URL: &url.URL{ + Path: "/api/v1/thread/%ZZ", + }, + } + } + reqCtx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetThread(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected status 400 for invalid encoding, got %d", rr.Code) + } + }) + + t.Run("special characters", func(t *testing.T) { + encodedID := url.QueryEscape("") + req := httptest.NewRequest("GET", "/api/v1/thread/"+encodedID, nil) + reqCtx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetThread(rr, req) + + // For valid encoding, we expect either 404 (not found) or 200 (found) + if rr.Code != http.StatusNotFound && rr.Code != http.StatusOK { + t.Errorf("Expected status 404 or 200 for valid encoding, got %d", rr.Code) + } + }) + }) + + t.Run("handles JSON encoding failure gracefully", func(t *testing.T) { + email := "json-error-thread@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + thread := &models.Thread{ + UserID: userID, + StableThreadID: "thread-json-error", + Subject: "JSON Error Test", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + now := time.Now() + msg := &models.Message{ + ThreadID: thread.ID, + UserID: userID, + IMAPUID: 1, + IMAPFolderName: "INBOX", + MessageIDHeader: "msg-json-error", + Subject: "Test", + SentAt: &now, + UnsafeBodyHTML: "

Body

", + BodyText: "Body", + } + if err := db.SaveMessage(ctx, pool, msg); err != nil { + t.Fatalf("Failed to save message: %v", err) + } + + req := httptest.NewRequest("GET", "/api/v1/thread/thread-json-error", nil) + reqCtx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + // Create a ResponseWriter that fails on Write + rr := httptest.NewRecorder() + failingWriter := &failingResponseWriterThread{ + ResponseWriter: rr, + writeShouldFail: true, + } + + handler.GetThread(failingWriter, req) + + // The handler should handle the write error gracefully (it logs but doesn't crash) + // The status code should still be set (200) even if Write fails + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + }) + + t.Run("handles thread with nil messages", func(t *testing.T) { + email := "nil-messages@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + thread := &models.Thread{ + UserID: userID, + StableThreadID: "thread-nil-messages", + Subject: "Nil Messages Test", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + // Don't create any messages - GetMessagesForThread should return empty slice, not nil + // But test the defensive check in the handler + + req := httptest.NewRequest("GET", "/api/v1/thread/thread-nil-messages", nil) + reqCtx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetThread(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + var response models.Thread + if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Messages should be an empty array, not nil + // Note: JSON unmarshaling might set nil for empty slices, but the handler's defensive check + // ensures messages is never nil. The important thing is the handler doesn't crash. + if len(response.Messages) != 0 { + t.Errorf("Expected 0 messages, got %d", len(response.Messages)) + } + }) } // mockIMAPServiceForThread is a mock implementation of IMAPService for thread handler tests @@ -466,4 +733,141 @@ func TestThreadHandler_SyncsMissingBodies(t *testing.T) { t.Error("Expected SyncFullMessages NOT to be called when body already exists") } }) + + t.Run("continues when SyncFullMessages returns an error", func(t *testing.T) { + email := "sync-error-thread@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + thread := &models.Thread{ + UserID: userID, + StableThreadID: "thread-sync-error", + Subject: "Sync Error Test", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + // Create a message WITHOUT a body (triggers sync) + now := time.Now() + msg := &models.Message{ + ThreadID: thread.ID, + UserID: userID, + IMAPUID: 1, + IMAPFolderName: "INBOX", + MessageIDHeader: "msg-sync-error", + Subject: "Test", + SentAt: &now, + // No body - triggers sync + } + if err := db.SaveMessage(ctx, pool, msg); err != nil { + t.Fatalf("Failed to save message: %v", err) + } + + mockIMAP := &mockIMAPServiceForThread{ + syncFullMessagesErr: fmt.Errorf("IMAP sync failed"), + } + + handler := NewThreadHandler(pool, encryptor, mockIMAP) + + req := httptest.NewRequest("GET", "/api/v1/thread/thread-sync-error", nil) + reqCtx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetThread(rr, req) + + // Should still return 200 OK, with messages without bodies (graceful degradation) + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + var response models.Thread + if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Verify sync was attempted + if !mockIMAP.syncFullMessagesCalled { + t.Error("Expected SyncFullMessages to be called") + } + + // Messages should be returned even without bodies + if len(response.Messages) == 0 { + t.Error("Expected messages to be returned even when sync fails") + } + }) + + t.Run("continues when GetMessageByUID fails after sync", func(t *testing.T) { + email := "getmessage-error@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + thread := &models.Thread{ + UserID: userID, + StableThreadID: "thread-getmessage-error", + Subject: "GetMessage Error Test", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + // Create a message WITHOUT a body (triggers sync) + now := time.Now() + msg := &models.Message{ + ThreadID: thread.ID, + UserID: userID, + IMAPUID: 999, // Use a high UID that might not exist after sync + IMAPFolderName: "INBOX", + MessageIDHeader: "msg-getmessage-error", + Subject: "Test", + SentAt: &now, + // No body - triggers sync + } + if err := db.SaveMessage(ctx, pool, msg); err != nil { + t.Fatalf("Failed to save message: %v", err) + } + + mockIMAP := &mockIMAPServiceForThread{ + syncFullMessagesErr: nil, // Sync succeeds + } + + handler := NewThreadHandler(pool, encryptor, mockIMAP) + + req := httptest.NewRequest("GET", "/api/v1/thread/thread-getmessage-error", nil) + reqCtx := context.WithValue(req.Context(), auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetThread(rr, req) + + // Should still return 200 OK, with original message (without updated body) + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + var response models.Thread + if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Messages should be returned even if GetMessageByUID fails + if len(response.Messages) == 0 { + t.Error("Expected messages to be returned even when GetMessageByUID fails") + } + }) + +} + +// failingResponseWriterThread is a ResponseWriter that fails on Write to test error handling. +type failingResponseWriterThread struct { + http.ResponseWriter + writeShouldFail bool +} + +func (f *failingResponseWriterThread) Write(p []byte) (int, error) { + if f.writeShouldFail { + return 0, fmt.Errorf("write failed") + } + return f.ResponseWriter.Write(p) } diff --git a/backend/internal/api/threads_handler.go b/backend/internal/api/threads_handler.go index aa78c8e..1512129 100644 --- a/backend/internal/api/threads_handler.go +++ b/backend/internal/api/threads_handler.go @@ -2,13 +2,10 @@ package api import ( "context" - "encoding/json" "log" "net/http" - "strconv" "github.com/jackc/pgx/v5/pgxpool" - "github.com/vdavid/vmail/backend/internal/auth" "github.com/vdavid/vmail/backend/internal/crypto" "github.com/vdavid/vmail/backend/internal/db" "github.com/vdavid/vmail/backend/internal/imap" @@ -18,7 +15,7 @@ import ( // ThreadsHandler handles thread-list-related API requests. type ThreadsHandler struct { pool *pgxpool.Pool - encryptor *crypto.Encryptor + encryptor *crypto.Encryptor // Not used directly, but required by imapService imapService imap.IMAPService } @@ -31,43 +28,9 @@ func NewThreadsHandler(pool *pgxpool.Pool, encryptor *crypto.Encryptor, imapServ } } -// parsePaginationParams parses page and limit from query parameters. -func parsePaginationParams(r *http.Request, defaultLimit int) (page, limit int) { - page = 1 - limit = defaultLimit - - if pageStr := r.URL.Query().Get("page"); pageStr != "" { - if parsed, err := strconv.Atoi(pageStr); err == nil && parsed > 0 { - page = parsed - } - } - - if limitStr := r.URL.Query().Get("limit"); limitStr != "" { - if parsed, err := strconv.Atoi(limitStr); err == nil && parsed > 0 { - limit = parsed - } - } - - return page, limit -} - -// getPaginationLimit gets the pagination limit, using user settings if available. -func (h *ThreadsHandler) getPaginationLimit(ctx context.Context, userID string, limitFromQuery int) int { - if limitFromQuery > 0 { - return limitFromQuery - } - - // If no limit provided, use the user's setting as default - settings, err := db.GetUserSettings(ctx, h.pool, userID) - if err == nil { - return settings.PaginationThreadsPerPage - } - - // If settings not found, use default 100 - return 100 -} - // syncFolderIfNeeded checks if the folder needs syncing and syncs if necessary. +// If the sync check fails or sync itself fails, it logs the error but continues +// to return cached data, ensuring the request doesn't fail due to sync issues. func (h *ThreadsHandler) syncFolderIfNeeded(ctx context.Context, userID, folder string) { shouldSync, err := h.imapService.ShouldSyncFolder(ctx, userID, folder) if err != nil { @@ -84,22 +47,12 @@ func (h *ThreadsHandler) syncFolderIfNeeded(ctx context.Context, userID, folder } } -// buildPaginationResponse builds the pagination response structure. -func buildPaginationResponse(threads []*models.Thread, totalCount, page, limit int) interface{} { - return struct { - Threads []*models.Thread `json:"threads"` - Pagination struct { - TotalCount int `json:"total_count"` - Page int `json:"page"` - PerPage int `json:"per_page"` - } `json:"pagination"` - }{ +// BuildPaginationResponse builds the pagination response structure. +// This is a shared helper function used by multiple handlers for consistent response formatting. +func BuildPaginationResponse(threads []*models.Thread, totalCount, page, limit int) *models.ThreadsResponse { + return &models.ThreadsResponse{ Threads: threads, - Pagination: struct { - TotalCount int `json:"total_count"` - Page int `json:"page"` - PerPage int `json:"per_page"` - }{ + Pagination: models.PaginationInfo{ TotalCount: totalCount, Page: page, PerPage: limit, @@ -111,7 +64,7 @@ func buildPaginationResponse(threads []*models.Thread, totalCount, page, limit i func (h *ThreadsHandler) GetThreads(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - userID, ok := h.getUserIDFromContext(ctx, w) + userID, ok := GetUserIDFromContext(ctx, w, h.pool) if !ok { return } @@ -124,8 +77,8 @@ func (h *ThreadsHandler) GetThreads(w http.ResponseWriter, r *http.Request) { } // Get pagination params - page, limitFromQuery := parsePaginationParams(r, 100) - limit := h.getPaginationLimit(ctx, userID, limitFromQuery) + page, limitFromQuery := ParsePaginationParams(r, 100) + limit := GetPaginationLimit(ctx, h.pool, userID, limitFromQuery) offset := (page - 1) * limit // Sync folder if needed @@ -148,29 +101,10 @@ func (h *ThreadsHandler) GetThreads(w http.ResponseWriter, r *http.Request) { } // Build and send the response - response := buildPaginationResponse(threads, totalCount, page, limit) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - log.Printf("ThreadsHandler: Failed to encode response: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } -} + // Use a buffered approach to prevent partial writes if JSON encoding fails + response := BuildPaginationResponse(threads, totalCount, page, limit) -func (h *ThreadsHandler) getUserIDFromContext(ctx context.Context, w http.ResponseWriter) (string, bool) { - email, ok := auth.GetUserEmailFromContext(ctx) - if !ok { - log.Println("ThreadsHandler: No user email in context") - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return "", false - } - - userID, err := db.GetOrCreateUser(ctx, h.pool, email) - if err != nil { - log.Printf("ThreadsHandler: Failed to get/create user: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return "", false + if !WriteJSONResponse(w, response) { + return } - - return userID, true } diff --git a/backend/internal/api/threads_handler_test.go b/backend/internal/api/threads_handler_test.go index 27172fb..491d400 100644 --- a/backend/internal/api/threads_handler_test.go +++ b/backend/internal/api/threads_handler_test.go @@ -9,57 +9,20 @@ import ( "testing" "time" - "github.com/jackc/pgx/v5/pgxpool" "github.com/vdavid/vmail/backend/internal/auth" - "github.com/vdavid/vmail/backend/internal/crypto" "github.com/vdavid/vmail/backend/internal/db" "github.com/vdavid/vmail/backend/internal/imap" "github.com/vdavid/vmail/backend/internal/models" "github.com/vdavid/vmail/backend/internal/testutil" ) -// setupTestUserAndSettings creates a test user and saves their settings. -func setupTestUserAndSettings(t *testing.T, pool *pgxpool.Pool, encryptor *crypto.Encryptor, email string) string { - t.Helper() - ctx := context.Background() - userID, err := db.GetOrCreateUser(ctx, pool, email) - if err != nil { - t.Fatalf("Failed to create user: %v", err) - } - - encryptedIMAPPassword, _ := encryptor.Encrypt("imap_pass") - encryptedSMTPPassword, _ := encryptor.Encrypt("smtp_pass") - - settings := &models.UserSettings{ - UserID: userID, - UndoSendDelaySeconds: 20, - PaginationThreadsPerPage: 100, - IMAPServerHostname: "imap.test.com", - IMAPUsername: "user", - EncryptedIMAPPassword: encryptedIMAPPassword, - SMTPServerHostname: "smtp.test.com", - SMTPUsername: "user", - EncryptedSMTPPassword: encryptedSMTPPassword, - } - if err := db.SaveUserSettings(ctx, pool, settings); err != nil { - t.Fatalf("Failed to save settings: %v", err) - } - return userID -} - -// createRequestWithUser creates an HTTP request with user email in context. -func createRequestWithUser(method, url, email string) *http.Request { - req := httptest.NewRequest(method, url, nil) - ctx := context.WithValue(req.Context(), auth.UserEmailKey, email) - return req.WithContext(ctx) -} - func TestThreadsHandler_GetThreads(t *testing.T) { pool := testutil.NewTestDB(t) defer pool.Close() encryptor := getTestEncryptor(t) - imapService := imap.NewService(pool, encryptor) + imapService := imap.NewService(pool, imap.NewPool(), encryptor) + defer imapService.Close() handler := NewThreadsHandler(pool, encryptor, imapService) t.Run("returns 401 when no user email in context", func(t *testing.T) { @@ -100,14 +63,7 @@ func TestThreadsHandler_GetThreads(t *testing.T) { t.Errorf("Expected status 200, got %d", rr.Code) } - var response struct { - Threads []*models.Thread `json:"threads"` - Pagination struct { - TotalCount int `json:"total_count"` - Page int `json:"page"` - PerPage int `json:"per_page"` - } `json:"pagination"` - } + var response models.ThreadsResponse if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -160,14 +116,7 @@ func TestThreadsHandler_GetThreads(t *testing.T) { t.Errorf("Expected status 200, got %d", rr.Code) } - var response struct { - Threads []*models.Thread `json:"threads"` - Pagination struct { - TotalCount int `json:"total_count"` - Page int `json:"page"` - PerPage int `json:"per_page"` - } `json:"pagination"` - } + var response models.ThreadsResponse if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -226,14 +175,7 @@ func TestThreadsHandler_GetThreads(t *testing.T) { t.Errorf("Expected status 200, got %d", rr.Code) } - var response struct { - Threads []*models.Thread `json:"threads"` - Pagination struct { - TotalCount int `json:"total_count"` - Page int `json:"page"` - PerPage int `json:"per_page"` - } `json:"pagination"` - } + var response models.ThreadsResponse if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -418,4 +360,276 @@ func TestThreadsHandler_SyncsWhenStale(t *testing.T) { t.Error("Expected SyncThreadsForFolder to be called even if it fails") } }) + + t.Run("falls back to default limit when GetUserSettings fails", func(t *testing.T) { + email := "settings-error@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + // Create a thread to ensure we have data + thread := &models.Thread{ + UserID: userID, + StableThreadID: "test-thread-settings-error", + Subject: "Test Thread", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + // Delete the user settings to simulate GetUserSettings returning an error + // (it will return NotFound, which getPaginationLimit handles by using default) + if _, err := pool.Exec(ctx, "DELETE FROM user_settings WHERE user_id = $1", userID); err != nil { + t.Fatalf("Failed to delete user settings: %v", err) + } + + mockIMAP := &mockIMAPService{ + shouldSyncFolderResult: false, + shouldSyncFolderErr: nil, + } + + handler := NewThreadsHandler(pool, encryptor, mockIMAP) + req := createRequestWithUser("GET", "/api/v1/threads?folder=INBOX", email) + + rr := httptest.NewRecorder() + handler.GetThreads(rr, req) + + // Should still return 200 OK, using default limit of 100 + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + var response models.ThreadsResponse + if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Should use default limit of 100 + if response.Pagination.PerPage != 100 { + t.Errorf("Expected default limit 100, got %d", response.Pagination.PerPage) + } + }) + + t.Run("returns 500 when GetThreadsForFolder returns an error", func(t *testing.T) { + email := "threads-error@example.com" + setupTestUserAndSettings(t, pool, encryptor, email) + + // Use a cancelled context to simulate database error + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + mockIMAP := &mockIMAPService{ + shouldSyncFolderResult: false, + shouldSyncFolderErr: nil, + } + + handler := NewThreadsHandler(pool, encryptor, mockIMAP) + req := httptest.NewRequest("GET", "/api/v1/threads?folder=INBOX", nil) + reqCtx := context.WithValue(cancelledCtx, auth.UserEmailKey, email) + req = req.WithContext(reqCtx) + + rr := httptest.NewRecorder() + handler.GetThreads(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", rr.Code) + } + }) + + t.Run("returns 500 when GetThreadCountForFolder returns an error", func(t *testing.T) { + email := "count-error@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + // Create a thread so GetThreadsForFolder succeeds + thread := &models.Thread{ + UserID: userID, + StableThreadID: "test-thread-count-error", + Subject: "Test Thread", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + // Use a cancelled context to simulate database error when counting + // We need to create the user first, then use cancelled context + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + reqCtx := context.WithValue(cancelledCtx, auth.UserEmailKey, email) + req := httptest.NewRequest("GET", "/api/v1/threads?folder=INBOX", nil) + req = req.WithContext(reqCtx) + + mockIMAP := &mockIMAPService{ + shouldSyncFolderResult: false, + shouldSyncFolderErr: nil, + } + + handler := NewThreadsHandler(pool, encryptor, mockIMAP) + rr := httptest.NewRecorder() + handler.GetThreads(rr, req) + + // Note: This test is tricky because GetThreadsForFolder is called before GetThreadCountForFolder + // and both use the same context. The cancelled context will cause GetThreadsForFolder to fail first. + // So we expect 500, but it's from GetThreadsForFolder, not GetThreadCountForFolder. + // This still tests error handling, just at an earlier point. + if rr.Code != http.StatusInternalServerError { + t.Errorf("Expected status 500, got %d", rr.Code) + } + }) + + t.Run("handles invalid pagination parameters gracefully", func(t *testing.T) { + email := "pagination-invalid@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + // Create a thread + thread := &models.Thread{ + UserID: userID, + StableThreadID: "test-thread-pagination", + Subject: "Test Thread", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + testCases := []struct { + name string + query string + expectedPage int + expectedPerPage int + }{ + {"page=0 uses default", "page=0&limit=50", 1, 50}, + {"page=-1 uses default", "page=-1&limit=50", 1, 50}, + {"limit=0 uses default", "page=1&limit=0", 1, 100}, + {"limit=-1 uses default", "page=1&limit=-1", 1, 100}, + {"both invalid", "page=0&limit=0", 1, 100}, + {"non-numeric page", "page=abc&limit=50", 1, 50}, + {"non-numeric limit", "page=1&limit=xyz", 1, 100}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockIMAP := &mockIMAPService{ + shouldSyncFolderResult: false, + shouldSyncFolderErr: nil, + } + + handler := NewThreadsHandler(pool, encryptor, mockIMAP) + req := createRequestWithUser("GET", fmt.Sprintf("/api/v1/threads?folder=INBOX&%s", tc.query), email) + + rr := httptest.NewRecorder() + handler.GetThreads(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + var response models.ThreadsResponse + if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + if response.Pagination.Page != tc.expectedPage { + t.Errorf("Expected page %d, got %d", tc.expectedPage, response.Pagination.Page) + } + if response.Pagination.PerPage != tc.expectedPerPage { + t.Errorf("Expected per_page %d, got %d", tc.expectedPerPage, response.Pagination.PerPage) + } + }) + } + }) + + t.Run("continues when ShouldSyncFolder returns an error", func(t *testing.T) { + email := "sync-error@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + // Create a thread + thread := &models.Thread{ + UserID: userID, + StableThreadID: "test-thread-sync-error", + Subject: "Test Thread", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + mockIMAP := &mockIMAPService{ + shouldSyncFolderResult: true, // Should try to sync + shouldSyncFolderErr: fmt.Errorf("cache check failed"), + syncThreadsForFolderErr: nil, // Sync succeeds + } + + handler := NewThreadsHandler(pool, encryptor, mockIMAP) + req := createRequestWithUser("GET", "/api/v1/threads?folder=INBOX", email) + + rr := httptest.NewRecorder() + handler.GetThreads(rr, req) + + // Should still return 200 OK, continuing despite ShouldSyncFolder error + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + + // Verify that ShouldSyncFolder was called + if !mockIMAP.shouldSyncFolderCalled { + t.Error("Expected ShouldSyncFolder to be called") + } + + // Verify that SyncThreadsForFolder was attempted (handler continues anyway) + if !mockIMAP.syncThreadsForFolderCalled { + t.Error("Expected SyncThreadsForFolder to be called even when ShouldSyncFolder returns error") + } + }) + + t.Run("handles JSON encoding failure gracefully", func(t *testing.T) { + email := "json-error@example.com" + ctx := context.Background() + userID := setupTestUserAndSettings(t, pool, encryptor, email) + + // Create a thread + thread := &models.Thread{ + UserID: userID, + StableThreadID: "test-thread-json-error", + Subject: "Test Thread", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + mockIMAP := &mockIMAPService{ + shouldSyncFolderResult: false, + shouldSyncFolderErr: nil, + } + + handler := NewThreadsHandler(pool, encryptor, mockIMAP) + req := createRequestWithUser("GET", "/api/v1/threads?folder=INBOX", email) + + // Create a ResponseWriter that fails on Write + rr := httptest.NewRecorder() + failingWriter := &failingResponseWriterThreads{ + ResponseWriter: rr, + writeShouldFail: true, + } + + handler.GetThreads(failingWriter, req) + + // The handler should handle the write error gracefully (it logs but doesn't crash) + // The status code should still be set (200) even if Write fails + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + }) +} + +// failingResponseWriterThreads is a ResponseWriter that fails on Write to test error handling. +type failingResponseWriterThreads struct { + http.ResponseWriter + writeShouldFail bool +} + +func (f *failingResponseWriterThreads) Write(p []byte) (int, error) { + if f.writeShouldFail { + return 0, fmt.Errorf("write failed") + } + return f.ResponseWriter.Write(p) } diff --git a/backend/internal/auth/middleware.go b/backend/internal/auth/middleware.go index eb18fd6..35e46db 100644 --- a/backend/internal/auth/middleware.go +++ b/backend/internal/auth/middleware.go @@ -15,6 +15,8 @@ type contextKey string const UserEmailKey contextKey = "user_email" // RequireAuth middleware checks for a valid bearer token in the Authorization header. +// It extracts the token, validates it, and stores the user's email in the request context +// for use by downstream handlers. Returns 401 Unauthorized if authentication fails. func RequireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") @@ -25,14 +27,31 @@ func RequireAuth(next http.Handler) http.Handler { return } - parts := strings.Split(authHeader, " ") - if len(parts) != 2 || parts[0] != "Bearer" { + // Parse Authorization header: "Bearer " (RFC 7235) + // Use strings.Fields to handle multiple spaces and trim whitespace + // Bearer scheme is case-insensitive per RFC 7235 + fields := strings.Fields(authHeader) + if len(fields) < 2 { log.Println("Auth: Invalid Authorization header format") http.Error(w, "Unauthorized", http.StatusUnauthorized) return } - token := parts[1] + // Check if the scheme is "Bearer" (case-insensitive) + if !strings.EqualFold(fields[0], "Bearer") { + log.Println("Auth: Invalid Authorization header format") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Join remaining fields to handle tokens that may contain spaces + // (though typically tokens don't, this is more robust) + token := strings.TrimSpace(strings.Join(fields[1:], " ")) + if token == "" { + log.Println("Auth: Empty token after Bearer") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } userEmail, err := ValidateToken(token) if err != nil { @@ -58,7 +77,7 @@ func GetUserEmailFromContext(ctx context.Context) (string, bool) { // it extracts the email from the token (e.g., "email:user@example.com" -> "user@example.com"). // Otherwise, it returns "test@example.com" as the default test user. func ValidateToken(token string) (string, error) { - if token == "" { + if strings.TrimSpace(token) == "" || strings.TrimSpace(token) == "email:" { return "", fmt.Errorf("token is empty") } diff --git a/backend/internal/auth/middleware_test.go b/backend/internal/auth/middleware_test.go index da4c940..4034888 100644 --- a/backend/internal/auth/middleware_test.go +++ b/backend/internal/auth/middleware_test.go @@ -3,6 +3,7 @@ package auth import ( "net/http" "net/http/httptest" + "os" "testing" ) @@ -81,6 +82,68 @@ func TestRequireAuth(t *testing.T) { t.Errorf("Expected status 401, got %d", rr.Code) } }) + + t.Run("handles multiple spaces between Bearer and token", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer valid_token_12345") + + rr := httptest.NewRecorder() + authHandler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } + }) + + t.Run("handles case-insensitive Bearer scheme", func(t *testing.T) { + testCases := []string{ + "bearer valid_token_12345", + "BEARER valid_token_12345", + "BeArEr valid_token_12345", + "Bearer valid_token_12345", + } + + for _, authHeader := range testCases { + t.Run(authHeader, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", authHeader) + + rr := httptest.NewRecorder() + authHandler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200 for %s, got %d", authHeader, rr.Code) + } + }) + } + }) + + t.Run("handles token with leading/trailing whitespace", func(t *testing.T) { + testCases := []struct { + name string + token string + }{ + {"leading space", " Bearer valid_token_12345"}, + {"trailing space", "Bearer valid_token_12345 "}, + {"both spaces", "Bearer valid_token_12345 "}, + {"tabs", "Bearer\tvalid_token_12345\t"}, + {"newlines", "Bearer\nvalid_token_12345\n"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", tc.token) + + rr := httptest.NewRecorder() + authHandler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200 for %s, got %d", tc.name, rr.Code) + } + }) + } + }) } func TestGetUserEmailFromContext(t *testing.T) { @@ -125,4 +188,62 @@ func TestValidateToken(t *testing.T) { t.Error("Expected non-empty email") } }) + + t.Run("extracts email from token when VMAIL_TEST_MODE=true", func(t *testing.T) { + originalValue := os.Getenv("VMAIL_TEST_MODE") + defer func(key, value string) { + err := os.Setenv(key, value) + if err != nil { + t.Fatalf("Failed to restore %s: %v", key, err) + } + }("VMAIL_TEST_MODE", originalValue) + + err := os.Setenv("VMAIL_TEST_MODE", "true") + if err != nil { + t.Fatalf("Failed to set VMAIL_TEST_MODE: %v", err) + return + } + + email, err := ValidateToken("email:testuser@example.com") + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if email != "testuser@example.com" { + t.Errorf("Expected email 'testuser@example.com', got %s", email) + } + }) + + t.Run("returns error for empty token", func(t *testing.T) { + testCases := []string{"", " ", "\t", "\n"} + for _, token := range testCases { + _, err := ValidateToken(token) + if err == nil { + t.Errorf("Expected error for empty/whitespace token: %q", token) + } + } + }) + + t.Run("returns error when VMAIL_TEST_MODE=true and token is email: with empty email", func(t *testing.T) { + originalValue := os.Getenv("VMAIL_TEST_MODE") + defer func(key, value string) { + err := os.Setenv(key, value) + if err != nil { + t.Fatalf("Failed to restore %s: %v", key, err) + } + }("VMAIL_TEST_MODE", originalValue) + + err := os.Setenv("VMAIL_TEST_MODE", "true") + if err != nil { + t.Fatalf("Failed to set VMAIL_TEST_MODE: %v", err) + return + } + + testCases := []string{"email:", "email: ", "email:\t"} + for _, token := range testCases { + _, err := ValidateToken(token) + if err == nil { + t.Errorf("Expected error for token with empty email: %q", token) + } + } + }) } diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3902d22..e2c478c 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1,25 +1,47 @@ package config import ( + "encoding/base64" "fmt" + "log" + "net/url" "os" + "strconv" "github.com/joho/godotenv" ) // Config holds the application configuration loaded from environment variables. type Config struct { - Environment string + // Environment is the deployment environment (development, production, etc.). + // Defaults to "development" if VMAIL_ENV is not set. + Environment string + // EncryptionKeyBase64 is the base64-encoded encryption key used for encrypting/decrypting + // user credentials. Must be 32 bytes when decoded (44 characters in base64). EncryptionKeyBase64 string - AutheliaURL string - DBHost string - DBPort string - DBUsername string - DBPassword string - DBName string - DBSSLMode string - Port string - Timezone string + // AutheliaURL is the base URL of the Authelia authentication server. + AutheliaURL string + // DBHost is the PostgreSQL database hostname. Defaults to "localhost". + DBHost string + // DBPort is the PostgreSQL database port. Defaults to "5432". + DBPort string + // DBUsername is the PostgreSQL database username. Defaults to "vmail". + DBUsername string + // DBPassword is the PostgreSQL database password. Required, no default. + DBPassword string + // DBName is the PostgreSQL database name. Defaults to "vmail". + DBName string + // DBSSLMode is the PostgreSQL SSL mode (disable, require, verify-full, etc.). Defaults to "disable". + DBSSLMode string + // Port is the HTTP server port. Defaults to "11764". + Port string + // Timezone is the application timezone (e.g., "UTC", "America/New_York"). Defaults to "UTC". + Timezone string + // IMAPMaxWorkers is the maximum number of IMAP worker connections per user. + // This controls concurrency against the IMAP server. In production this should + // be kept conservative to respect provider limits. In test environments it can + // be higher to speed up E2E tests. + IMAPMaxWorkers int } // NewConfig loads and returns a new Config instance from environment variables. @@ -31,7 +53,7 @@ func NewConfig() (*Config, error) { if env == "development" { if err := godotenv.Load(); err != nil { - fmt.Println("Warning: .env file not found, using environment variables") + log.Printf("Warning: .env file not found, using environment variables") } } @@ -47,6 +69,7 @@ func NewConfig() (*Config, error) { DBSSLMode: getEnvOrDefault("VMAIL_DB_SSLMODE", "disable"), Port: getEnvOrDefault("PORT", "11764"), Timezone: getEnvOrDefault("TZ", "UTC"), + IMAPMaxWorkers: getEnvOrDefaultInt("VMAIL_IMAP_MAX_WORKERS", 3), } if err := config.Validate(); err != nil { @@ -56,29 +79,74 @@ func NewConfig() (*Config, error) { return config, nil } -// Validate checks that all required configuration values are set. +// Validate checks that all required configuration values are set and valid. func (c *Config) Validate() error { if c.EncryptionKeyBase64 == "" { return fmt.Errorf("VMAIL_ENCRYPTION_KEY_BASE64 is required") } + // Validate EncryptionKeyBase64 format: must be valid base64 and decode to 32 bytes + decoded, err := base64.StdEncoding.DecodeString(c.EncryptionKeyBase64) + if err != nil { + return fmt.Errorf("VMAIL_ENCRYPTION_KEY_BASE64 is not valid base64: %w", err) + } + if len(decoded) != 32 { + return fmt.Errorf("VMAIL_ENCRYPTION_KEY_BASE64 must decode to 32 bytes, got %d bytes", len(decoded)) + } + if c.AutheliaURL == "" { return fmt.Errorf("AUTHELIA_URL is required") } + // Validate AutheliaURL format: must be a valid URL with http or https scheme + parsedURL, err := url.Parse(c.AutheliaURL) + if err != nil { + return fmt.Errorf("AUTHELIA_URL is not a valid URL: %w", err) + } + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return fmt.Errorf("AUTHELIA_URL must use http:// or https:// scheme, got: %s", parsedURL.Scheme) + } + if c.DBPassword == "" { return fmt.Errorf("VMAIL_DB_PASSWORD is required") } + // Validate DBPort format: must be a valid port number (1-65535) + if err := validatePort(c.DBPort); err != nil { + return fmt.Errorf("VMAIL_DB_PORT is not a valid port number: %w", err) + } + + // Validate Port format: must be a valid port number (1-65535) + if err := validatePort(c.Port); err != nil { + return fmt.Errorf("PORT is not a valid port number: %w", err) + } + + return nil +} + +// validatePort checks if a string represents a valid port number (1-65535). +func validatePort(portStr string) error { + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("port must be a number: %w", err) + } + if port < 1 || port > 65535 { + return fmt.Errorf("port must be between 1 and 65535, got %d", port) + } return nil } // GetDatabaseURL returns a PostgreSQL connection string built from the configuration. +// The password and username are properly URL-encoded to handle special characters. func (c *Config) GetDatabaseURL() string { + // URL-encode username and password to handle special characters + encodedUsername := url.QueryEscape(c.DBUsername) + encodedPassword := url.QueryEscape(c.DBPassword) + return fmt.Sprintf( "postgres://%s:%s@%s:%s/%s?sslmode=%s", - c.DBUsername, - c.DBPassword, + encodedUsername, + encodedPassword, c.DBHost, c.DBPort, c.DBName, @@ -86,9 +154,25 @@ func (c *Config) GetDatabaseURL() string { ) } +// getEnvOrDefault retrieves an environment variable, returning the default value if not set or empty. func getEnvOrDefault(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } + +// getEnvOrDefaultInt retrieves an environment variable as an int, returning the +// default value if not set, empty, or invalid. +func getEnvOrDefaultInt(key string, defaultValue int) int { + value := os.Getenv(key) + if value == "" { + return defaultValue + } + parsed, err := strconv.Atoi(value) + if err != nil { + log.Printf("Warning: %s is not a valid integer (%q), using default %d", key, value, defaultValue) + return defaultValue + } + return parsed +} diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index e4bb924..4e7c4a3 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,7 +1,9 @@ package config import ( + "net/url" "os" + "strings" "testing" ) @@ -77,7 +79,7 @@ func TestNewConfig(t *testing.T) { func TestNewConfigWithDefaults(t *testing.T) { _ = os.Setenv("VMAIL_ENV", "production") - _ = os.Setenv("VMAIL_ENCRYPTION_KEY_BASE64", "test-key") + _ = os.Setenv("VMAIL_ENCRYPTION_KEY_BASE64", "dGVzdC1rZXktMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM=") _ = os.Setenv("AUTHELIA_URL", "http://authelia:9091") _ = os.Setenv("VMAIL_DB_PASSWORD", "password") @@ -131,6 +133,8 @@ func TestValidate(t *testing.T) { EncryptionKeyBase64: "dGVzdC1rZXktMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM=", AutheliaURL: "http://authelia:9091", DBPassword: "password", + DBPort: "5432", + Port: "11764", }, shouldErr: false, }, @@ -139,6 +143,8 @@ func TestValidate(t *testing.T) { config: &Config{ AutheliaURL: "http://authelia:9091", DBPassword: "password", + DBPort: "5432", + Port: "11764", }, shouldErr: true, errMsg: "VMAIL_ENCRYPTION_KEY_BASE64 is required", @@ -148,6 +154,8 @@ func TestValidate(t *testing.T) { config: &Config{ EncryptionKeyBase64: "dGVzdC1rZXktMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM=", DBPassword: "password", + DBPort: "5432", + Port: "11764", }, shouldErr: true, errMsg: "AUTHELIA_URL is required", @@ -157,6 +165,8 @@ func TestValidate(t *testing.T) { config: &Config{ EncryptionKeyBase64: "dGVzdC1rZXktMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM=", AutheliaURL: "http://authelia:9091", + DBPort: "5432", + Port: "11764", }, shouldErr: true, errMsg: "VMAIL_DB_PASSWORD is required", @@ -180,21 +190,65 @@ func TestValidate(t *testing.T) { } func TestGetDatabaseURL(t *testing.T) { - config := &Config{ - DBUsername: "test-user", - DBPassword: "test-password", - DBHost: "localhost", - DBPort: "5432", - DBName: "testdb", - DBSSLMode: "disable", - } + t.Run("basic URL generation", func(t *testing.T) { + config := &Config{ + DBUsername: "test-user", + DBPassword: "test-password", + DBHost: "localhost", + DBPort: "5432", + DBName: "testdb", + DBSSLMode: "disable", + } - expected := "postgres://test-user:test-password@localhost:5432/testdb?sslmode=disable" - got := config.GetDatabaseURL() + expected := "postgres://test-user:test-password@localhost:5432/testdb?sslmode=disable" + got := config.GetDatabaseURL() - if got != expected { - t.Errorf("expected database URL '%s', got '%s'", expected, got) - } + if got != expected { + t.Errorf("expected database URL '%s', got '%s'", expected, got) + } + }) + + t.Run("handles special characters in password", func(t *testing.T) { + config := &Config{ + DBUsername: "test-user", + DBPassword: "p@ss:w/rd%test#", + DBHost: "localhost", + DBPort: "5432", + DBName: "testdb", + DBSSLMode: "disable", + } + + got := config.GetDatabaseURL() + // The password should be URL-encoded + if !strings.Contains(got, "p%40ss%3Aw%2Frd%25test%23") { + t.Errorf("Expected password to be URL-encoded in database URL, got: %s", got) + } + // Verify the URL can be parsed + if _, err := url.Parse(got); err != nil { + t.Errorf("Generated database URL is not valid: %v", err) + } + }) + + t.Run("handles special characters in username", func(t *testing.T) { + config := &Config{ + DBUsername: "user@domain", + DBPassword: "password", + DBHost: "localhost", + DBPort: "5432", + DBName: "testdb", + DBSSLMode: "disable", + } + + got := config.GetDatabaseURL() + // The username should be URL-encoded + if !strings.Contains(got, "user%40domain") { + t.Errorf("Expected username to be URL-encoded in database URL, got: %s", got) + } + // Verify the URL can be parsed + if _, err := url.Parse(got); err != nil { + t.Errorf("Generated database URL is not valid: %v", err) + } + }) } func TestGetEnvOrDefault(t *testing.T) { @@ -213,3 +267,247 @@ func TestGetEnvOrDefault(t *testing.T) { t.Errorf("expected 'default', got '%s'", got) } } + +func TestNewConfigWithEnvFile(t *testing.T) { + originalEnv := os.Getenv("VMAIL_ENV") + defer func(key, value string) { + _ = os.Setenv(key, value) + }("VMAIL_ENV", originalEnv) + + _ = os.Setenv("VMAIL_ENV", "development") + _ = os.Setenv("VMAIL_ENCRYPTION_KEY_BASE64", "dGVzdC1rZXktMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM=") + _ = os.Setenv("AUTHELIA_URL", "http://authelia:9091") + _ = os.Setenv("VMAIL_DB_PASSWORD", "test-password") + + defer func() { + _ = os.Unsetenv("VMAIL_ENV") + _ = os.Unsetenv("VMAIL_ENCRYPTION_KEY_BASE64") + _ = os.Unsetenv("AUTHELIA_URL") + _ = os.Unsetenv("VMAIL_DB_PASSWORD") + }() + + // Note: This test verifies that NewConfig works in development mode. + // The actual .env file loading is tested implicitly - if godotenv.Load() fails, + // it logs a warning but continues (which is acceptable behavior). + config, err := NewConfig() + if err != nil { + t.Fatalf("NewConfig() returned error: %v", err) + } + + if config.Environment != "development" { + t.Errorf("expected Environment 'development', got '%s'", config.Environment) + } +} + +func TestValidateEncryptionKey(t *testing.T) { + tests := []struct { + name string + key string + shouldErr bool + errMsg string + }{ + { + name: "valid 32-byte base64 key", + key: "dGVzdC1rZXktMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM=", + shouldErr: false, + }, + { + name: "invalid base64", + key: "not-valid-base64!!!", + shouldErr: true, + errMsg: "VMAIL_ENCRYPTION_KEY_BASE64 is not valid base64", + }, + { + name: "wrong length (too short)", + key: "dGVzdA==", // "test" in base64, only 4 bytes + shouldErr: true, + errMsg: "VMAIL_ENCRYPTION_KEY_BASE64 must decode to 32 bytes", + }, + { + name: "wrong length (too long)", + key: "dGVzdC1rZXktMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM=", // 64 bytes + shouldErr: true, + errMsg: "VMAIL_ENCRYPTION_KEY_BASE64 must decode to 32 bytes", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &Config{ + EncryptionKeyBase64: tt.key, + AutheliaURL: "http://authelia:9091", + DBPassword: "password", + DBPort: "5432", + Port: "11764", + } + + err := config.Validate() + if tt.shouldErr && err == nil { + t.Errorf("expected error but got none") + } + if !tt.shouldErr && err != nil { + t.Errorf("expected no error but got: %v", err) + } + if tt.shouldErr && err != nil && !contains(err.Error(), tt.errMsg) { + t.Errorf("expected error message to contain '%s', got '%s'", tt.errMsg, err.Error()) + } + }) + } +} + +func TestValidateAutheliaURL(t *testing.T) { + tests := []struct { + name string + url string + shouldErr bool + errMsg string + }{ + { + name: "valid HTTP URL", + url: "http://authelia:9091", + shouldErr: false, + }, + { + name: "valid HTTPS URL", + url: "https://authelia.example.com", + shouldErr: false, + }, + { + name: "invalid URL (wrong scheme)", + url: "authelia:9091", + shouldErr: true, + errMsg: "AUTHELIA_URL must use http:// or https:// scheme", + }, + { + name: "invalid URL (path only)", + url: "/path/to/authelia", + shouldErr: true, + errMsg: "AUTHELIA_URL must use http:// or https:// scheme", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &Config{ + EncryptionKeyBase64: "dGVzdC1rZXktMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM=", + AutheliaURL: tt.url, + DBPassword: "password", + DBPort: "5432", + Port: "11764", + } + + err := config.Validate() + if tt.shouldErr && err == nil { + t.Errorf("expected error but got none") + } + if !tt.shouldErr && err != nil { + t.Errorf("expected no error but got: %v", err) + } + if tt.shouldErr && err != nil && !contains(err.Error(), tt.errMsg) { + t.Errorf("expected error message to contain '%s', got '%s'", tt.errMsg, err.Error()) + } + }) + } +} + +func TestValidatePort(t *testing.T) { + tests := []struct { + name string + dbPort string + port string + shouldErr bool + errMsg string + }{ + { + name: "valid ports", + dbPort: "5432", + port: "11764", + shouldErr: false, + }, + { + name: "invalid DBPort (not a number)", + dbPort: "not-a-port", + port: "11764", + shouldErr: true, + errMsg: "VMAIL_DB_PORT is not a valid port number", + }, + { + name: "invalid Port (not a number)", + dbPort: "5432", + port: "not-a-port", + shouldErr: true, + errMsg: "PORT is not a valid port number", + }, + { + name: "invalid DBPort (too low)", + dbPort: "0", + port: "11764", + shouldErr: true, + errMsg: "VMAIL_DB_PORT is not a valid port number", + }, + { + name: "invalid DBPort (too high)", + dbPort: "65536", + port: "11764", + shouldErr: true, + errMsg: "VMAIL_DB_PORT is not a valid port number", + }, + { + name: "invalid Port (too low)", + dbPort: "5432", + port: "0", + shouldErr: true, + errMsg: "PORT is not a valid port number", + }, + { + name: "invalid Port (too high)", + dbPort: "5432", + port: "65536", + shouldErr: true, + errMsg: "PORT is not a valid port number", + }, + { + name: "valid boundary ports", + dbPort: "1", + port: "65535", + shouldErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &Config{ + EncryptionKeyBase64: "dGVzdC1rZXktMTIzNDU2Nzg5MDEyMzQ1Njc4OTAxMjM=", + AutheliaURL: "http://authelia:9091", + DBPassword: "password", + DBPort: tt.dbPort, + Port: tt.port, + } + + err := config.Validate() + if tt.shouldErr && err == nil { + t.Errorf("expected error but got none") + } + if !tt.shouldErr && err != nil { + t.Errorf("expected no error but got: %v", err) + } + if tt.shouldErr && err != nil && !contains(err.Error(), tt.errMsg) { + t.Errorf("expected error message to contain '%s', got '%s'", tt.errMsg, err.Error()) + } + }) + } +} + +// contains checks if a string contains a substring (case-sensitive). +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || findSubstring(s, substr)) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/backend/internal/crypto/encryption.go b/backend/internal/crypto/encryption.go index 4ef00bf..608516e 100644 --- a/backend/internal/crypto/encryption.go +++ b/backend/internal/crypto/encryption.go @@ -9,7 +9,9 @@ import ( "io" ) -// Encryptor provides encryption and decryption functionality using AES-GCM. +// Encryptor provides encryption and decryption functionality using AES-GCM (Galois/Counter Mode). +// AES-GCM provides both confidentiality and authenticity, making it suitable for encrypting +// sensitive data like user passwords. The key is stored in memory as plain bytes. type Encryptor struct { key []byte } @@ -29,6 +31,9 @@ func NewEncryptor(base64Key string) (*Encryptor, error) { } // Encrypt encrypts the given plaintext using AES-GCM. +// The returned ciphertext format is: [nonce][encrypted_data][auth_tag] +// where the nonce is prepended to the ciphertext for use during decryption. +// Each encryption uses a random nonce, ensuring the same plaintext produces different ciphertexts. func (e *Encryptor) Encrypt(plaintext string) ([]byte, error) { block, err := aes.NewCipher(e.key) if err != nil { @@ -50,6 +55,9 @@ func (e *Encryptor) Encrypt(plaintext string) ([]byte, error) { } // Decrypt decrypts the given ciphertext using AES-GCM. +// The ciphertext format is expected to be: [nonce][encrypted_data][auth_tag] +// where the nonce is prepended. Returns an error if the ciphertext is invalid, +// corrupted, or was encrypted with a different key (authentication failure). func (e *Encryptor) Decrypt(ciphertext []byte) (string, error) { block, err := aes.NewCipher(e.key) if err != nil { diff --git a/backend/internal/crypto/encryption_test.go b/backend/internal/crypto/encryption_test.go index 0da6b8b..4206ce2 100644 --- a/backend/internal/crypto/encryption_test.go +++ b/backend/internal/crypto/encryption_test.go @@ -141,4 +141,118 @@ func TestDecryptInvalidCiphertext(t *testing.T) { t.Error("Expected error for corrupted ciphertext, got nil") } }) + + t.Run("wrong key", func(t *testing.T) { + // Create first encryptor with one key + key1 := make([]byte, 32) + for i := range key1 { + key1[i] = byte(i) + } + base64Key1 := base64.StdEncoding.EncodeToString(key1) + encryptor1, err := NewEncryptor(base64Key1) + if err != nil { + t.Fatalf("Failed to create encryptor1: %v", err) + } + + // Create second encryptor with a different key + key2 := make([]byte, 32) + for i := range key2 { + key2[i] = byte(i + 100) // Different key + } + base64Key2 := base64.StdEncoding.EncodeToString(key2) + encryptor2, err := NewEncryptor(base64Key2) + if err != nil { + t.Fatalf("Failed to create encryptor2: %v", err) + } + + // Encrypt with first encryptor + plaintext := "secret password" + ciphertext, err := encryptor1.Encrypt(plaintext) + if err != nil { + t.Fatalf("Failed to encrypt: %v", err) + } + + // Try to decrypt with second encryptor (wrong key) - should fail + _, err = encryptor2.Decrypt(ciphertext) + if err == nil { + t.Error("Expected error when decrypting with wrong key, got nil") + } + }) +} + +func TestDecryptWithDifferentInstanceSameKey(t *testing.T) { + // Create a shared key + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + base64Key := base64.StdEncoding.EncodeToString(key) + + // Create first encryptor instance + encryptor1, err := NewEncryptor(base64Key) + if err != nil { + t.Fatalf("Failed to create encryptor1: %v", err) + } + + // Create second encryptor instance with the same key + encryptor2, err := NewEncryptor(base64Key) + if err != nil { + t.Fatalf("Failed to create encryptor2: %v", err) + } + + plaintext := "shared secret" + ciphertext, err := encryptor1.Encrypt(plaintext) + if err != nil { + t.Fatalf("Failed to encrypt with encryptor1: %v", err) + } + + // Decrypt with second encryptor instance (same key) - should succeed + decrypted, err := encryptor2.Decrypt(ciphertext) + if err != nil { + t.Fatalf("Failed to decrypt with encryptor2: %v", err) + } + + if decrypted != plaintext { + t.Errorf("Expected %q, got %q", plaintext, decrypted) + } +} + +func TestEncryptDecryptLargePlaintext(t *testing.T) { + key := make([]byte, 32) + base64Key := base64.StdEncoding.EncodeToString(key) + + encryptor, err := NewEncryptor(base64Key) + if err != nil { + t.Fatalf("Failed to create encryptor: %v", err) + } + + // Create a very large plaintext (1MB) + largePlaintext := make([]byte, 1024*1024) + for i := range largePlaintext { + largePlaintext[i] = byte(i % 256) + } + plaintextStr := string(largePlaintext) + + ciphertext, err := encryptor.Encrypt(plaintextStr) + if err != nil { + t.Fatalf("Failed to encrypt large plaintext: %v", err) + } + + if len(ciphertext) == 0 { + t.Fatal("Expected non-empty ciphertext") + } + + decrypted, err := encryptor.Decrypt(ciphertext) + if err != nil { + t.Fatalf("Failed to decrypt large plaintext: %v", err) + } + + if decrypted != plaintextStr { + t.Error("Decrypted plaintext does not match original") + } + + // Verify length matches + if len(decrypted) != len(plaintextStr) { + t.Errorf("Expected decrypted length %d, got %d", len(plaintextStr), len(decrypted)) + } } diff --git a/backend/internal/db/user.go b/backend/internal/db/user.go new file mode 100644 index 0000000..98fcbae --- /dev/null +++ b/backend/internal/db/user.go @@ -0,0 +1,27 @@ +package db + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5/pgxpool" +) + +// GetOrCreateUser returns the user's id for the given email. +// If no user exists with that email, it creates a new one. +func GetOrCreateUser(ctx context.Context, pool *pgxpool.Pool, email string) (string, error) { + var userID string + + err := pool.QueryRow(ctx, ` + INSERT INTO users (email) + VALUES ($1) + ON CONFLICT (email) DO UPDATE SET email = EXCLUDED.email + RETURNING id + `, email).Scan(&userID) + + if err != nil { + return "", fmt.Errorf("failed to get or create user: %w", err) + } + + return userID, nil +} diff --git a/backend/internal/db/user_settings.go b/backend/internal/db/user_settings.go index 3e85423..baf3c87 100644 --- a/backend/internal/db/user_settings.go +++ b/backend/internal/db/user_settings.go @@ -13,25 +13,6 @@ import ( // ErrUserSettingsNotFound is returned when user settings cannot be found. var ErrUserSettingsNotFound = errors.New("user settings not found") -// GetOrCreateUser returns the user's id for the given email. -// If no user exists with that email, it creates a new one. -func GetOrCreateUser(ctx context.Context, pool *pgxpool.Pool, email string) (string, error) { - var userID string - - err := pool.QueryRow(ctx, ` - INSERT INTO users (email) - VALUES ($1) - ON CONFLICT (email) DO UPDATE SET email = EXCLUDED.email - RETURNING id - `, email).Scan(&userID) - - if err != nil { - return "", fmt.Errorf("failed to get or create user: %w", err) - } - - return userID, nil -} - // UserSettingsExist returns true if the user settings exist. func UserSettingsExist(ctx context.Context, pool *pgxpool.Pool, userID string) (bool, error) { var exists bool diff --git a/backend/internal/db/user_settings_test.go b/backend/internal/db/user_settings_test.go index 53cbfb5..a28e970 100644 --- a/backend/internal/db/user_settings_test.go +++ b/backend/internal/db/user_settings_test.go @@ -10,44 +10,6 @@ import ( "github.com/vdavid/vmail/backend/internal/testutil" ) -func TestGetOrCreateUser(t *testing.T) { - pool := testutil.NewTestDB(t) - defer pool.Close() - - ctx := context.Background() - - t.Run("creates new user", func(t *testing.T) { - email := "test@example.com" - - userID, err := GetOrCreateUser(ctx, pool, email) - if err != nil { - t.Fatalf("GetOrCreateUser failed: %v", err) - } - - if userID == "" { - t.Fatal("Expected non-empty user ID") - } - }) - - t.Run("returns existing user", func(t *testing.T) { - email := "existing@example.com" - - userID1, err := GetOrCreateUser(ctx, pool, email) - if err != nil { - t.Fatalf("First GetOrCreateUser failed: %v", err) - } - - userID2, err := GetOrCreateUser(ctx, pool, email) - if err != nil { - t.Fatalf("Second GetOrCreateUser failed: %v", err) - } - - if userID1 != userID2 { - t.Errorf("Expected same user ID, got %s and %s", userID1, userID2) - } - }) -} - func TestUserSettingsExist(t *testing.T) { pool := testutil.NewTestDB(t) defer pool.Close() diff --git a/backend/internal/db/user_test.go b/backend/internal/db/user_test.go new file mode 100644 index 0000000..9914cdd --- /dev/null +++ b/backend/internal/db/user_test.go @@ -0,0 +1,46 @@ +package db + +import ( + "context" + "testing" + + "github.com/vdavid/vmail/backend/internal/testutil" +) + +func TestGetOrCreateUser(t *testing.T) { + pool := testutil.NewTestDB(t) + defer pool.Close() + + ctx := context.Background() + + t.Run("creates new user", func(t *testing.T) { + email := "test@example.com" + + userID, err := GetOrCreateUser(ctx, pool, email) + if err != nil { + t.Fatalf("GetOrCreateUser failed: %v", err) + } + + if userID == "" { + t.Fatal("Expected non-empty user ID") + } + }) + + t.Run("returns existing user", func(t *testing.T) { + email := "existing@example.com" + + userID1, err := GetOrCreateUser(ctx, pool, email) + if err != nil { + t.Fatalf("First GetOrCreateUser failed: %v", err) + } + + userID2, err := GetOrCreateUser(ctx, pool, email) + if err != nil { + t.Fatalf("Second GetOrCreateUser failed: %v", err) + } + + if userID1 != userID2 { + t.Errorf("Expected same user ID, got %s and %s", userID1, userID2) + } + }) +} diff --git a/backend/internal/imap/client.go b/backend/internal/imap/client.go index c367dc0..8f33a15 100644 --- a/backend/internal/imap/client.go +++ b/backend/internal/imap/client.go @@ -2,102 +2,68 @@ package imap import ( "fmt" - "log" "net" - "os" "sync" "time" - "github.com/emersion/go-imap" "github.com/emersion/go-imap/client" ) -// Pool manages IMAP connections per user. -type Pool struct { - clients map[string]*client.Client - mu sync.RWMutex -} - -// NewPool creates a new IMAP connection pool. -func NewPool() *Pool { - return &Pool{ - clients: make(map[string]*client.Client), - } -} +// clientRole indicates the purpose of a client. +type clientRole int -// getClientConcrete gets or creates an IMAP client for a user (internal use). -// Returns the concrete *client.Client type for internal operations. -func (p *Pool) getClientConcrete(userID, server, username, password string) (*client.Client, error) { - p.mu.RLock() - c, exists := p.clients[userID] - p.mu.RUnlock() - - if exists && c != nil { - // Check if the connection is still alive - state := c.State() - // ConnState values: 0=NotAuthenticated, 1=Authenticated, 2=Selected - if state == imap.AuthenticatedState || state == imap.SelectedState { - return c, nil - } - // Connection is dead, remove it - p.mu.Lock() - delete(p.clients, userID) - p.mu.Unlock() - } - - // Create a new connection (use TLS in production, non-TLS for tests) - // Check environment variable for test mode - useTLS := os.Getenv("VMAIL_TEST_MODE") != "true" - c, err := ConnectToIMAP(server, useTLS) - if err != nil { - return nil, fmt.Errorf("failed to connect: %w", err) - } +const ( + // roleWorker indicates a worker client. There can be multiple worker clients per user. + roleWorker clientRole = iota + // roleListener indicates a listener client. There can be only one listener client per user. + roleListener +) - if err := Login(c, username, password); err != nil { - _ = c.Logout() - return nil, fmt.Errorf("failed to login: %w", err) - } +// threadSafeClient wraps an IMAP client with a mutex for thread-safe access. +// Each client has its own mutex to allow concurrent access to different clients +// while serializing access to the same client. +type threadSafeClient struct { + client *client.Client + mu sync.Mutex + lastUsed time.Time + role clientRole +} - p.mu.Lock() - p.clients[userID] = c - p.mu.Unlock() +// Lock acquires the mutex for thread-safe access to the underlying client. +func (c *threadSafeClient) Lock() { + c.mu.Lock() +} - return c, nil +// Unlock releases the mutex. +func (c *threadSafeClient) Unlock() { + c.mu.Unlock() } -// GetClient gets or creates an IMAP client for a user. -// Implements IMAPPool interface - returns IMAPClient for testability. -func (p *Pool) GetClient(userID, server, username, password string) (IMAPClient, error) { - c, err := p.getClientConcrete(userID, server, username, password) - if err != nil { - return nil, err - } - return &ClientWrapper{client: c}, nil +// TryLock attempts to acquire the mutex without blocking. +// Returns true if the lock was acquired, false otherwise. +func (c *threadSafeClient) TryLock() bool { + return c.mu.TryLock() } -// RemoveClient removes a client from the pool and logs out. -func (p *Pool) RemoveClient(userID string) { - p.mu.Lock() - defer p.mu.Unlock() +// GetClient returns the underlying IMAP client (for internal use). +// Caller must hold the lock before calling this. +func (c *threadSafeClient) GetClient() *client.Client { + return c.client +} - c, exists := p.clients[userID] - if exists { - _ = c.Logout() - delete(p.clients, userID) - } +// UpdateLastUsed updates the lastUsed timestamp to now. +func (c *threadSafeClient) UpdateLastUsed() { + c.lastUsed = time.Now() } -// Close closes all connections in the pool. -func (p *Pool) Close() { - p.mu.Lock() - defer p.mu.Unlock() +// GetLastUsed returns the lastUsed timestamp. +func (c *threadSafeClient) GetLastUsed() time.Time { + return c.lastUsed +} - for userID, c := range p.clients { - if err := c.Logout(); err != nil { - log.Printf("Failed to logout IMAP client for user %s: %v", userID, err) - } - delete(p.clients, userID) - } +// GetRole returns the client role (worker or listener). +func (c *threadSafeClient) GetRole() clientRole { + return c.role } // ConnectToIMAP connects to the IMAP server with a 5-second timeout. diff --git a/backend/internal/imap/client_test.go b/backend/internal/imap/client_test.go deleted file mode 100644 index 6180294..0000000 --- a/backend/internal/imap/client_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package imap - -import ( - "testing" - - "github.com/emersion/go-imap" -) - -func TestPool_GetClient(t *testing.T) { - pool := NewPool() - defer pool.Close() - - t.Run("creates new client when none exists", func(t *testing.T) { - // This test would require a real IMAP server or a mock. - // For now, we test that the pool structure works - if pool == nil { - t.Error("Expected pool to be created") - } - }) - - t.Run("removes client from pool", func(t *testing.T) { - pool.RemoveClient("test-user") - // Should not panic - }) - - t.Run("removes and recreates client when connection is dead", func(t *testing.T) { - // This test verifies the reconnection logic in GetClient. - // The logic should: - // 1. Check if the client exists and is in AuthenticatedState or SelectedState - // 2. If the client is in NotAuthenticatedState (or any other invalid state), remove it - // 3. Create a new connection - // - // To properly test this, we would need: - // - A mock IMAP client that can return different states - // - Or a real IMAP server that we can disconnect - // - // For now, we verify the pool structure and that RemoveClient works - userID := "test-reconnect-user" - pool.RemoveClient(userID) // Clean up if exists - - // The actual reconnection test would: - // 1. Manually add a mock client in NotAuthenticatedState to the pool - // 2. Call GetClient - // 3. Assert that the old client was removed and a new one was created - // - // This requires refactoring Pool to accept a client factory function - // or using interfaces to inject mock clients. - _ = userID - }) -} - -//goland:noinspection GoBoolExpressions -func TestPool_GetClient_ReconnectionLogic(t *testing.T) { - // This test documents the expected reconnection behavior: - // - // When GetClient is called and a client exists in the pool: - // 1. Check client.State() - // 2. If the state is imap.AuthenticatedState or imap.SelectedState, return the existing client - // 3. If the state is imap.NotAuthenticatedState (or any other state), remove the client from the pool - // 4. Create new connection and add to pool - // - // To test this properly, we need: - // - Interface for IMAP client with State() method - // - Mock client that can return different states - // - Ability to inject mock into pool - // - // Example test structure: - // mockClient := &MockClient{state: imap.NotAuthenticatedState} - // pool.clients["user"] = mockClient - // newClient, err := pool.GetClient("user", "server", "user", "pass") - // assert mockClient was removed - // assert newClient is different from mockClient - // assert newClient.State() is AuthenticatedState or SelectedState - - // Verify the state constants exist and are distinct - // The actual values may vary by go-imap version, but they should be distinct - if imap.NotAuthenticatedState == imap.AuthenticatedState { - t.Error("NotAuthenticatedState and AuthenticatedState should be different") - } - if imap.AuthenticatedState == imap.SelectedState { - t.Error("AuthenticatedState and SelectedState should be different") - } - if imap.NotAuthenticatedState == imap.SelectedState { - t.Error("NotAuthenticatedState and SelectedState should be different") - } - - // Log the actual values for reference - t.Logf("IMAP state constants: NotAuthenticatedState=%d, AuthenticatedState=%d, SelectedState=%d", - imap.NotAuthenticatedState, imap.AuthenticatedState, imap.SelectedState) -} - -func TestPool_Close(t *testing.T) { - pool := NewPool() - - t.Run("closes all clients", func(t *testing.T) { - pool.Close() - // Should not panic - }) - - t.Run("can be called multiple times safely", func(t *testing.T) { - pool := NewPool() - pool.Close() - pool.Close() // Should not panic - }) -} diff --git a/backend/internal/imap/fetch.go b/backend/internal/imap/fetch.go index 1167e13..70f0d0c 100644 --- a/backend/internal/imap/fetch.go +++ b/backend/internal/imap/fetch.go @@ -8,6 +8,7 @@ import ( ) // FetchMessageHeaders fetches message headers for the given UIDs. +// Returns envelope, body structure, flags, and UID for each message. func FetchMessageHeaders(c *client.Client, uids []uint32) ([]*imap.Message, error) { if c == nil { return nil, fmt.Errorf("client is nil") @@ -50,6 +51,7 @@ func FetchMessageHeaders(c *client.Client, uids []uint32) ([]*imap.Message, erro } // FetchFullMessage fetches the full message body for the given UID. +// First fetches headers and body structure, then fetches the actual body content. func FetchFullMessage(c *client.Client, uid uint32) (*imap.Message, error) { if c == nil { return nil, fmt.Errorf("client is nil") @@ -99,7 +101,11 @@ func FetchFullMessage(c *client.Client, uid uint32) (*imap.Message, error) { if bodyMsg != nil { msg.Body = bodyMsg.Body } - <-bodyDone + if err := <-bodyDone; err != nil { + // Log error but don't fail - we still have headers and structure + // The body fetch failure is non-critical for basic message retrieval + return nil, fmt.Errorf("failed to fetch message body: %w", err) + } } return msg, nil @@ -107,34 +113,52 @@ func FetchFullMessage(c *client.Client, uid uint32) (*imap.Message, error) { // SearchUIDsSince searches for all UIDs greater than or equal to the given UID. // This is used for incremental sync to find only new messages. +// +// Performance note: This function fetches all UIDs and filters them client-side. +// While IMAP supports UID SEARCH with ranges (e.g., "UID minUID:*"), the go-imap +// library's SearchCriteria doesn't expose this capability directly. The current +// approach is acceptable because: +// 1. We're only fetching UID numbers (not message content), which is fast +// 2. Client-side filtering is efficient for typical mailbox sizes +// 3. Most mailboxes have < 100k messages, making this approach practical +// +// For very large mailboxes (> 1M messages), consider: +// - Using IMAP's native UID SEARCH with ranges if go-imap adds support +// - Implementing batch fetching with pagination +// - Using server-side filtering if the IMAP server supports extensions func SearchUIDsSince(c *client.Client, minUID uint32) ([]uint32, error) { if c == nil { return nil, fmt.Errorf("client is nil") } - // Create a SeqSet with the range minUID:* - // This represents all UIDs from minUID to the highest UID - seqSet := new(imap.SeqSet) - seqSet.AddRange(minUID, 0) // 0 means "highest UID" - - // Use SEARCH to find UIDs in this range - // We'll use a simple approach: fetch UIDs for all messages in the range - // Actually, IMAP SEARCH doesn't work with SeqSet directly for UID ranges - // Instead, we need to use the SEARCH command with UID criteria - - // The go-imap library's UidSearch doesn't directly support UID ranges, - // but we can fetch all UIDs and filter them, or use a different approach. - // For now, let's fetch all UIDs and filter - this is still efficient - // because we're only getting UID numbers, not message content. - + // Fetch all UIDs from the server + // Note: go-imap's UidSearch doesn't support UID ranges in SearchCriteria, + // so we fetch all UIDs and filter client-side. This is efficient for typical + // mailbox sizes since we're only transferring UID numbers. searchCriteria := imap.NewSearchCriteria() uids, err := c.UidSearch(searchCriteria) if err != nil { return nil, fmt.Errorf("failed to search for UIDs: %w", err) } + // Early exit if no UIDs or minUID is higher than all UIDs + if len(uids) == 0 { + return []uint32{}, nil + } + + // If minUID is higher than the highest UID, return empty + if minUID > uids[len(uids)-1] { + return []uint32{}, nil + } + // Filter to only UIDs >= minUID - var filteredUIDs []uint32 + // Pre-allocate slice with estimated capacity (assuming UIDs are roughly evenly distributed) + estimatedSize := len(uids) + if minUID > 0 { + // Rough estimate: if minUID is halfway, we'll get about half the UIDs + estimatedSize = len(uids) / 2 + } + filteredUIDs := make([]uint32, 0, estimatedSize) for _, uid := range uids { if uid >= minUID { filteredUIDs = append(filteredUIDs, uid) diff --git a/backend/internal/imap/fetch_test.go b/backend/internal/imap/fetch_test.go new file mode 100644 index 0000000..cd8cb72 --- /dev/null +++ b/backend/internal/imap/fetch_test.go @@ -0,0 +1,158 @@ +package imap + +import ( + "testing" + "time" + + "github.com/vdavid/vmail/backend/internal/testutil" +) + +func TestFetchMessageHeaders(t *testing.T) { + t.Run("returns error for nil client", func(t *testing.T) { + _, err := FetchMessageHeaders(nil, []uint32{1, 2, 3}) + if err == nil { + t.Error("Expected error for nil client") + } + if err.Error() != "client is nil" { + t.Errorf("Expected error 'client is nil', got: %v", err) + } + }) + + t.Run("returns empty slice for empty UIDs", func(t *testing.T) { + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + client, cleanup := server.Connect(t) + defer cleanup() + + result, err := FetchMessageHeaders(client, []uint32{}) + if err != nil { + t.Errorf("Expected no error for empty UIDs, got: %v", err) + } + if result == nil { + t.Error("Expected empty slice, got nil") + } + if len(result) != 0 { + t.Errorf("Expected empty slice, got %d items", len(result)) + } + }) + + t.Run("fetches message headers successfully", func(t *testing.T) { + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + server.EnsureINBOX(t) + + // Add a test message + uid := server.AddMessage(t, "INBOX", "", "Test Subject", "from@example.com", "to@example.com", time.Now()) + + client, cleanup := server.Connect(t) + defer cleanup() + + // Select INBOX + _, err := client.Select("INBOX", false) + if err != nil { + t.Fatalf("Failed to select INBOX: %v", err) + } + + // Fetch headers + messages, err := FetchMessageHeaders(client, []uint32{uid}) + if err != nil { + t.Fatalf("Failed to fetch message headers: %v", err) + } + + if len(messages) != 1 { + t.Errorf("Expected 1 message, got %d", len(messages)) + } + + if messages[0].Uid != uid { + t.Errorf("Expected UID %d, got %d", uid, messages[0].Uid) + } + + if messages[0].Envelope == nil { + t.Error("Expected envelope, got nil") + } + }) +} + +func TestFetchFullMessage(t *testing.T) { + t.Run("returns error for nil client", func(t *testing.T) { + _, err := FetchFullMessage(nil, 1) + if err == nil { + t.Error("Expected error for nil client") + } + if err.Error() != "client is nil" { + t.Errorf("Expected error 'client is nil', got: %v", err) + } + }) + + t.Run("fetches full message successfully", func(t *testing.T) { + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + server.EnsureINBOX(t) + + // Add a test message + uid := server.AddMessage(t, "INBOX", "", "Test Subject", "from@example.com", "to@example.com", time.Now()) + + client, cleanup := server.Connect(t) + defer cleanup() + + // Select INBOX + _, err := client.Select("INBOX", false) + if err != nil { + t.Fatalf("Failed to select INBOX: %v", err) + } + + // Fetch full message + msg, err := FetchFullMessage(client, uid) + if err != nil { + t.Fatalf("Failed to fetch full message: %v", err) + } + + if msg == nil { + t.Fatal("Expected message, got nil") + } + + if msg.Uid != uid { + t.Errorf("Expected UID %d, got %d", uid, msg.Uid) + } + + if msg.Envelope == nil { + t.Error("Expected envelope, got nil") + } + }) + + t.Run("handles message without body structure", func(t *testing.T) { + // This test verifies that FetchFullMessage doesn't crash when + // BodyStructure is nil. The function should still return the message + // with headers even if body structure is missing. + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + server.EnsureINBOX(t) + + // Add a test message + uid := server.AddMessage(t, "INBOX", "", "Test Subject", "from@example.com", "to@example.com", time.Now()) + + client, cleanup := server.Connect(t) + defer cleanup() + + // Select INBOX + _, err := client.Select("INBOX", false) + if err != nil { + t.Fatalf("Failed to select INBOX: %v", err) + } + + // Fetch full message + msg, err := FetchFullMessage(client, uid) + if err != nil { + t.Fatalf("Failed to fetch full message: %v", err) + } + + // Message should be returned even if body structure is nil + if msg == nil { + t.Fatal("Expected message, got nil") + } + }) +} diff --git a/backend/internal/imap/folder_test.go b/backend/internal/imap/folder_test.go new file mode 100644 index 0000000..c1d9dff --- /dev/null +++ b/backend/internal/imap/folder_test.go @@ -0,0 +1,123 @@ +package imap + +import ( + "testing" + + "github.com/vdavid/vmail/backend/internal/testutil" +) + +func TestListFolders(t *testing.T) { + t.Run("returns error for nil client", func(t *testing.T) { + _, err := ListFolders(nil) + if err == nil { + t.Error("Expected error for nil client") + } + if err.Error() != "client is nil" { + t.Errorf("Expected error 'client is nil', got: %v", err) + } + }) + + t.Run("returns error for server without SPECIAL-USE support", func(t *testing.T) { + // Create a test IMAP server without SPECIAL-USE extension + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + client, cleanup := server.Connect(t) + defer cleanup() + + // The test server doesn't enable SPECIAL-USE by default + // We need to check if it supports it - if not, we should get an error + caps, err := client.Capability() + if err != nil { + t.Fatalf("Failed to check capabilities: %v", err) + } + + // If the server doesn't support SPECIAL-USE, ListFolders should return an error + if !caps["SPECIAL-USE"] { + _, err := ListFolders(client) + if err == nil { + t.Error("Expected error for server without SPECIAL-USE support") + } + if err.Error() == "" { + t.Error("Expected non-empty error message") + } + } else { + // Server supports SPECIAL-USE, so test should pass + folders, err := ListFolders(client) + if err != nil { + t.Fatalf("ListFolders should succeed when SPECIAL-USE is supported: %v", err) + } + if folders == nil { + t.Error("Expected folders slice, got nil") + } + } + }) + + t.Run("handles empty folder list", func(t *testing.T) { + // Create a test IMAP server with SPECIAL-USE support + server, err := testutil.NewTestIMAPServerForE2E() + if err != nil { + t.Skipf("Failed to create test IMAP server with SPECIAL-USE support: %v", err) + } + defer server.Close() + + client, err := server.ConnectForE2E() + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer func() { + _ = client.Logout() + }() + + // Check if server supports SPECIAL-USE + caps, err := client.Capability() + if err != nil { + t.Fatalf("Failed to check capabilities: %v", err) + } + + if !caps["SPECIAL-USE"] { + t.Skip("Server does not support SPECIAL-USE, skipping test") + } + + // List folders - should return at least INBOX (created by memory backend) + folders, err := ListFolders(client) + if err != nil { + t.Fatalf("ListFolders failed: %v", err) + } + + // Memory backend creates INBOX by default, so we should have at least one folder + if len(folders) == 0 { + t.Error("Expected at least INBOX folder, got empty list") + } + + // Verify INBOX is present + foundINBOX := false + for _, folder := range folders { + if folder.Name == "INBOX" { + foundINBOX = true + if folder.Role != "inbox" { + t.Errorf("Expected INBOX role 'inbox', got '%s'", folder.Role) + } + } + } + if !foundINBOX { + t.Error("Expected to find INBOX folder") + } + }) + + t.Run("handles network errors during list", func(t *testing.T) { + // Create a client and then close it to simulate network error + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + client, _ := server.Connect(t) + // Close the client to simulate network error + _ = client.Logout() + + // Try to list folders with closed client + _, err := ListFolders(client) + if err == nil { + t.Error("Expected error when client is closed") + } + }) +} diff --git a/backend/internal/imap/parser.go b/backend/internal/imap/parser.go index 76c9bdd..8dc3ef5 100644 --- a/backend/internal/imap/parser.go +++ b/backend/internal/imap/parser.go @@ -11,6 +11,7 @@ import ( ) // ParseMessage converts an IMAP message to our Message model. +// Extracts headers, flags, and body (if available). Body parsing errors are logged but don't fail the parse. func ParseMessage(imapMsg *imap.Message, threadID, userID, folderName string) (*models.Message, error) { if imapMsg == nil { return nil, fmt.Errorf("imap message is nil") diff --git a/backend/internal/imap/parser_test.go b/backend/internal/imap/parser_test.go index 7c6f98c..3c82582 100644 --- a/backend/internal/imap/parser_test.go +++ b/backend/internal/imap/parser_test.go @@ -199,4 +199,125 @@ func TestParseMessage(t *testing.T) { t.Error("Expected message to not be marked as read") } }) + + t.Run("handles message without Message-ID", func(t *testing.T) { + imapMsg := &imap.Message{ + Uid: 300, + Flags: []string{}, + Envelope: &imap.Envelope{ + // No MessageId + Subject: "Test Subject", + }, + } + + msg, err := ParseMessage(imapMsg, "thread-id", "user-id", "INBOX") + if err != nil { + t.Fatalf("ParseMessage failed: %v", err) + } + + if msg.MessageIDHeader != "" { + t.Errorf("Expected empty MessageIDHeader, got %s", msg.MessageIDHeader) + } + if msg.Subject != "Test Subject" { + t.Errorf("Expected Subject 'Test Subject', got %s", msg.Subject) + } + }) + + t.Run("handles message with empty body", func(t *testing.T) { + imapMsg := &imap.Message{ + Uid: 400, + Flags: []string{}, + Envelope: &imap.Envelope{ + MessageId: "", + }, + // No Body or BodyStructure + } + + msg, err := ParseMessage(imapMsg, "thread-id", "user-id", "INBOX") + if err != nil { + t.Fatalf("ParseMessage failed: %v", err) + } + + if msg.UnsafeBodyHTML != "" { + t.Errorf("Expected empty body HTML, got %s", msg.UnsafeBodyHTML) + } + if msg.BodyText != "" { + t.Errorf("Expected empty body text, got %s", msg.BodyText) + } + }) + + t.Run("handles body parsing errors gracefully", func(t *testing.T) { + // Create a message with invalid body structure + imapMsg := &imap.Message{ + Uid: 500, + Flags: []string{}, + Envelope: &imap.Envelope{ + MessageId: "", + Subject: "Test Subject", + }, + BodyStructure: &imap.BodyStructure{ + MIMEType: "text", + MIMESubType: "plain", + }, + // Body is nil, which will cause parseBody to fail, but ParseMessage should continue + } + + msg, err := ParseMessage(imapMsg, "thread-id", "user-id", "INBOX") + if err != nil { + t.Fatalf("ParseMessage should not fail on body parsing errors: %v", err) + } + + // Should still have headers even if body parsing failed + if msg.Subject != "Test Subject" { + t.Errorf("Expected Subject 'Test Subject', got %s", msg.Subject) + } + if msg.MessageIDHeader != "" { + t.Errorf("Expected MessageIDHeader '', got %s", msg.MessageIDHeader) + } + }) + + t.Run("handles message with attachments", func(t *testing.T) { + // Note: Testing attachments requires a properly formatted MIME message + // For now, we test that the function handles messages with BodyStructure + // that indicates attachments. Full attachment parsing is tested through + // integration tests with real IMAP messages. + imapMsg := &imap.Message{ + Uid: 600, + Flags: []string{}, + Envelope: &imap.Envelope{ + MessageId: "", + Subject: "Test with Attachments", + }, + BodyStructure: &imap.BodyStructure{ + MIMEType: "multipart", + MIMESubType: "mixed", + Parts: []*imap.BodyStructure{ + { + MIMEType: "text", + MIMESubType: "plain", + }, + { + MIMEType: "application", + MIMESubType: "pdf", + Disposition: "attachment", + DispositionParams: map[string]string{ + "filename": "test.pdf", + }, + }, + }, + }, + } + + msg, err := ParseMessage(imapMsg, "thread-id", "user-id", "INBOX") + if err != nil { + t.Fatalf("ParseMessage failed: %v", err) + } + + // Message should be parsed successfully + if msg.Subject != "Test with Attachments" { + t.Errorf("Expected Subject 'Test with Attachments', got %s", msg.Subject) + } + // Attachments would be parsed from the body if Body is available + // This is tested through integration tests + }) } diff --git a/backend/internal/imap/pool.go b/backend/internal/imap/pool.go new file mode 100644 index 0000000..300022b --- /dev/null +++ b/backend/internal/imap/pool.go @@ -0,0 +1,116 @@ +package imap + +import ( + "context" + "log" + "sync" + "time" +) + +const ( + // workerIdleTimeout is the maximum time a worker connection can be idle before being closed. + workerIdleTimeout = 10 * time.Minute + // healthCheckThreshold is the idle time after which we perform a health check before reuse. + healthCheckThreshold = 1 * time.Minute +) + +// Pool manages IMAP connections per user. +// Supports two types of connections: +// - Worker connections: 1-3 connections per user for API handlers (SEARCH, FETCH, STORE) +// - Listener connections: 1 dedicated connection per user for IDLE command +// +// Thread safety: Each connection is wrapped with a mutex to ensure thread-safe access. +// Multiple goroutines can use different connections concurrently, but access to the same +// connection is serialized. +type Pool struct { + workerSets map[string]*workerClientSet // userID -> worker client set + listeners map[string]*threadSafeClient // userID -> listener connection + mu sync.RWMutex + maxWorkers int // Maximum worker connections per user (default: 3) + cleanupCtx context.Context + cleanupCancel context.CancelFunc +} + +// NewPool creates a new IMAP connection pool with the default worker limit. +func NewPool() *Pool { + return NewPoolWithMaxWorkers(3) +} + +// NewPoolWithMaxWorkers creates a new IMAP connection pool with a configurable +// maximum number of worker connections per user. +func NewPoolWithMaxWorkers(maxWorkers int) *Pool { + ctx, cancel := context.WithCancel(context.Background()) + p := &Pool{ + workerSets: make(map[string]*workerClientSet), + listeners: make(map[string]*threadSafeClient), + maxWorkers: maxWorkers, + cleanupCtx: ctx, + cleanupCancel: cancel, + } + go p.startCleanupGoroutine() + return p +} + +// WithClient gets an IMAP client for a user and calls the provided function with it. +// The client is automatically released when the function returns. +// Implements IMAPPool interface. +func (p *Pool) WithClient(userID, server, username, password string, fn func(IMAPClient) error) error { + tsClient, release, err := p.getWorkerConnection(userID, server, username, password) + if err != nil { + return err + } + defer release() + + client := &ClientWrapper{client: tsClient.GetClient()} + return fn(client) +} + +// RemoveClient removes all connections (worker and listener) for a user from the pool. +func (p *Pool) RemoveClient(userID string) { + p.mu.Lock() + defer p.mu.Unlock() + + // Remove worker set + if set, exists := p.workerSets[userID]; exists { + set.close() + delete(p.workerSets, userID) + } + + // Remove listener + if listener, exists := p.listeners[userID]; exists { + listener.Lock() + _ = listener.GetClient().Logout() + listener.Unlock() + delete(p.listeners, userID) + } +} + +// Close closes all connections in the pool and stops the cleanup goroutine. +func (p *Pool) Close() { + // Stop cleanup goroutine + p.cleanupCancel() + + p.mu.Lock() + defer p.mu.Unlock() + + // Close all worker sets + for userID, set := range p.workerSets { + set.close() + delete(p.workerSets, userID) + } + + // Close all listener connections + for userID, listener := range p.listeners { + // Try to lock - if we can't, the listener is in use + if listener.TryLock() { + if err := listener.GetClient().Logout(); err != nil { + log.Printf("Failed to logout listener connection for user %s: %v", userID, err) + } + listener.Unlock() + } else { + // Listener is locked (in use) - try to close anyway during shutdown + _ = listener.GetClient().Logout() + } + delete(p.listeners, userID) + } +} diff --git a/backend/internal/imap/pool_cleanup.go b/backend/internal/imap/pool_cleanup.go new file mode 100644 index 0000000..638d279 --- /dev/null +++ b/backend/internal/imap/pool_cleanup.go @@ -0,0 +1,58 @@ +package imap + +import ( + "time" +) + +// startCleanupGoroutine runs a background goroutine that periodically cleans up idle connections. +// The goroutine will stop when cleanupCtx is cancelled (via Pool.Close()). +func (p *Pool) startCleanupGoroutine() { + ticker := time.NewTicker(1 * time.Minute) + go func() { + defer ticker.Stop() + for { + select { + case <-p.cleanupCtx.Done(): + // Context cancelled - stop the ticker and exit + return + case <-ticker.C: + // Periodic cleanup + p.cleanupIdleConnections() + } + } + }() +} + +// cleanupIdleConnections removes worker connections that have been idle too long. +func (p *Pool) cleanupIdleConnections() { + p.mu.Lock() + defer p.mu.Unlock() + + now := time.Now() + for userID, set := range p.workerSets { + set.mu.Lock() + var toRemove []*threadSafeClient + for _, client := range set.clients { + if now.Sub(client.GetLastUsed()) > workerIdleTimeout { + toRemove = append(toRemove, client) + } + } + // Remove idle clients + for _, client := range toRemove { + for i, c := range set.clients { + if c == client { + set.clients = append(set.clients[:i], set.clients[i+1:]...) + client.Lock() + _ = client.GetClient().Logout() + client.Unlock() + break + } + } + } + // Remove empty sets + if len(set.clients) == 0 { + delete(p.workerSets, userID) + } + set.mu.Unlock() + } +} diff --git a/backend/internal/imap/pool_interface.go b/backend/internal/imap/pool_interface.go index cf5a1a4..86b2d41 100644 --- a/backend/internal/imap/pool_interface.go +++ b/backend/internal/imap/pool_interface.go @@ -21,8 +21,11 @@ type IMAPClient interface { // //goland:noinspection GoNameStartsWithPackageName type IMAPPool interface { - // GetClient gets or creates an IMAP client for a user. - GetClient(userID, server, username, password string) (IMAPClient, error) + // WithClient gets an IMAP client for a user and calls the provided function with it. + // The client is automatically released when the function returns, ensuring worker slots + // are freed promptly. This is the safe way to use the pool - it's impossible to forget + // to release the client. + WithClient(userID, server, username, password string, fn func(IMAPClient) error) error // RemoveClient removes a client from the pool (useful when a connection is broken). RemoveClient(userID string) @@ -41,5 +44,21 @@ func (w *ClientWrapper) ListFolders() ([]*models.Folder, error) { return ListFolders(w.client) } +// ListenerClient defines the interface for listener client operations. +// This allows the IDLE feature to work with the thread-safe wrapper +// without exposing implementation details. +type ListenerClient interface { + // Lock acquires the mutex for thread-safe access to the underlying client. + Lock() + // Unlock releases the mutex. + Unlock() + // GetClient returns the underlying IMAP client. + // Caller must hold the lock before calling this. + GetClient() *client.Client +} + // Ensure Pool implements IMAPPool interface var _ IMAPPool = (*Pool)(nil) + +// Ensure threadSafeClient implements ListenerClient interface +var _ ListenerClient = (*threadSafeClient)(nil) diff --git a/backend/internal/imap/pool_listener.go b/backend/internal/imap/pool_listener.go new file mode 100644 index 0000000..ba61bb3 --- /dev/null +++ b/backend/internal/imap/pool_listener.go @@ -0,0 +1,99 @@ +package imap + +import ( + "fmt" + "os" + "time" + + "github.com/emersion/go-imap" +) + +// GetListenerConnection gets or creates a listener client for a user. +// Listener clients are dedicated clients for IDLE command. +// Returns a locked client that must be unlocked by the caller. +// Thread-safe: uses double-check locking pattern. +func (p *Pool) GetListenerConnection(userID, server, username, password string) (ListenerClient, error) { + // First check without a lock + p.mu.RLock() + listener, exists := p.listeners[userID] + p.mu.RUnlock() + + if exists { + listener.Lock() + // Double-check after acquiring a lock + p.mu.RLock() + existingListener, stillExists := p.listeners[userID] + p.mu.RUnlock() + + if stillExists && existingListener == listener { + // Check if the connection is healthy + state := listener.GetClient().State() + if state == imap.AuthenticatedState || state == imap.SelectedState { + return listener, nil // Caller must unlock + } + // Connection is dead, unlock and remove it + listener.Unlock() + p.mu.Lock() + if p.listeners[userID] == listener { + delete(p.listeners, userID) + } + p.mu.Unlock() + // Close dead connection + _ = listener.GetClient().Logout() + } else { + // Another goroutine removed/recreated it + listener.Unlock() + // Retry with a new connection + return p.GetListenerConnection(userID, server, username, password) + } + } + + // Need to create a new listener connection + useTLS := os.Getenv("VMAIL_TEST_MODE") != "true" + c, err := ConnectToIMAP(server, useTLS) + if err != nil { + return nil, fmt.Errorf("failed to connect: %w", err) + } + + if err := Login(c, username, password); err != nil { + _ = c.Logout() + return nil, fmt.Errorf("failed to login: %w", err) + } + + // Wrap in threadSafeClient + listener = &threadSafeClient{ + client: c, + lastUsed: time.Now(), + role: roleListener, + } + + // Double-check before adding + p.mu.Lock() + if existingListener, exists := p.listeners[userID]; exists { + // Another goroutine created it - close ours and use existing + _ = c.Logout() + p.mu.Unlock() + listener = existingListener + listener.Lock() + return listener, nil + } + p.listeners[userID] = listener + p.mu.Unlock() + + listener.Lock() // Lock before returning + return listener, nil +} + +// RemoveListenerConnection removes a listener connection from the pool. +func (p *Pool) RemoveListenerConnection(userID string) { + p.mu.Lock() + defer p.mu.Unlock() + + listener, exists := p.listeners[userID] + if exists { + listener.Lock() + _ = listener.GetClient().Logout() + listener.Unlock() + delete(p.listeners, userID) + } +} diff --git a/backend/internal/imap/pool_test.go b/backend/internal/imap/pool_test.go new file mode 100644 index 0000000..6116046 --- /dev/null +++ b/backend/internal/imap/pool_test.go @@ -0,0 +1,253 @@ +package imap + +import ( + "fmt" + "os" + "testing" + + "github.com/emersion/go-imap" + "github.com/vdavid/vmail/backend/internal/testutil" +) + +func TestPool_GetClient(t *testing.T) { + pool := NewPool() + defer pool.Close() + + t.Run("creates new client when none exists", func(t *testing.T) { + // This test would require a real IMAP server or a mock. + // For now, we test that the pool structure works + if pool == nil { + t.Error("Expected pool to be created") + } + }) + + t.Run("removes client from pool", func(t *testing.T) { + pool.RemoveClient("test-user") + // Should not panic + }) + + t.Run("removes and recreates client when connection is dead", func(t *testing.T) { + // This test verifies the reconnection logic in GetClient. + // The logic should: + // 1. Check if the client exists and is in AuthenticatedState or SelectedState + // 2. If the client is in NotAuthenticatedState (or any other invalid state), remove it + // 3. Create a new connection + // + // To properly test this, we would need: + // - A mock IMAP client that can return different states + // - Or a real IMAP server that we can disconnect + // + // For now, we verify the pool structure and that RemoveClient works + userID := "test-reconnect-user" + pool.RemoveClient(userID) // Clean up if exists + + // The actual reconnection test would: + // 1. Manually add a mock client in NotAuthenticatedState to the pool + // 2. Call GetClient + // 3. Assert that the old client was removed and a new one was created + // + // This requires refactoring Pool to accept a client factory function + // or using interfaces to inject mock clients. + _ = userID + }) +} + +//goland:noinspection GoBoolExpressions +func TestPool_GetClient_ReconnectionLogic(t *testing.T) { + // This test documents the expected reconnection behavior: + // + // When GetClient is called and a client exists in the pool: + // 1. Check client.State() + // 2. If the state is imap.AuthenticatedState or imap.SelectedState, return the existing client + // 3. If the state is imap.NotAuthenticatedState (or any other state), remove the client from the pool + // 4. Create new connection and add to pool + // + // To test this properly, we need: + // - Interface for IMAP client with State() method + // - Mock client that can return different states + // - Ability to inject mock into pool + // + // Example test structure: + // mockClient := &MockClient{state: imap.NotAuthenticatedState} + // pool.clients["user"] = mockClient + // newClient, err := pool.GetClient("user", "server", "user", "pass") + // assert mockClient was removed + // assert newClient is different from mockClient + // assert newClient.State() is AuthenticatedState or SelectedState + + // Verify the state constants exist and are distinct + // The actual values may vary by go-imap version, but they should be distinct + if imap.NotAuthenticatedState == imap.AuthenticatedState { + t.Error("NotAuthenticatedState and AuthenticatedState should be different") + } + if imap.AuthenticatedState == imap.SelectedState { + t.Error("AuthenticatedState and SelectedState should be different") + } + if imap.NotAuthenticatedState == imap.SelectedState { + t.Error("NotAuthenticatedState and SelectedState should be different") + } + + // Log the actual values for reference + t.Logf("IMAP state constants: NotAuthenticatedState=%d, AuthenticatedState=%d, SelectedState=%d", + imap.NotAuthenticatedState, imap.AuthenticatedState, imap.SelectedState) +} + +func TestPool_ConcurrentAccess(t *testing.T) { + // Set test mode to use non-TLS connections + err := os.Setenv("VMAIL_TEST_MODE", "true") + if err != nil { + t.Fatalf("Failed to set VMAIL_TEST_MODE: %v", err) + } + defer func() { + err := os.Unsetenv("VMAIL_TEST_MODE") + if err != nil { + t.Fatalf("Failed to unset VMAIL_TEST_MODE: %v", err) + } + }() + + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + pool := NewPool() + defer pool.Close() + + t.Run("multiple goroutines creating clients simultaneously", func(t *testing.T) { + const numGoroutines = 5 + const userID = "simultaneous-create-user" + + results := make(chan error, numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + err := pool.WithClient(userID, server.Address, server.Username(), server.Password(), func(client IMAPClient) error { + // Client is automatically released when this function returns + return nil + }) + results <- err + }() + } + + // All should succeed without errors + for i := 0; i < numGoroutines; i++ { + if err := <-results; err != nil { + t.Errorf("WithClient failed: %v", err) + } + } + }) + + t.Run("remove client while another goroutine is using it", func(t *testing.T) { + const userID = "remove-while-using-user" + + // Use WithClient to get a client + done := make(chan bool, 1) + go func() { + err := pool.WithClient(userID, server.Address, server.Username(), server.Password(), func(client IMAPClient) error { + // Simulate using the client + _ = client + done <- true + return nil + }) + if err != nil { + t.Errorf("WithClient failed: %v", err) + } + }() + + // Remove the client while it might be in use + pool.RemoveClient(userID) + + // Wait for the goroutine to finish + <-done + // Should not panic + }) +} + +func TestPool_EdgeCases(t *testing.T) { + // Set test mode to use non-TLS connections + err := os.Setenv("VMAIL_TEST_MODE", "true") + if err != nil { + t.Fatalf("Failed to set VMAIL_TEST_MODE: %v", err) + } + defer func() { + err := os.Unsetenv("VMAIL_TEST_MODE") + if err != nil { + t.Fatalf("Failed to unset VMAIL_TEST_MODE: %v", err) + } + }() + + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + t.Run("pool with many users", func(t *testing.T) { + pool := NewPool() + defer pool.Close() + + const numUsers = 100 + for i := 0; i < numUsers; i++ { + userID := fmt.Sprintf("user-%d", i) + err := pool.WithClient(userID, server.Address, server.Username(), server.Password(), func(client IMAPClient) error { + // Client is automatically released when this function returns + return nil + }) + if err != nil { + t.Errorf("Failed to get client for user %s: %v", userID, err) + } + } + + // Verify all users have clients + for i := 0; i < numUsers; i++ { + userID := fmt.Sprintf("user-%d", i) + pool.RemoveClient(userID) + // Should not panic + } + }) + + t.Run("close while clients are in use", func(t *testing.T) { + pool := NewPool() + + // Use WithClient to get a client + err := pool.WithClient("close-user", server.Address, server.Username(), server.Password(), func(client IMAPClient) error { + // Client is automatically released when this function returns + return nil + }) + if err != nil { + t.Fatalf("Failed to get client: %v", err) + } + + // Close while the client might be in use + pool.Close() + + // Should not panic + }) + + t.Run("remove client while in use", func(t *testing.T) { + pool := NewPool() + defer pool.Close() + + userID := "remove-in-use-user" + err := pool.WithClient(userID, server.Address, server.Username(), server.Password(), func(client IMAPClient) error { + // Client is automatically released when this function returns + return nil + }) + if err != nil { + t.Fatalf("Failed to get client: %v", err) + } + + // Remove while might be in use + pool.RemoveClient(userID) + // Should not panic + }) +} + +func TestPool_Close(t *testing.T) { + pool := NewPool() + + t.Run("closes all clients", func(t *testing.T) { + pool.Close() + // Should not panic + }) + + t.Run("can be called multiple times safely", func(t *testing.T) { + pool := NewPool() + pool.Close() + pool.Close() // Should not panic + }) +} diff --git a/backend/internal/imap/pool_worker.go b/backend/internal/imap/pool_worker.go new file mode 100644 index 0000000..8e26fc0 --- /dev/null +++ b/backend/internal/imap/pool_worker.go @@ -0,0 +1,182 @@ +package imap + +import ( + "fmt" + "os" + "time" + + "github.com/emersion/go-imap" +) + +// getOrCreateWorkerSet gets or creates a worker client set for a user. +// Thread-safe: uses double-check locking pattern. +func (p *Pool) getOrCreateWorkerSet(userID string) *workerClientSet { + // First check without lock + p.mu.RLock() + set, exists := p.workerSets[userID] + p.mu.RUnlock() + + if exists { + return set + } + + // Need to create - acquire write lock + p.mu.Lock() + defer p.mu.Unlock() + + // Double-check: another goroutine might have created it + if set, exists := p.workerSets[userID]; exists { + return set + } + + // Create new set + set = &workerClientSet{ + clients: make([]*threadSafeClient, 0), + semaphore: make(chan struct{}, p.maxWorkers), + } + p.workerSets[userID] = set + return set +} + +// getWorkerConnection gets or creates a worker client for a user. +// Returns a locked client and a release function that must be called when done. +// Thread-safe: uses double-check locking and proper synchronization. +func (p *Pool) getWorkerConnection(userID, server, username, password string) (*threadSafeClient, func(), error) { + set := p.getOrCreateWorkerSet(userID) + + // Try to acquire an existing client + tsClient, release := set.acquire() + if tsClient != nil { + // Client is already locked from acquire() + // Check if client is healthy + state := tsClient.GetClient().State() + if state == imap.AuthenticatedState || state == imap.SelectedState { + // Check if we need a health check + lastUsed := tsClient.GetLastUsed() + if time.Since(lastUsed) > healthCheckThreshold { + if !p.checkConnectionHealth(tsClient) { + // Client is dead, unlock and remove it + tsClient.Unlock() + release() + // Remove from the set and create a new one + p.removeDeadClient(set, tsClient) + // Fall through to create a new client + } else { + // Client is healthy, update timestamp + tsClient.UpdateLastUsed() + return tsClient, release, nil // Caller must call release() when done + } + } else { + // Client is healthy and recently used + tsClient.UpdateLastUsed() + return tsClient, release, nil // Caller must call release() when done + } + } else { + // Client is dead + tsClient.Unlock() + release() + p.removeDeadClient(set, tsClient) + // Fall through to create a new client + } + } + + // Need to create a new client + // Acquire semaphore slot + set.semaphore <- struct{}{} + + // Use a flag to track if we should release in defer + // We'll manually release on error paths, so defer should not release in those cases + shouldReleaseInDefer := true + defer func() { + if shouldReleaseInDefer { + <-set.semaphore + } + }() + + // Double-check: another goroutine might have created a client while we were waiting + set.mu.Lock() + for _, existingClient := range set.clients { + if existingClient.mu.TryLock() { + state := existingClient.GetClient().State() + if state == imap.AuthenticatedState || state == imap.SelectedState { + existingClient.UpdateLastUsed() + set.mu.Unlock() + // Return with release function + // Don't release in defer since we're returning a client + shouldReleaseInDefer = false + release := func() { + existingClient.Unlock() + <-set.semaphore + } + return existingClient, release, nil // Caller must call release() when done + } + existingClient.mu.Unlock() + } + } + set.mu.Unlock() + + // Create new client + useTLS := os.Getenv("VMAIL_TEST_MODE") != "true" + c, err := ConnectToIMAP(server, useTLS) + if err != nil { + shouldReleaseInDefer = false // Don't release in defer, we'll do it manually + <-set.semaphore // Release semaphore on error + return nil, nil, fmt.Errorf("failed to connect: %w", err) + } + + if err := Login(c, username, password); err != nil { + shouldReleaseInDefer = false // Don't release in defer, we'll do it manually + _ = c.Logout() + <-set.semaphore // Release semaphore on error + return nil, nil, fmt.Errorf("failed to login: %w", err) + } + + // Wrap in threadSafeClient + newClient := &threadSafeClient{ + client: c, + lastUsed: time.Now(), + role: roleWorker, + } + tsClient = newClient + + // Add to set + set.addClient(tsClient) + tsClient.Lock() // Lock before returning + + // Don't release in defer - the release function will handle it + shouldReleaseInDefer = false + // Create release function for the new client + newRelease := func() { + tsClient.Unlock() + <-set.semaphore + } + return tsClient, newRelease, nil +} + +// removeDeadClient removes a dead client from the set. +func (p *Pool) removeDeadClient(set *workerClientSet, client *threadSafeClient) { + set.mu.Lock() + defer set.mu.Unlock() + + for i, c := range set.clients { + if c == client { + // Remove from slice + set.clients = append(set.clients[:i], set.clients[i+1:]...) + // Close client + client.Lock() + _ = client.client.Logout() + client.Unlock() + break + } + } +} + +// checkConnectionHealth performs a NOOP command to check if client is alive. +// The client must be locked before calling this. +func (p *Pool) checkConnectionHealth(client *threadSafeClient) bool { + // The caller has already locked the client + if err := client.client.Noop(); err != nil { + return false + } + return true +} diff --git a/backend/internal/imap/search.go b/backend/internal/imap/search.go index f8cf515..069edbf 100644 --- a/backend/internal/imap/search.go +++ b/backend/internal/imap/search.go @@ -10,11 +10,16 @@ import ( "time" "github.com/emersion/go-imap" + imapclient "github.com/emersion/go-imap/client" "github.com/vdavid/vmail/backend/internal/db" "github.com/vdavid/vmail/backend/internal/models" ) -// parseHeaderFilter processes from:, to:, or subject: filters. +// ErrInvalidSearchQuery is returned when a search query cannot be parsed. +var ErrInvalidSearchQuery = errors.New("invalid search query") + +// parseHeaderFilter processes header filters (from:, to:, subject:). +// Returns (handled, error) where handled indicates if the token matched this filter type. func parseHeaderFilter(token, prefix, headerName string, criteria *imap.SearchCriteria) (bool, error) { if !strings.HasPrefix(token, prefix) { return false, nil @@ -27,7 +32,9 @@ func parseHeaderFilter(token, prefix, headerName string, criteria *imap.SearchCr return true, nil } -// parseDateFilter processes after: or before: filters. +// parseDateFilter processes date filters (after:, before:). +// Returns (handled, error) where handled indicates if the token matched this filter type. +// For before: filters, sets the time to end of day (23:59:59.999999999). func parseDateFilter(token, prefix string, criteria *imap.SearchCriteria) (bool, error) { if !strings.HasPrefix(token, prefix) { return false, nil @@ -51,6 +58,8 @@ func parseDateFilter(token, prefix string, criteria *imap.SearchCriteria) (bool, } // parseFolderFilter processes folder: or label: filters. +// Only the first folder: or label: filter is extracted; subsequent ones are ignored. +// Returns (handled, folder, error) where handled indicates if the token matched this filter type. func parseFolderFilter(token string, folderFound *bool) (bool, string, error) { if !strings.HasPrefix(token, "folder:") && !strings.HasPrefix(token, "label:") { return false, "", nil @@ -170,6 +179,8 @@ func ParseSearchQuery(query string) (*imap.SearchCriteria, string, error) { } // tokenizeQuery splits a query into tokens, respecting quoted strings. +// Handles quoted strings (e.g., "John Doe") and combines filter prefixes with quoted values +// (e.g., from:"John Doe" becomes a single token). func tokenizeQuery(query string) []string { var tokens []string var current strings.Builder @@ -273,6 +284,8 @@ func parseFolderFromQuery(query string) (string, string) { } // buildThreadMapFromMessages processes IMAP messages and builds a map of threads. +// Returns a map from stable thread ID to thread, and a map from stable thread ID to latest sent_at time. +// Messages without Message-ID headers or not found in the database are skipped with warnings. func (s *Service) buildThreadMapFromMessages(ctx context.Context, userID string, messages []*imap.Message) (map[string]*models.Thread, map[string]*time.Time, error) { threadMap := make(map[string]*models.Thread) threadToLatestSentAt := make(map[string]*time.Time) @@ -285,7 +298,7 @@ func (s *Service) buildThreadMapFromMessages(ctx context.Context, userID string, messageID := imapMsg.Envelope.MessageId - msg, err := db.GetMessageByMessageID(ctx, s.pool, userID, messageID) + msg, err := db.GetMessageByMessageID(ctx, s.dbPool, userID, messageID) if err != nil { if errors.Is(err, db.ErrMessageNotFound) { log.Printf("Warning: Message with Message-ID %s not found in DB, skipping", messageID) @@ -294,7 +307,7 @@ func (s *Service) buildThreadMapFromMessages(ctx context.Context, userID string, return nil, nil, fmt.Errorf("failed to get message from DB: %w", err) } - thread, err := db.GetThreadByID(ctx, s.pool, msg.ThreadID) + thread, err := db.GetThreadByID(ctx, s.dbPool, msg.ThreadID) if err != nil { log.Printf("Warning: Failed to get thread %s: %v", msg.ThreadID, err) continue @@ -315,7 +328,8 @@ func (s *Service) buildThreadMapFromMessages(ctx context.Context, userID string, return threadMap, threadToLatestSentAt, nil } -// sortAndPaginateThreads sorts threads by latest sent_at and applies pagination. +// sortAndPaginateThreads sorts threads by latest sent_at (newest first) and applies pagination. +// Threads without sent_at are sorted to the end. Returns the paginated threads and total count. func sortAndPaginateThreads(threadMap map[string]*models.Thread, threadToLatestSentAt map[string]*time.Time, page, limit int) ([]*models.Thread, int) { threads := make([]*models.Thread, 0, len(threadMap)) for _, thread := range threadMap { @@ -350,13 +364,16 @@ func sortAndPaginateThreads(threadMap map[string]*models.Thread, threadToLatestS } // Search searches for threads matching the query in the specified folder. -// Supports Gmail-like syntax via ParseSearchQuery. -// Returns threads, total count, and error. +// Supports Gmail-like syntax via ParseSearchQuery (from:, to:, subject:, after:, before:, folder:, label:). +// If no folder is specified in the query, defaults to INBOX. +// Returns threads sorted by latest sent_at (newest first), total count, and error. +// Note: Error handling tests for getClientAndSelectFolder, UidSearch, and FetchMessageHeaders +// require complex IMAP server mocking and are covered through integration tests. func (s *Service) Search(ctx context.Context, userID string, query string, page, limit int) ([]*models.Thread, int, error) { // Parse the query using Gmail-like syntax criteria, extractedFolder, err := ParseSearchQuery(query) if err != nil { - return nil, 0, fmt.Errorf("invalid search query: %w", err) + return nil, 0, fmt.Errorf("%w: %v", ErrInvalidSearchQuery, err) } // Use extracted folder or default to INBOX @@ -365,36 +382,44 @@ func (s *Service) Search(ctx context.Context, userID string, query string, page, folder = "INBOX" } - client, _, err := s.getClientAndSelectFolder(ctx, userID, folder) - if err != nil { - return nil, 0, fmt.Errorf("failed to get IMAP client: %w", err) - } + var threads []*models.Thread + var totalCount int - uids, err := client.UidSearch(criteria) - if err != nil { - return nil, 0, fmt.Errorf("failed to search IMAP: %w", err) - } + err = s.withClientAndSelectFolder(ctx, userID, folder, func(client *imapclient.Client, _ *imap.MailboxStatus) error { + uids, err := client.UidSearch(criteria) + if err != nil { + return fmt.Errorf("failed to search IMAP: %w", err) + } - if len(uids) == 0 { - return nil, 0, nil - } + if len(uids) == 0 { + threads = nil + totalCount = 0 + return nil + } - messages, err := FetchMessageHeaders(client, uids) - if err != nil { - return nil, 0, fmt.Errorf("failed to fetch message headers: %w", err) - } + messages, err := FetchMessageHeaders(client, uids) + if err != nil { + return fmt.Errorf("failed to fetch message headers: %w", err) + } - threadMap, threadToLatestSentAt, err := s.buildThreadMapFromMessages(ctx, userID, messages) - if err != nil { - return nil, 0, err - } + threadMap, threadToLatestSentAt, err := s.buildThreadMapFromMessages(ctx, userID, messages) + if err != nil { + return err + } - threads, totalCount := sortAndPaginateThreads(threadMap, threadToLatestSentAt, page, limit) + threads, totalCount = sortAndPaginateThreads(threadMap, threadToLatestSentAt, page, limit) - // Enrich threads with first message's from_address for display - if err := db.EnrichThreadsWithFirstMessageFromAddress(ctx, s.pool, threads); err != nil { - log.Printf("Warning: Failed to enrich threads with first message from address: %v", err) - // Continue anyway - threads will work without the from_address + // Enrich threads with first message's from_address for display + if err := db.EnrichThreadsWithFirstMessageFromAddress(ctx, s.dbPool, threads); err != nil { + log.Printf("Warning: Failed to enrich threads with first message from address: %v", err) + // Continue anyway - threads will work without the from_address + } + + return nil + }) + + if err != nil { + return nil, 0, fmt.Errorf("failed to get IMAP client: %w", err) } return threads, totalCount, nil diff --git a/backend/internal/imap/search_test.go b/backend/internal/imap/search_test.go index 412bc68..fa2c0e0 100644 --- a/backend/internal/imap/search_test.go +++ b/backend/internal/imap/search_test.go @@ -1,11 +1,17 @@ package imap import ( + "context" + "encoding/base64" "strings" "testing" "time" + "github.com/emersion/go-imap" + "github.com/vdavid/vmail/backend/internal/crypto" + "github.com/vdavid/vmail/backend/internal/db" "github.com/vdavid/vmail/backend/internal/models" + "github.com/vdavid/vmail/backend/internal/testutil" ) func TestParseFolderFromQuery(t *testing.T) { @@ -294,3 +300,246 @@ func TestSortAndPaginateThreads(t *testing.T) { } }) } + +func TestTokenizeQuery(t *testing.T) { + t.Run("handles unclosed quotes", func(t *testing.T) { + // Unclosed quote should treat the rest as part of the token + tokens := tokenizeQuery(`from:"John Doe`) + // The tokenizer should handle this gracefully - the quote starts but never closes + // So "John Doe" (without closing quote) should be part of the token + if len(tokens) == 0 { + t.Error("Expected at least one token for unclosed quote") + } + // Verify the behavior: the unclosed quote should be included in the token + found := false + for _, token := range tokens { + if strings.Contains(token, "John Doe") { + found = true + } + } + if !found { + t.Errorf("Expected token to contain 'John Doe', got tokens: %v", tokens) + } + }) + + t.Run("handles empty quoted strings", func(t *testing.T) { + tokens := tokenizeQuery(`from:"" test`) + // Empty quoted strings are skipped (not tokenized) - this is the current behavior + // The tokenizer processes from: and test, skipping the empty quotes + if len(tokens) != 2 { + t.Errorf("Expected 2 tokens (from: and test), got %d: %v", len(tokens), tokens) + } + if tokens[0] != "from:" { + t.Errorf("Expected first token 'from:', got '%s'", tokens[0]) + } + if tokens[1] != "test" { + t.Errorf("Expected second token 'test', got '%s'", tokens[1]) + } + }) + + t.Run("handles multiple spaces between tokens", func(t *testing.T) { + tokens := tokenizeQuery("from:george to:alice") + // Multiple spaces should be collapsed (treated as single separator) + if len(tokens) != 2 { + t.Errorf("Expected 2 tokens, got %d: %v", len(tokens), tokens) + } + if tokens[0] != "from:george" { + t.Errorf("Expected first token 'from:george', got '%s'", tokens[0]) + } + if tokens[1] != "to:alice" { + t.Errorf("Expected second token 'to:alice', got '%s'", tokens[1]) + } + }) + + t.Run("handles nested quotes (quotes inside quotes)", func(t *testing.T) { + // The current implementation doesn't handle escaped quotes, but we test the behavior + tokens := tokenizeQuery(`from:"John "Doe" Smith"`) + // The tokenizer treats each quote as a toggle, so nested quotes will be tokenized + // This is expected behavior - the tokenizer doesn't handle escaped quotes + if len(tokens) == 0 { + t.Error("Expected at least one token for nested quotes") + } + }) + + t.Run("handles quoted strings with spaces", func(t *testing.T) { + tokens := tokenizeQuery(`from:"John Doe" test`) + if len(tokens) != 2 { + t.Errorf("Expected 2 tokens, got %d: %v", len(tokens), tokens) + } + // The quoted string should be combined with the prefix if applicable + // Check that "John Doe" is in one of the tokens + found := false + for _, token := range tokens { + if strings.Contains(token, "John Doe") { + found = true + } + } + if !found { + t.Errorf("Expected token to contain 'John Doe', got tokens: %v", tokens) + } + }) + + t.Run("handles filter prefix with quoted value", func(t *testing.T) { + tokens := tokenizeQuery(`from: "John Doe"`) + // The tokenizer should combine "from:" with the following quoted string + if len(tokens) == 0 { + t.Error("Expected at least one token") + } + // Check that from: and "John Doe" are combined + found := false + for _, token := range tokens { + if strings.Contains(token, "from:") && strings.Contains(token, "John Doe") { + found = true + } + } + if !found { + t.Errorf("Expected 'from:' and 'John Doe' to be combined, got tokens: %v", tokens) + } + }) +} + +func TestService_buildThreadMapFromMessages(t *testing.T) { + pool := testutil.NewTestDB(t) + defer pool.Close() + + encryptor := getTestEncryptorForSearch(t) + service := NewService(pool, NewPool(), encryptor) + defer service.Close() + + ctx := context.Background() + userID, err := db.GetOrCreateUser(ctx, pool, "build-thread-test@example.com") + if err != nil { + t.Fatalf("Failed to create user: %v", err) + } + + t.Run("returns error when GetMessageByMessageID returns non-NotFound error", func(t *testing.T) { + // Create a cancelled context to simulate a database error + cancelledCtx, cancel := context.WithCancel(ctx) + cancel() // Cancel immediately to cause context error + + imapMsg := &imap.Message{ + Uid: 1, + Envelope: &imap.Envelope{ + MessageId: "", + }, + } + + _, _, err := service.buildThreadMapFromMessages(cancelledCtx, userID, []*imap.Message{imapMsg}) + if err == nil { + t.Error("Expected error when GetMessageByMessageID returns non-NotFound error") + } + if !strings.Contains(err.Error(), "failed to get message from DB") { + t.Errorf("Expected error message about 'failed to get message from DB', got: %v", err) + } + }) + + t.Run("continues gracefully when GetThreadByID returns error", func(t *testing.T) { + // Create a thread and message + messageID := "" + thread := &models.Thread{ + UserID: userID, + StableThreadID: messageID, + Subject: "Test Thread", + } + if err := db.SaveThread(ctx, pool, thread); err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + // Create a message linked to this thread + message := &models.Message{ + ThreadID: thread.ID, + UserID: userID, + IMAPUID: 1, + IMAPFolderName: "INBOX", + MessageIDHeader: messageID, + FromAddress: "from@example.com", + Subject: "Test Subject", + } + if err := db.SaveMessage(ctx, pool, message); err != nil { + t.Fatalf("Failed to save message: %v", err) + } + + // Delete the thread to simulate GetThreadByID returning an error + _, err := pool.Exec(ctx, "DELETE FROM threads WHERE id = $1", thread.ID) + if err != nil { + t.Fatalf("Failed to delete thread: %v", err) + } + + // Now buildThreadMapFromMessages should skip this message and continue + imapMsg := &imap.Message{ + Uid: 1, + Envelope: &imap.Envelope{ + MessageId: messageID, + }, + } + + threadMap, sentAtMap, err := service.buildThreadMapFromMessages(ctx, userID, []*imap.Message{imapMsg}) + if err != nil { + t.Errorf("Expected no error (should skip message with missing thread), got: %v", err) + } + // The thread should not be in the map because GetThreadByID failed + if len(threadMap) != 0 { + t.Errorf("Expected empty thread map (thread was deleted), got: %v", threadMap) + } + if len(sentAtMap) != 0 { + t.Errorf("Expected empty sentAt map, got: %v", sentAtMap) + } + }) + + t.Run("skips messages not found in database", func(t *testing.T) { + // Message that doesn't exist in DB + imapMsg := &imap.Message{ + Uid: 999, + Envelope: &imap.Envelope{ + MessageId: "", + }, + } + + threadMap, sentAtMap, err := service.buildThreadMapFromMessages(ctx, userID, []*imap.Message{imapMsg}) + if err != nil { + t.Errorf("Expected no error (should skip message not found), got: %v", err) + } + if len(threadMap) != 0 { + t.Errorf("Expected empty thread map, got: %v", threadMap) + } + if len(sentAtMap) != 0 { + t.Errorf("Expected empty sentAt map, got: %v", sentAtMap) + } + }) + + t.Run("skips messages without Message-ID", func(t *testing.T) { + imapMsg := &imap.Message{ + Uid: 1, + Envelope: &imap.Envelope{ + // No MessageId + }, + } + + threadMap, sentAtMap, err := service.buildThreadMapFromMessages(ctx, userID, []*imap.Message{imapMsg}) + if err != nil { + t.Errorf("Expected no error (should skip message without Message-ID), got: %v", err) + } + if len(threadMap) != 0 { + t.Errorf("Expected empty thread map, got: %v", threadMap) + } + if len(sentAtMap) != 0 { + t.Errorf("Expected empty sentAt map, got: %v", sentAtMap) + } + }) +} + +func getTestEncryptorForSearch(t *testing.T) *crypto.Encryptor { + t.Helper() + + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + base64Key := base64.StdEncoding.EncodeToString(key) + + encryptor, err := crypto.NewEncryptor(base64Key) + if err != nil { + t.Fatalf("Failed to create encryptor: %v", err) + } + return encryptor +} diff --git a/backend/internal/imap/service.go b/backend/internal/imap/service.go index dd8ce25..2ca7da6 100644 --- a/backend/internal/imap/service.go +++ b/backend/internal/imap/service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "os" "time" "github.com/emersion/go-imap" @@ -17,26 +18,29 @@ import ( ) // Service handles IMAP operations and caching. +// The IMAP pool is injected so that a single shared pool can be used across +// handlers and services, ensuring per-user connection limits are enforced +// consistently. type Service struct { - pool *pgxpool.Pool - clientPool *Pool - encryptor *crypto.Encryptor - cacheTTL time.Duration + dbPool *pgxpool.Pool + imapPool IMAPPool + encryptor *crypto.Encryptor + cacheTTL time.Duration } // NewService creates a new IMAP service. -func NewService(pool *pgxpool.Pool, encryptor *crypto.Encryptor) *Service { +func NewService(dbPool *pgxpool.Pool, imapPool IMAPPool, encryptor *crypto.Encryptor) *Service { return &Service{ - pool: pool, - clientPool: NewPool(), - encryptor: encryptor, - cacheTTL: 5 * time.Minute, // Default cache TTL + dbPool: dbPool, + imapPool: imapPool, + encryptor: encryptor, + cacheTTL: 5 * time.Minute, // Default cache TTL } } // getSettingsAndPassword gets user settings and decrypts the IMAP password. func (s *Service) getSettingsAndPassword(ctx context.Context, userID string) (*models.UserSettings, string, error) { - settings, err := db.GetUserSettings(ctx, s.pool, userID) + settings, err := db.GetUserSettings(ctx, s.dbPool, userID) if err != nil { return nil, "", fmt.Errorf("failed to get user settings: %w", err) } @@ -49,27 +53,34 @@ func (s *Service) getSettingsAndPassword(ctx context.Context, userID string) (*m return settings, imapPassword, nil } -// getClientAndSelectFolder gets user settings, decrypts the password, gets the IMAP client, and selects the folder. -// Returns the client and mailbox status, or an error. -func (s *Service) getClientAndSelectFolder(ctx context.Context, userID, folderName string) (*imapclient.Client, *imap.MailboxStatus, error) { +// withClientAndSelectFolder gets user settings, gets an IMAP client, selects the folder, and calls the callback. +// The client is automatically released when the callback returns. +// Thread-safe: The connection is locked during folder selection to prevent concurrent folder selections +// from interfering with each other. +func (s *Service) withClientAndSelectFolder(ctx context.Context, userID, folderName string, fn func(*imapclient.Client, *imap.MailboxStatus) error) error { settings, imapPassword, err := s.getSettingsAndPassword(ctx, userID) if err != nil { - return nil, nil, err + return err } - // Get IMAP client (internal use - need concrete type) - client, err := s.clientPool.getClientConcrete(userID, settings.IMAPServerHostname, settings.IMAPUsername, imapPassword) - if err != nil { - return nil, nil, fmt.Errorf("failed to get IMAP client: %w", err) - } + // Use WithClient to ensure the client is always released + return s.imapPool.WithClient(userID, settings.IMAPServerHostname, settings.IMAPUsername, imapPassword, func(clientIface IMAPClient) error { + wrapper, ok := clientIface.(*ClientWrapper) + if !ok || wrapper.client == nil { + return fmt.Errorf("failed to unwrap IMAP client") + } + client := wrapper.client - // Select the folder - mbox, err := client.Select(folderName, false) - if err != nil { - return nil, nil, fmt.Errorf("failed to select folder %s: %w", folderName, err) - } + // Select the folder - connection is locked, so this is thread-safe + // Even if multiple goroutines call this concurrently, they will use different connections + // from the pool, or the same connection will be serialized by the lock + mbox, err := client.Select(folderName, false) + if err != nil { + return fmt.Errorf("failed to select folder %s: %w", folderName, err) + } - return client, mbox, nil + return fn(client, mbox) + }) } // threadMaps contains the maps needed for thread processing. @@ -156,7 +167,7 @@ func getStableThreadID(rootUID uint32, rootUIDToStableID map[uint32]string, uidT // getOrCreateThread gets an existing thread or creates a new one. func (s *Service) getOrCreateThread(ctx context.Context, userID, stableThreadID string, rootUID uint32, uidToMessageMap map[uint32]*imap.Message) (*models.Thread, error) { - threadModel, err := db.GetThreadByStableID(ctx, s.pool, userID, stableThreadID) + threadModel, err := db.GetThreadByStableID(ctx, s.dbPool, userID, stableThreadID) if err != nil { if !errors.Is(err, db.ErrThreadNotFound) { return nil, fmt.Errorf("failed to get thread: %w", err) @@ -175,7 +186,7 @@ func (s *Service) getOrCreateThread(ctx context.Context, userID, stableThreadID Subject: subject, } - if err := db.SaveThread(ctx, s.pool, threadModel); err != nil { + if err := db.SaveThread(ctx, s.dbPool, threadModel); err != nil { return nil, fmt.Errorf("failed to save thread: %w", err) } } @@ -196,7 +207,7 @@ func (s *Service) processMessage(ctx context.Context, imapMsg *imap.Message, roo return nil // Continue processing other messages } - if err := db.SaveMessage(ctx, s.pool, msg); err != nil { + if err := db.SaveMessage(ctx, s.dbPool, msg); err != nil { return fmt.Errorf("failed to save message: %w", err) } @@ -229,7 +240,7 @@ func (s *Service) tryIncrementalSync(ctx context.Context, client *imapclient.Cli if len(newUIDs) == 0 { log.Printf("No new messages to sync") // Update sync timestamp even though there's nothing new - if err := db.SetFolderSyncInfo(ctx, s.pool, userID, folderName, syncInfo.LastSyncedUID); err != nil { + if err := db.SetFolderSyncInfo(ctx, s.dbPool, userID, folderName, syncInfo.LastSyncedUID); err != nil { log.Printf("Warning: Failed to update folder sync timestamp: %v", err) } // Trigger background thread count update @@ -263,13 +274,22 @@ type fullSyncResult struct { } // performFullSync performs a full sync of all threads in the folder. -// Falls back to fetching all UIDs using SEARCH if THREAD command is not supported. +// In non-test environments, the IMAP server is required to support the THREAD +// extension (RFC 5256). In test mode (VMAIL_TEST_MODE=true), if THREAD is not +// supported we fall back to fetching all UIDs using SEARCH so that E2E tests +// can run against the in-memory IMAP server. func (s *Service) performFullSync(ctx context.Context, client *imapclient.Client, userID, folderName string) (fullSyncResult, error) { log.Printf("Full sync: fetching all threads") threads, err := RunThreadCommand(client) if err != nil { - // THREAD command not supported (e.g., by test IMAP server) - fall back to SEARCH - log.Printf("THREAD command not supported, falling back to SEARCH: %v", err) + // In non-test environments, missing THREAD support is a hard error. + if os.Getenv("VMAIL_TEST_MODE") != "true" { + return fullSyncResult{}, fmt.Errorf("IMAP server must support THREAD extension (RFC 5256): %w", err) + } + + // In test mode (used by E2E tests), THREAD is not supported by the + // in-memory test IMAP server, so we fall back to SEARCH. + log.Printf("THREAD command not supported in test mode, falling back to SEARCH, which is okay.") // Fetch all UIDs using SEARCH (starting from UID 1) uidsToSync, err := SearchUIDsSince(client, 1) if err != nil { @@ -279,7 +299,7 @@ func (s *Service) performFullSync(ctx context.Context, client *imapclient.Client if len(uidsToSync) == 0 { log.Printf("No messages found in folder %s", folderName) // Still update sync info - if err := db.SetFolderSyncInfo(ctx, s.pool, userID, folderName, nil); err != nil { + if err := db.SetFolderSyncInfo(ctx, s.dbPool, userID, folderName, nil); err != nil { log.Printf("Warning: Failed to set folder sync info: %v", err) } return fullSyncResult{shouldReturn: true}, nil @@ -310,7 +330,7 @@ func (s *Service) performFullSync(ctx context.Context, client *imapclient.Client if len(uidsToSync) == 0 { log.Printf("No messages found in folder %s", folderName) // Still update sync info - if err := db.SetFolderSyncInfo(ctx, s.pool, userID, folderName, nil); err != nil { + if err := db.SetFolderSyncInfo(ctx, s.dbPool, userID, folderName, nil); err != nil { log.Printf("Warning: Failed to set folder sync info: %v", err) } return fullSyncResult{shouldReturn: true}, nil @@ -371,84 +391,81 @@ func (s *Service) processFullSyncMessages(ctx context.Context, messages []*imap. // SyncThreadsForFolder syncs threads from IMAP for a specific folder. // Uses incremental sync if possible (only syncs new messages since last sync). func (s *Service) SyncThreadsForFolder(ctx context.Context, userID, folderName string) error { - client, mbox, err := s.getClientAndSelectFolder(ctx, userID, folderName) - if err != nil { - return err - } + return s.withClientAndSelectFolder(ctx, userID, folderName, func(client *imapclient.Client, mbox *imap.MailboxStatus) error { + log.Printf("Selected folder %s: %d messages", folderName, mbox.Messages) - log.Printf("Selected folder %s: %d messages", folderName, mbox.Messages) + // Check if we can do incremental sync + syncInfo, err := db.GetFolderSyncInfo(ctx, s.dbPool, userID, folderName) + if err != nil { + log.Printf("Warning: Failed to get folder sync info: %v", err) + syncInfo = nil // Fall back to full sync + } - // Check if we can do incremental sync - syncInfo, err := db.GetFolderSyncInfo(ctx, s.pool, userID, folderName) - if err != nil { - log.Printf("Warning: Failed to get folder sync info: %v", err) - syncInfo = nil // Fall back to full sync - } + // Try incremental sync first + incResult, isIncremental := s.tryIncrementalSync(ctx, client, userID, folderName, syncInfo) + if isIncremental { + if incResult.shouldReturn { + return nil + } + // Incremental sync path: process messages without thread structure + messages, err := FetchMessageHeaders(client, incResult.uidsToSync) + if err != nil { + return fmt.Errorf("failed to fetch message headers: %w", err) + } + log.Printf("Fetched %d message headers", len(messages)) + s.processIncrementalMessages(ctx, messages, userID, folderName) - // Try incremental sync first - incResult, isIncremental := s.tryIncrementalSync(ctx, client, userID, folderName, syncInfo) - if isIncremental { - if incResult.shouldReturn { + // Update sync info with the highest UID + highestUIDInt64 := int64(incResult.highestUID) + if err := db.SetFolderSyncInfo(ctx, s.dbPool, userID, folderName, &highestUIDInt64); err != nil { + log.Printf("Warning: Failed to set folder sync info: %v", err) + } + go s.updateThreadCountInBackground(userID, folderName) + return nil + } + + // Full sync path: get thread structure first + fullResult, err := s.performFullSync(ctx, client, userID, folderName) + if err != nil { + return err + } + if fullResult.shouldReturn { return nil } - // Incremental sync path: process messages without thread structure - messages, err := FetchMessageHeaders(client, incResult.uidsToSync) + + // Fetch message headers for UIDs we need to sync + messages, err := FetchMessageHeaders(client, fullResult.uidsToSync) if err != nil { return fmt.Errorf("failed to fetch message headers: %w", err) } + log.Printf("Fetched %d message headers", len(messages)) - s.processIncrementalMessages(ctx, messages, userID, folderName) + + // Process messages: use thread structure if available, otherwise use incremental processing + threadMaps := fullResult.threadMaps + if threadMaps == nil { + // THREAD command not supported - process messages without thread structure + // (same as incremental sync) + s.processIncrementalMessages(ctx, messages, userID, folderName) + } else { + // Process messages using thread structure + if err := s.processFullSyncMessages(ctx, messages, threadMaps, userID, folderName); err != nil { + return err + } + } // Update sync info with the highest UID - highestUIDInt64 := int64(incResult.highestUID) - if err := db.SetFolderSyncInfo(ctx, s.pool, userID, folderName, &highestUIDInt64); err != nil { + highestUIDInt64 := int64(fullResult.highestUID) + if err := db.SetFolderSyncInfo(ctx, s.dbPool, userID, folderName, &highestUIDInt64); err != nil { log.Printf("Warning: Failed to set folder sync info: %v", err) + // Don't fail the entire sync if timestamp update fails } + + // Trigger background thread count update go s.updateThreadCountInBackground(userID, folderName) - return nil - } - // Full sync path: get thread structure first - fullResult, err := s.performFullSync(ctx, client, userID, folderName) - if err != nil { - return err - } - if fullResult.shouldReturn { return nil - } - - // Fetch message headers for UIDs we need to sync - messages, err := FetchMessageHeaders(client, fullResult.uidsToSync) - if err != nil { - return fmt.Errorf("failed to fetch message headers: %w", err) - } - - log.Printf("Fetched %d message headers", len(messages)) - - // Process messages: use thread structure if available, otherwise use incremental processing - threadMaps := fullResult.threadMaps - if threadMaps == nil { - // THREAD command not supported - process messages without thread structure - // (same as incremental sync) - s.processIncrementalMessages(ctx, messages, userID, folderName) - } else { - // Process messages using thread structure - if err := s.processFullSyncMessages(ctx, messages, threadMaps, userID, folderName); err != nil { - return err - } - } - - // Update sync info with the highest UID - highestUIDInt64 := int64(fullResult.highestUID) - if err := db.SetFolderSyncInfo(ctx, s.pool, userID, folderName, &highestUIDInt64); err != nil { - log.Printf("Warning: Failed to set folder sync info: %v", err) - // Don't fail the entire sync if timestamp update fails - } - - // Trigger background thread count update - go s.updateThreadCountInBackground(userID, folderName) - - return nil + }) } // processIncrementalMessage processes a single message during incremental sync. @@ -471,17 +488,17 @@ func (s *Service) processIncrementalMessage(ctx context.Context, imapMsg *imap.M // Note: This is a simplification - full sync will correct threading using THREAD command // First, try to find the thread by Message-ID (this works for root messages) - threadModel, err := db.GetThreadByStableID(ctx, s.pool, userID, messageID) + threadModel, err := db.GetThreadByStableID(ctx, s.dbPool, userID, messageID) if err != nil { if !errors.Is(err, db.ErrThreadNotFound) { return fmt.Errorf("failed to get thread: %w", err) } // Thread was not found - check if this message already exists (might be a reply to an existing thread) - existingMsg, err := db.GetMessageByMessageID(ctx, s.pool, userID, messageID) + existingMsg, err := db.GetMessageByMessageID(ctx, s.dbPool, userID, messageID) if err == nil && existingMsg != nil { // Message already exists, get its thread - threadModel, err = db.GetThreadByID(ctx, s.pool, existingMsg.ThreadID) + threadModel, err = db.GetThreadByID(ctx, s.dbPool, existingMsg.ThreadID) if err != nil { return fmt.Errorf("failed to get existing message's thread: %w", err) } @@ -497,7 +514,7 @@ func (s *Service) processIncrementalMessage(ctx context.Context, imapMsg *imap.M if imapMsg.Envelope != nil { threadModel.Subject = imapMsg.Envelope.Subject } - if err := db.SaveThread(ctx, s.pool, threadModel); err != nil { + if err := db.SaveThread(ctx, s.dbPool, threadModel); err != nil { return fmt.Errorf("failed to save thread: %w", err) } } @@ -509,7 +526,7 @@ func (s *Service) processIncrementalMessage(ctx context.Context, imapMsg *imap.M return fmt.Errorf("failed to parse message: %w", err) } - if err := db.SaveMessage(ctx, s.pool, msg); err != nil { + if err := db.SaveMessage(ctx, s.dbPool, msg); err != nil { return fmt.Errorf("failed to save message: %w", err) } @@ -517,12 +534,13 @@ func (s *Service) processIncrementalMessage(ctx context.Context, imapMsg *imap.M } // updateThreadCountInBackground updates the thread count in the background. +// Uses a 30-second timeout to avoid hanging indefinitely. func (s *Service) updateThreadCountInBackground(userID, folderName string) { // Use a new context with timeout to avoid hanging bgCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - if err := db.UpdateThreadCount(bgCtx, s.pool, userID, folderName); err != nil { + if err := db.UpdateThreadCount(bgCtx, s.dbPool, userID, folderName); err != nil { log.Printf("Warning: Failed to update thread count in background for folder %s: %v", folderName, err) } else { log.Printf("Updated thread count for folder %s", folderName) @@ -531,16 +549,15 @@ func (s *Service) updateThreadCountInBackground(userID, folderName string) { // SyncFullMessage syncs the full message body from IMAP. func (s *Service) SyncFullMessage(ctx context.Context, userID, folderName string, imapUID int64) error { - client, _, err := s.getClientAndSelectFolder(ctx, userID, folderName) - if err != nil { - return err - } - - return s.syncSingleMessage(ctx, client, userID, folderName, imapUID) + return s.withClientAndSelectFolder(ctx, userID, folderName, func(client *imapclient.Client, _ *imap.MailboxStatus) error { + return s.syncSingleMessage(ctx, client, userID, folderName, imapUID) + }) } // SyncFullMessages syncs multiple message bodies from IMAP in a batch. // It groups messages by folder and syncs them efficiently to reduce network calls. +// Thread-safe: Each folder selection uses a locked connection from the pool, ensuring +// that concurrent syncs for the same user use different connections or are serialized. func (s *Service) SyncFullMessages(ctx context.Context, userID string, messages []MessageToSync) error { if len(messages) == 0 { return nil @@ -560,25 +577,36 @@ func (s *Service) SyncFullMessages(ctx context.Context, userID string, messages // Sync messages grouped by folder for folderName, uids := range folderToUIDs { - // Get IMAP client (internal use - need concrete type) - client, err := s.clientPool.getClientConcrete(userID, settings.IMAPServerHostname, settings.IMAPUsername, imapPassword) - if err != nil { - log.Printf("Warning: Failed to get IMAP client for folder %s: %v", folderName, err) - continue - } + // Use WithClient to ensure the client is always released + err := s.imapPool.WithClient(userID, settings.IMAPServerHostname, settings.IMAPUsername, imapPassword, func(clientIface IMAPClient) error { + wrapper, ok := clientIface.(*ClientWrapper) + if !ok || wrapper.client == nil { + log.Printf("Warning: Failed to unwrap IMAP client for folder %s", folderName) + return nil // Continue with next folder + } - // Select the folder once for all messages in this folder - if _, err := client.Select(folderName, false); err != nil { - log.Printf("Warning: Failed to select folder %s: %v", folderName, err) - continue - } + client := wrapper.client - // Sync each message in this folder - for _, imapUID := range uids { - if err := s.syncSingleMessage(ctx, client, userID, folderName, imapUID); err != nil { - log.Printf("Warning: Failed to sync message UID %d in folder %s: %v", imapUID, folderName, err) - // Continue with other messages + // Select the folder once for all messages in this folder + if _, err := client.Select(folderName, false); err != nil { + log.Printf("Warning: Failed to select folder %s: %v", folderName, err) + return nil // Continue with next folder } + + // Sync each message in this folder + for _, imapUID := range uids { + if err := s.syncSingleMessage(ctx, client, userID, folderName, imapUID); err != nil { + log.Printf("Warning: Failed to sync message UID %d in folder %s: %v", imapUID, folderName, err) + // Continue with other messages + } + } + + return nil + }) + + if err != nil { + log.Printf("Warning: Failed to get IMAP client for folder %s: %v", folderName, err) + // Continue with next folder } } @@ -594,7 +622,7 @@ func (s *Service) syncSingleMessage(ctx context.Context, client *imapclient.Clie } // Get existing message from DB - msg, err := db.GetMessageByUID(ctx, s.pool, userID, folderName, imapUID) + msg, err := db.GetMessageByUID(ctx, s.dbPool, userID, folderName, imapUID) if err != nil { return fmt.Errorf("failed to get message from DB: %w", err) } @@ -610,14 +638,14 @@ func (s *Service) syncSingleMessage(ctx context.Context, client *imapclient.Clie msg.BodyText = parsedMsg.BodyText // Save message with body - if err := db.SaveMessage(ctx, s.pool, msg); err != nil { + if err := db.SaveMessage(ctx, s.dbPool, msg); err != nil { return fmt.Errorf("failed to save message: %w", err) } // Save attachments for _, att := range parsedMsg.Attachments { att.MessageID = msg.ID - if err := db.SaveAttachment(ctx, s.pool, &att); err != nil { + if err := db.SaveAttachment(ctx, s.dbPool, &att); err != nil { log.Printf("Warning: Failed to save attachment: %v", err) } } @@ -627,7 +655,7 @@ func (s *Service) syncSingleMessage(ctx context.Context, client *imapclient.Clie // ShouldSyncFolder checks if we should sync the folder based on cache TTL. func (s *Service) ShouldSyncFolder(ctx context.Context, userID, folderName string) (bool, error) { - syncInfo, err := db.GetFolderSyncInfo(ctx, s.pool, userID, folderName) + syncInfo, err := db.GetFolderSyncInfo(ctx, s.dbPool, userID, folderName) if err != nil { return false, err } @@ -643,5 +671,5 @@ func (s *Service) ShouldSyncFolder(ctx context.Context, userID, folderName strin // Close closes the service and cleans up connections. func (s *Service) Close() { - s.clientPool.Close() + s.imapPool.Close() } diff --git a/backend/internal/imap/service_sync_test.go b/backend/internal/imap/service_sync_test.go index d8b7ae4..96ec949 100644 --- a/backend/internal/imap/service_sync_test.go +++ b/backend/internal/imap/service_sync_test.go @@ -106,7 +106,7 @@ func TestTryIncrementalSync(t *testing.T) { defer clientCleanup() encryptor := getTestEncryptor(t) - service := NewService(pool, encryptor) + service := NewService(pool, NewPool(), encryptor) defer service.Close() userID, err := db.GetOrCreateUser(ctx, pool, "incremental-test@example.com") @@ -272,7 +272,7 @@ func TestProcessIncrementalMessage(t *testing.T) { } encryptor := getTestEncryptor(t) - service := NewService(pool, encryptor) + service := NewService(pool, NewPool(), encryptor) defer service.Close() userID, err := db.GetOrCreateUser(ctx, pool, "process-test@example.com") @@ -373,6 +373,68 @@ func TestProcessIncrementalMessage(t *testing.T) { } }) + t.Run("matches existing thread by being a reply (message exists in DB)", func(t *testing.T) { + // Create a thread first + rootMessageID := "" + thread := &models.Thread{ + UserID: userID, + StableThreadID: rootMessageID, + Subject: "Root Thread", + } + err := db.SaveThread(ctx, pool, thread) + if err != nil { + t.Fatalf("Failed to save thread: %v", err) + } + + // Create a message in that thread (simulating a previous sync) + replyMessageID := "" + existingMsg := &models.Message{ + UserID: userID, + ThreadID: thread.ID, + MessageIDHeader: replyMessageID, + IMAPFolderName: folderName, + IMAPUID: 4, + Subject: "Re: Root Thread", + BodyText: "Original body", + UnsafeBodyHTML: "

Original body

", + } + err = db.SaveMessage(ctx, pool, existingMsg) + if err != nil { + t.Fatalf("Failed to save existing message: %v", err) + } + + // Now process the same message again (simulating incremental sync finding it) + imapMsg := &imap.Message{ + Uid: 4, + Envelope: &imap.Envelope{ + MessageId: replyMessageID, // Same Message-ID as existing message + Subject: "Re: Root Thread", + Date: time.Now(), + From: []*imap.Address{ + {MailboxName: "from", HostName: "test.com"}, + }, + To: []*imap.Address{ + {MailboxName: "to", HostName: "test.com"}, + }, + }, + Flags: []string{imap.SeenFlag}, + } + + err = service.processIncrementalMessage(ctx, imapMsg, userID, folderName) + if err != nil { + t.Fatalf("processIncrementalMessage failed: %v", err) + } + + // Verify message is still in the same thread + msg, err := db.GetMessageByMessageID(ctx, pool, userID, replyMessageID) + if err != nil { + t.Fatalf("Failed to get message: %v", err) + } + if msg.ThreadID != thread.ID { + t.Errorf("Message should be in existing thread %s, got %s", thread.ID, msg.ThreadID) + } + }) + t.Run("skips message without Message-ID", func(t *testing.T) { imapMsg := &imap.Message{ Uid: 3, diff --git a/backend/internal/imap/service_test.go b/backend/internal/imap/service_test.go index 7e5d652..d39338d 100644 --- a/backend/internal/imap/service_test.go +++ b/backend/internal/imap/service_test.go @@ -33,7 +33,7 @@ func TestShouldSyncFolder(t *testing.T) { } encryptor := getTestEncryptor(t) - service := NewService(pool, encryptor) + service := NewService(pool, NewPool(), encryptor) defer service.Close() userID, err := db.GetOrCreateUser(ctx, pool, "sync-test@example.com") @@ -193,3 +193,48 @@ func TestGetFolderSyncInfoWithUID(t *testing.T) { // - performFullSync: Requires mock IMAP client with THREAD command // - processIncrementalMessage: Can be tested with mock IMAP message // - SearchUIDsSince: Requires mock IMAP client + +func TestService_updateThreadCountInBackground(t *testing.T) { + pool := testutil.NewTestDB(t) + defer pool.Close() + + encryptor := getTestEncryptor(t) + service := NewService(pool, NewPool(), encryptor) + defer service.Close() + + ctx := context.Background() + userID, err := db.GetOrCreateUser(ctx, pool, "thread-count-test@example.com") + if err != nil { + t.Fatalf("Failed to create user: %v", err) + } + + folderName := "INBOX" + + t.Run("handles database error gracefully", func(t *testing.T) { + // Test that updateThreadCountInBackground handles database errors gracefully + // by using an invalid userID that will cause UpdateThreadCount to fail + // (it will try to update a non-existent folder_sync_timestamps row) + invalidUserID := "00000000-0000-0000-0000-000000000000" + + // The function should log a warning but not crash + service.updateThreadCountInBackground(invalidUserID, "NonExistentFolder") + + // Give the goroutine time to complete + time.Sleep(200 * time.Millisecond) + + // Test should complete without panicking + // If there's a panic, the test will fail + // The function logs a warning for database errors, which is the expected behavior + }) + + t.Run("succeeds with valid database connection", func(t *testing.T) { + // Test that the function works correctly with a valid connection + service.updateThreadCountInBackground(userID, folderName) + + // Give the goroutine time to complete + time.Sleep(100 * time.Millisecond) + + // Test should complete without panicking + // If there's a panic, the test will fail + }) +} diff --git a/backend/internal/imap/thread.go b/backend/internal/imap/thread.go index 456fb7b..ac28c0d 100644 --- a/backend/internal/imap/thread.go +++ b/backend/internal/imap/thread.go @@ -9,6 +9,7 @@ import ( ) // RunThreadCommand runs the THREAD command and returns the thread structure. +// Uses the REFERENCES algorithm to build thread relationships. func RunThreadCommand(c *client.Client) ([]*sortthread.Thread, error) { if c == nil { return nil, fmt.Errorf("client is nil") diff --git a/backend/internal/imap/thread_test.go b/backend/internal/imap/thread_test.go new file mode 100644 index 0000000..11472fc --- /dev/null +++ b/backend/internal/imap/thread_test.go @@ -0,0 +1,148 @@ +package imap + +import ( + "testing" + "time" + + "github.com/vdavid/vmail/backend/internal/testutil" +) + +func TestRunThreadCommand(t *testing.T) { + t.Run("returns error for nil client", func(t *testing.T) { + _, err := RunThreadCommand(nil) + if err == nil { + t.Error("Expected error for nil client") + } + if err.Error() != "client is nil" { + t.Errorf("Expected error 'client is nil', got: %v", err) + } + }) + + t.Run("handles empty mailbox", func(t *testing.T) { + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + server.EnsureINBOX(t) + + client, cleanup := server.Connect(t) + defer cleanup() + + // Select INBOX (which is empty) + _, err := client.Select("INBOX", false) + if err != nil { + t.Fatalf("Failed to select INBOX: %v", err) + } + + // Check if server supports THREAD + caps, err := client.Capability() + if err != nil { + t.Fatalf("Failed to check capabilities: %v", err) + } + + // Run thread command on empty mailbox + threads, err := RunThreadCommand(client) + if !caps["THREAD"] { + // Server doesn't support THREAD, expect an error + if err == nil { + t.Error("Expected error for server without THREAD support") + } + return + } + + // Server supports THREAD, should succeed + if err != nil { + t.Fatalf("RunThreadCommand should succeed on empty mailbox: %v", err) + } + + if threads == nil { + t.Error("Expected empty slice, got nil") + } + if len(threads) != 0 { + t.Errorf("Expected empty threads slice, got %d threads", len(threads)) + } + }) + + t.Run("handles mailbox with unthreaded messages", func(t *testing.T) { + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + server.EnsureINBOX(t) + + // Add some messages without threading relationships + now := time.Now() + server.AddMessage(t, "INBOX", "", "Subject 1", "from@test.com", "to@test.com", now) + server.AddMessage(t, "INBOX", "", "Subject 2", "from@test.com", "to@test.com", now.Add(-1*time.Hour)) + + client, cleanup := server.Connect(t) + defer cleanup() + + _, err := client.Select("INBOX", false) + if err != nil { + t.Fatalf("Failed to select INBOX: %v", err) + } + + // Run thread command + threads, err := RunThreadCommand(client) + if err != nil { + // Some servers may not support THREAD command + // In that case, we expect an error + if err.Error() == "" { + t.Error("Expected non-empty error message") + } + return + } + + // If successful, we should have threads (possibly one per message if unthreaded) + if threads == nil { + t.Error("Expected threads slice, got nil") + } + // Unthreaded messages might be returned as separate threads or as a single thread + // The exact behavior depends on the server implementation + }) + + t.Run("handles server without THREAD support", func(t *testing.T) { + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + server.EnsureINBOX(t) + + client, cleanup := server.Connect(t) + defer cleanup() + + // Check if server supports THREAD + caps, err := client.Capability() + if err != nil { + t.Fatalf("Failed to check capabilities: %v", err) + } + + // The memory backend may or may not support THREAD + // If it doesn't, we should get an error + if !caps["THREAD"] { + _, err := RunThreadCommand(client) + if err == nil { + t.Error("Expected error for server without THREAD support") + } + } else { + // Server supports THREAD, so test should pass + _, err := RunThreadCommand(client) + if err != nil { + t.Fatalf("RunThreadCommand should succeed when THREAD is supported: %v", err) + } + } + }) + + t.Run("handles network errors during thread command", func(t *testing.T) { + server := testutil.NewTestIMAPServer(t) + defer server.Close() + + client, _ := server.Connect(t) + // Close the client to simulate network error + _ = client.Logout() + + // Try to run thread command with closed client + _, err := RunThreadCommand(client) + if err == nil { + t.Error("Expected error when client is closed") + } + }) +} diff --git a/backend/internal/imap/worker_client_set.go b/backend/internal/imap/worker_client_set.go new file mode 100644 index 0000000..f62581b --- /dev/null +++ b/backend/internal/imap/worker_client_set.go @@ -0,0 +1,75 @@ +package imap + +import ( + "log" + "sync" +) + +// workerClientSet manages multiple worker clients for a single user. +// Uses a semaphore to limit concurrent connections (max 3 by default). +type workerClientSet struct { + clients []*threadSafeClient + semaphore chan struct{} // Limits concurrent connections (max 3) + mu sync.Mutex +} + +// acquire gets a client from the set, blocking if at max capacity. +// Returns the client (locked) and a release function that must be called when done. +// If no client is available, returns nil and the caller should create a new one. +func (s *workerClientSet) acquire() (*threadSafeClient, func()) { + // Block until a slot is available + s.semaphore <- struct{}{} + + s.mu.Lock() + defer s.mu.Unlock() + + // Find an available client (not in use) + for _, client := range s.clients { + // Client is available if we can acquire its lock immediately + if client.mu.TryLock() { + client.UpdateLastUsed() + // Keep it locked - caller will unlock when done + return client, func() { + client.Unlock() + <-s.semaphore // Release semaphore slot + } + } + } + + // No available client - caller will need to create one + <-s.semaphore // Release semaphore slot temporarily + return nil, func() {} // No-op release function +} + +// addClient adds a new client to the set. +func (s *workerClientSet) addClient(client *threadSafeClient) { + s.mu.Lock() + defer s.mu.Unlock() + s.clients = append(s.clients, client) +} + +// close closes all clients in the set. +// If a client is currently locked (in use), it will be skipped. +// The auto-release goroutine will handle closing it when it sees the pool is closed. +func (s *workerClientSet) close() { + s.mu.Lock() + defer s.mu.Unlock() + + for _, client := range s.clients { + // Try to lock - if we can't, the client is in use and will be closed + // by the auto-release goroutine when it sees cleanupCtx.Done() + if client.TryLock() { + if err := client.client.Logout(); err != nil { + log.Printf("Failed to logout worker client: %v", err) + } + client.Unlock() + } else { + // Client is locked (in use) - skip it + // The auto-release goroutine will see cleanupCtx.Done() and won't release, + // but we should still try to close the underlying connection + // Note: This is not thread-safe, but we're shutting down so it's acceptable + _ = client.client.Logout() + } + } + s.clients = nil +} diff --git a/backend/internal/models/email.go b/backend/internal/models/email.go index 45b90d9..a1db1ac 100644 --- a/backend/internal/models/email.go +++ b/backend/internal/models/email.go @@ -57,3 +57,16 @@ type Attachment struct { IsInline bool `json:"is_inline"` ContentID string `json:"content_id,omitempty"` } + +// ThreadsResponse represents the paginated response for thread listings. +type ThreadsResponse struct { + Threads []*Thread `json:"threads"` + Pagination PaginationInfo `json:"pagination"` +} + +// PaginationInfo contains pagination metadata for list responses. +type PaginationInfo struct { + TotalCount int `json:"total_count"` + Page int `json:"page"` + PerPage int `json:"per_page"` +} diff --git a/backend/internal/models/user.go b/backend/internal/models/user.go index 715957c..03f1f36 100644 --- a/backend/internal/models/user.go +++ b/backend/internal/models/user.go @@ -58,6 +58,5 @@ type UserSettingsResponse struct { // AuthStatusResponse represents the authentication and setup status of a user. type AuthStatusResponse struct { - IsAuthenticated bool `json:"isAuthenticated"` IsSetupComplete bool `json:"isSetupComplete"` } diff --git a/backend/internal/testutil/imap.go b/backend/internal/testutil/imap.go index 0438921..9b574de 100644 --- a/backend/internal/testutil/imap.go +++ b/backend/internal/testutil/imap.go @@ -21,12 +21,12 @@ var _ server.Extension = (*specialUseExtension)(nil) type specialUseExtension struct{} // Capabilities returns the SPECIAL-USE capability. -func (e *specialUseExtension) Capabilities(c server.Conn) []string { +func (e *specialUseExtension) Capabilities(server.Conn) []string { return []string{"SPECIAL-USE"} } // Command returns nil (no custom commands needed for SPECIAL-USE). -func (e *specialUseExtension) Command(name string) server.HandlerFactory { +func (e *specialUseExtension) Command(string) server.HandlerFactory { return nil } @@ -177,6 +177,7 @@ func NewTestIMAPServer(t *testing.T) *TestIMAPServer { cleanup := func() { err := s.Close() if err != nil { + t.Logf("Failed to close IMAP server: %v", err) return } } @@ -451,7 +452,7 @@ Test message body. } // CreateFolderWithSpecialUse creates a folder with SPECIAL-USE attributes (non-test context). -func (s *TestIMAPServer) CreateFolderWithSpecialUse(folderName string, specialUseAttr string) error { +func (s *TestIMAPServer) CreateFolderWithSpecialUse(folderName string) error { client, err := s.ConnectForE2E() if err != nil { return fmt.Errorf("failed to connect: %w", err) diff --git a/backend/internal/testutil/smtp.go b/backend/internal/testutil/smtp.go index 15a5d96..70444a5 100644 --- a/backend/internal/testutil/smtp.go +++ b/backend/internal/testutil/smtp.go @@ -147,6 +147,7 @@ func NewTestSMTPServer(t *testing.T) *TestSMTPServer { cleanup := func() { err := s.Close() if err != nil { + t.Logf("Failed to close SMTP server: %v", err) return } } diff --git a/docs/architecture.md b/docs/architecture.md index d879716..7f8d3e7 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -73,30 +73,17 @@ The DB's role is **not** to be a full, permanent copy of the mailbox. Its primar The back end is a **Go** application providing a **REST API** for the front end. It communicates with the IMAP and the SMTP server and uses a **Postgres** database for caching and internal storage. -### Go libraries used - -* **IMAP Client:** [`github.com/emersion/go-imap`](https://github.com/emersion/go-imap) - * This seems to be the *de facto* standard library for client-side IMAP in Go. - It seems well-maintained and supports the necessary extensions like `THREAD`. -* **MIME Parsing:** [`github.com/jhillyerd/enmime`](https://github.com/jhillyerd/enmime) - * The Go standard library is not enough for real-world, complex emails. - * `enmime` robustly handles attachments, encodings, - and HTML/text parts. [Docs here.](https://pkg.go.dev/github.com/jhillyerd/enmime) -* **SMTP Sending:** Standard `net/smtp` (for transport) - with [`github.com/go-mail/mail`](https://github.com/go-mail/mail) - * `net/smtp` is the standard library for sending. - * `go-mail` is a popular and simple builder library for composing complex emails (HTML and attachments) - that `net/smtp` can then send. -* **HTTP Router:** [`http.ServeMux`](https://pkg.go.dev/net/http#ServeMux) - * It's part of the Go standard library, is battle-tested and well-documented. - * Selected based on [this guide](https://www.alexedwards.net/blog/which-go-router-should-i-use) -* **Postgres Driver:** [`github.com/jackc/pgx`](https://github.com/jackc/pgx) - * The modern, high-performance Postgres driver for Go. We need no full ORM (like [GORM](https://gorm.io/)) - for this project. -* **Encryption:** Standard `crypto/aes` and `crypto/cipher` - * For encrypting/decrypting user credentials in the DB using AES-GCM. -* **Testing:** [`github.com/ory/dockertest`](https://github.com/ory/dockertest) - * Useful for integration tests to spin up real Postgres containers. +### Features + +- [auth](backend/auth.md) +- [config](backend/config.md) +- [crypto](backend/crypto.md) +- [folders](backend/folders.md) +- [imap](backend/imap.md) +- [search](backend/search.md) +- [settings](backend/settings.md) +- [thread](backend/thread.md) +- [threads](backend/threads.md) ### REST API @@ -105,14 +92,28 @@ It communicates with the IMAP and the SMTP server and uses a **Postgres** databa **Thread ID:** The `thread_id` we use in the API (e.g., `/api/v1/thread/{thread_id}`) is a stable, unique identifier, such as the `Message-ID` header of the root/first message in the thread. -* [x] `GET /auth/status`: Checks the Authelia token and tells the front end if the user is authenticated, and has +(The checked items are implemented) + +* [x] `GET /auth/status`: Checks the Authelia token and tells the front end if the user has completed the setup/onboarding. - * Response: `{"isAuthenticated": true, "isSetupComplete": false}`. + * Response: `{"isSetupComplete": false}`. * `isSetupComplete: false` tells the React app to redirect to the `/settings` page for onboarding. -* [ ] `GET /folders`: List all IMAP folders (Inbox, Sent, etc.). -* [ ] `GET /threads?folder=Inbox&page=1&limit=100`: Get paginated threads for a folder. -* [ ] `GET /threads/search?q=from:george&page=1`: Get paginated search results. -* [ ] `GET /thread/{thread_id}`: Get all messages and content for one thread. +* [x] `GET /folders`: List all IMAP folders (Inbox, Sent, etc.). + * Response: Array of folder objects with `name` and `role` fields. + * Folders are sorted by role priority (inbox, sent, drafts, spam, trash, archive, other), then alphabetically within the same role. +* [x] `GET /threads?folder=Inbox&page=1&limit=100`: Get paginated threads for a folder. + * Response: `{"threads": [...], "pagination": {"total_count": 100, "page": 1, "per_page": 100}}`. + * Automatically syncs the folder from IMAP if the cache is stale. + * Uses user's pagination setting from settings if no limit is provided. +* [x] `GET /search?q=from:george&page=1&limit=100`: Get paginated search results. + * Response: `{"threads": [...], "pagination": {"total_count": 100, "page": 1, "per_page": 100}}`. + * Supports Gmail-like search syntax (from:, to:, subject:, after:, before:, folder:, label:). + * Empty query returns all emails in INBOX. + * Uses user's pagination setting from settings if no limit is provided. +* [x] `GET /thread/{thread_id}`: Get all messages and content for one thread. + * Response: Thread object with all messages, attachments, and bodies. + * Automatically syncs missing message bodies from IMAP in batch. + * Thread ID is URL-encoded Message-ID header. * [ ] `GET /message/{message_id}/attachment/{attachment_id}`: Download an attachment. * [x] `GET /settings`: Get user settings. * Response: `{"imap_server_hostname": "mail.example.com", "archive_folder_name": "Archive", ...}` @@ -135,29 +136,13 @@ unique identifier, such as the `Message-ID` header of the root/first message in For real-time updates (like new emails), the front end will open a WebSocket connection. -* `GET /api/v1/ws`: Upgrades the HTTP connection to a WebSocket. +* [ ] `GET /api/v1/ws`: Upgrades the HTTP connection to a WebSocket. The server uses this connection to push updates to the client. * **Server-to-client message example:** ```json {"type": "new_message", "folder": "INBOX"} ``` -## Front end - -### Tech - -* **Framework:** React 19+, with functional components and hooks. -* **Language:** TypeScript, using no classes, just modules. -* **Styling:** Tailwind 4, utility-first CSS. -* **Package manager:** pnpm. -* **State management:** - * `TanStack Query` (React Query): For server state (caching, invalidating, and refetching all data from our Go API). - * `Zustand`: For simple, global UI state (e.g., current selection, composer open/closed). -* **Routing:** `react-router` (for URL-based navigation, e.g., `/inbox`, `/thread/id`). -* **Linting/Formatting:** ESLint and Prettier. -* **Testing:** - * `Jest` + `React Testing Library`: For unit and integration tests. - * `Playwright`: For end-to-end tests. -* **Security:** [`DOMPurify`](https://github.com/cure53/DOMPurify) - * To sanitize all email HTML content before rendering it with `dangerouslySetInnerHTML`. - This is a **mandatory** security step. +### Technical decisions + +See [technical decisions](technical-decisions.md) \ No newline at end of file diff --git a/docs/backend/auth.md b/docs/backend/auth.md new file mode 100644 index 0000000..b165663 --- /dev/null +++ b/docs/backend/auth.md @@ -0,0 +1,36 @@ +# Auth + +The `auth` backend feature handles authentication and authorization for the V-Mail API. + +The feature set is not in a single package but rather a scattered bunch of files that provide auth. + +## Components + +* **`internal/api/auth_handler.go`**: HTTP handler for the `/api/v1/auth/status` endpoint. + * `GetAuthStatus`: Returns authentication and setup status for the current user. + * Checks if the user has completed onboarding by verifying user settings exist in the database. + +* **`internal/auth/middleware.go`**: Authentication middleware. + * `RequireAuth`: HTTP middleware that validates Bearer tokens in the Authorization header. + * `ValidateToken`: Validates Authelia JWT tokens and extracts the user's email (currently a stub for development). + * `GetUserEmailFromContext`: Helper to extract the authenticated user's email from the request context. + +* **`internal/db/user.go`**: Database operations for users. + * `GetOrCreateUser`: Gets or creates a user record by email address. + +* **`internal/db/user_settings.go`**: Database operations for user settings. + * `UserSettingsExist`: Checks if user settings exist for a given user ID. + +## Flow + +1. Frontend sends API requests with a Bearer token in the Authorization header. +2. `RequireAuth` middleware validates the token and extracts the user's email. +3. The email is stored in the request context for use by handlers. +4. Handlers use `GetUserEmailFromContext` to retrieve the authenticated user's email. +5. The auth handler checks if the user has completed setup by querying for user settings. + +## Current limitations + +* `ValidateToken` is a stub that always returns "test@example.com" in production mode. It must be implemented to + actually validate Authelia JWT tokens before deployment. +* In test mode (`VMAIL_TEST_MODE=true`), tokens can be prefixed with "email:" to specify the test user email. diff --git a/docs/backend/config.md b/docs/backend/config.md new file mode 100644 index 0000000..0de9f71 --- /dev/null +++ b/docs/backend/config.md @@ -0,0 +1,40 @@ +# Config + +The `config` package handles loading and validating application configuration from environment variables. + +## Components + +* **`internal/config/config.go`**: Configuration loading and validation. + * `Config`: Struct holding all application configuration values. + * `NewConfig`: Loads configuration from environment variables, with support for `.env` file in development mode. + * `Validate`: Validates that all required configuration values are set. + * `GetDatabaseURL`: Builds a PostgreSQL connection string from database configuration. + * `getEnvOrDefault`: Helper function to get environment variables with default values. + +## Configuration values + +### Required + +* `VMAIL_ENCRYPTION_KEY_BASE64`: Base64-encoded encryption key (32 bytes when decoded). +* `AUTHELIA_URL`: Base URL of the Authelia authentication server. +* `VMAIL_DB_PASSWORD`: PostgreSQL database password. + +### Optional (with defaults) + +* `VMAIL_ENV`: Deployment environment (defaults to "development"). +* `VMAIL_DB_HOST`: Database hostname (defaults to "localhost"). +* `VMAIL_DB_PORT`: Database port (defaults to "5432"). +* `VMAIL_DB_USER`: Database username (defaults to "vmail"). +* `VMAIL_DB_NAME`: Database name (defaults to "vmail"). +* `VMAIL_DB_SSLMODE`: SSL mode (defaults to "disable"). +* `PORT`: HTTP server port (defaults to "11764"). +* `TZ`: Application timezone (defaults to "UTC"). + +## Development mode + +* When `VMAIL_ENV` is "development" (or not set), the package attempts to load a `.env` file using `godotenv`. +* If the `.env` file is not found, it falls back to environment variables with a warning message. + +## Current limitations + +* None - all identified issues have been addressed. diff --git a/docs/backend/crypto.md b/docs/backend/crypto.md new file mode 100644 index 0000000..61fa36a --- /dev/null +++ b/docs/backend/crypto.md @@ -0,0 +1,33 @@ +# Crypto + +The `crypto` package provides encryption and decryption functionality for sensitive data like user passwords. + +## Components + +* **`internal/crypto/encryption.go`**: AES-GCM encryption implementation. + * `Encryptor`: Struct holding the encryption key. + * `NewEncryptor`: Creates a new encryptor from a base64-encoded 32-byte key. + * `Encrypt`: Encrypts plaintext using AES-GCM with a random nonce. + * `Decrypt`: Decrypts ciphertext, verifying authenticity and integrity. + +## Encryption scheme + +* **Algorithm:** AES-256-GCM (Galois/Counter Mode) +* **Key size:** 32 bytes (256 bits) +* **Nonce:** Randomly generated for each encryption (12 bytes for GCM) +* **Ciphertext format:** `[nonce][encrypted_data][auth_tag]` + * The nonce is prepended to the ciphertext for use during decryption. + * The authentication tag is appended by GCM to verify data integrity. + +## Security properties + +* **Confidentiality:** Data is encrypted and cannot be read without the key. +* **Authenticity:** GCM provides authentication, detecting tampering or corruption. +* **Nonce uniqueness:** Each encryption uses a random nonce, ensuring the same plaintext produces different ciphertexts. +* **Key storage:** The encryption key is stored in memory as plain bytes (standard practice for application-level encryption). + +## Usage + +* Used to encrypt/decrypt IMAP and SMTP passwords before storing them in the database. +* The encryption key is provided via the `VMAIL_ENCRYPTION_KEY_BASE64` environment variable. +* The same key must be used across all application instances to decrypt previously encrypted data. diff --git a/docs/backend/folders.md b/docs/backend/folders.md new file mode 100644 index 0000000..a092bad --- /dev/null +++ b/docs/backend/folders.md @@ -0,0 +1,43 @@ +# Folders + +The `folders` back end feature provides a way to list IMAP folders for the authenticated user. + +The feature is intentionally not organized into a single package so that API-level functions can share helpers, etc. + +## Components + +* **`internal/api/folders_handler.go`**: HTTP handler for the `/api/v1/folders` endpoint. + * `GetFolders`: Lists all IMAP folders for the current user, sorted by role priority. + * `getUserSettingsAndPassword`: Retrieves user settings and decrypts the IMAP password. + * `getIMAPClient`: Gets an IMAP client from the pool, with user-friendly error messages for timeouts. + * `listFoldersWithRetry`: Lists folders with automatic retry on connection errors. + * `retryListFolders`: Retries listing folders after removing a broken connection from the pool. + * `writeFoldersResponse`: Writes the sorted folders as JSON. + * `sortFoldersByRole`: Sorts folders by role priority (inbox, sent, drafts, spam, trash, archive, other), then alphabetically within the same role. + +* **`internal/imap/folder.go`**: IMAP folder listing implementation. + * `ListFolders`: Lists all folders on the IMAP server using SPECIAL-USE attributes (RFC 6154) to determine roles. + * `determineFolderRole`: Maps folder names and SPECIAL-USE attributes to role strings. + +## Flow + +1. Handler extracts user ID from request context. +2. Retrieves and decrypts user settings (IMAP credentials). +3. Gets an IMAP client from the connection pool. +4. Lists folders from the IMAP server. +5. If a connection error occurs (broken pipe, connection reset, EOF), removes the broken client from the pool and retries with a fresh connection. +6. Sorts folders by role priority and alphabetically. +7. Returns folders as JSON. + +## Error handling + +* Returns 404 if user settings are not found. +* Returns 400 if the IMAP server doesn't support SPECIAL-USE extension (required for V-Mail). +* Returns 503 (Service Unavailable) for connection timeout errors with a user-friendly message. +* Returns 500 for other connection or internal errors. +* Automatically retries on transient connection errors (broken pipe, connection reset, EOF). + +## Dependencies + +* Requires IMAP server support for SPECIAL-USE extension (RFC 6154) to identify folder roles. +* Uses the IMAP connection pool to manage client connections efficiently. diff --git a/docs/backend/imap.md b/docs/backend/imap.md new file mode 100644 index 0000000..3913676 --- /dev/null +++ b/docs/backend/imap.md @@ -0,0 +1,97 @@ +# IMAP + +The `imap` package handles all communication with IMAP servers, including connection pooling, folder listing, message +syncing, and searching. + +This is probably the trickiest part of the codebase. + +## Components + +* **`internal/imap/client.go`**: Connection pool implementation. + * `Pool`: Manages IMAP connections per user (one connection per user, reused across requests). + * `getClientConcrete`: Gets or creates an IMAP client, checking connection health. + * `GetClient`: Public interface that returns an `IMAPClient` wrapper. + * `RemoveClient`: Removes a broken connection from the pool. + * `ConnectToIMAP`: Establishes connection with 5-second timeout. + * `Login`: Authenticates with the IMAP server. + +* **`internal/imap/pool_interface.go`**: Interfaces for testability. + * `IMAPClient`: Interface for IMAP client operations (currently only `ListFolders`). + * `IMAPPool`: Interface for connection pool operations. + * `ClientWrapper`: Wraps go-imap client to implement `IMAPClient`. + +* **`internal/imap/service.go`**: Main IMAP service implementation. + * `Service`: Handles IMAP operations and caching. + * `SyncThreadsForFolder`: Syncs threads from IMAP (incremental or full sync). + * `SyncFullMessage`: Syncs a single message body. + * `SyncFullMessages`: Batch syncs multiple message bodies. + * `Search`: Searches for threads matching a query. + * `ShouldSyncFolder`: Checks if folder cache is stale. + +* **`internal/imap/fetch.go`**: Message fetching operations. + * `FetchMessageHeaders`: Fetches headers for multiple messages. + * `FetchFullMessage`: Fetches full message body. + * `SearchUIDsSince`: Searches for UIDs >= minUID (for incremental sync). + +* **`internal/imap/folder.go`**: Folder listing operations. + * `ListFolders`: Lists folders with SPECIAL-USE attributes. + * `determineFolderRole`: Maps folder names and attributes to roles. + +* **`internal/imap/thread.go`**: Thread structure operations. + * `RunThreadCommand`: Executes IMAP THREAD command. + +* **`internal/imap/parser.go`**: Message parsing. + * `ParseMessage`: Converts IMAP message to internal model. + * `parseBody`: Parses email body using enmime library. + +* **`internal/imap/search.go`**: Search query parsing and execution. + * `ParseSearchQuery`: Parses Gmail-like search queries. + * `Search`: Performs IMAP search and returns threads. + +## Connection Pooling + +The connection pool is a critical and complex part of the codebase. Key characteristics: + +* **Worker connections**: Each user has a pool of 1–3 worker connections for API handlers (SEARCH, FETCH, STORE). These + connections are reused across requests and managed by a semaphore to limit concurrent connections. +* **Listener connections**: Each user has one dedicated listener connection for the IDLE command (for real-time email + notifications via WebSocket). +* **Thread safety**: + * IMAP clients from `go-imap` are **NOT thread-safe**. Each connection is wrapped with a mutex (`clientWithMutex`) + to ensure thread-safe access. + * Multiple goroutines can use different connections concurrently, but access to the same connection is serialized by + the mutex. + * Folder selection is thread-safe because connections are locked during operations. +* **Connection lifecycle management**: + * **Idle timeout**: Worker connections are closed after 10 minutes of inactivity. Listener connections have no idle + timeout (IDLE keeps them alive). + * **Health checks**: Before reusing a connection that's been idle > 1 minute, a NOOP command is sent to verify the + connection is alive. + * **Automatic cleanup**: A background goroutine runs every minute to remove idle connections. +* **Connection limits**: Maximum of 3 worker connections per user (enforced by semaphore). One listener connection per + user. + +## Thread safety guarantees + +* **Per-connection mutexes**: Each connection has its own mutex, allowing concurrent access to different connections + while serializing access to the same connection. +* **Double-check locking**: Used when creating new connections to prevent race conditions where multiple goroutines + create connections simultaneously. +* **Semaphore-based limiting**: Worker connections are limited by a semaphore (max 3 per user), ensuring proper resource + management. + +## Sync behavior + +* **Incremental sync**: If a folder has been synced before, only new messages (UIDs > last synced UID) are fetched. +* **Full sync**: If no sync info exists or incremental sync fails, all messages are fetched using THREAD command (or + SEARCH as fallback). +* **Thread structure**: Full sync uses IMAP THREAD command to build thread relationships. If THREAD is not supported, + falls back to processing messages without threading. +* **Lazy loading**: Message bodies are not always synced immediately. They are synced on-demand when a thread is viewed. + +## Error handling + +* Sync errors are logged but don't fail requests (graceful degradation). +* Broken connections are removed from the pool and recreated on next use. +* Folder selection errors are propagated to the caller. +* Network errors during fetch are propagated to the caller. diff --git a/docs/backend/search.md b/docs/backend/search.md new file mode 100644 index 0000000..831df38 --- /dev/null +++ b/docs/backend/search.md @@ -0,0 +1,78 @@ +# Search + +The `search` feature provides search for email threads using a Gmail-like query syntax. + +The feature is intentionally not organized into a single package so that API-level functions can share helpers, etc. + +## Components + +* **`internal/api/search_handler.go`**: HTTP handler for the `/api/v1/search` endpoint. + * `Search`: Handles search requests with query parameter parsing and pagination. + * `getPaginationLimit`: Gets pagination limit from user settings or defaults. + +* **`internal/imap/search.go`**: IMAP search implementation and query parsing. + * `ParseSearchQuery`: Parses Gmail-like search queries into IMAP SearchCriteria. + * `Search`: Performs IMAP search and returns paginated threads. + * `buildThreadMapFromMessages`: Builds thread map from IMAP search results. + * `sortAndPaginateThreads`: Sorts threads by latest sent_at and applies pagination. + * `tokenizeQuery`: Tokenizes query string, respecting quoted strings. + * `parseHeaderFilter`: Parses header filters (from:, to:, subject:). + * `parseDateFilter`: Parses date filters (after:, before:). + * `parseFolderFilter`: Parses folder/label filters (folder:, label:). + +## Flow + +1. Handler extracts user ID from request context. +2. Gets query from `q` query parameter (empty query means return all emails). +3. Parses pagination parameters (page, limit) from query string. +4. Gets pagination limit from user settings if not provided in query. +5. Calls IMAP service to search for matching threads. +6. IMAP service parses query using Gmail-like syntax. +7. IMAP service searches the specified folder (or INBOX if not specified). +8. IMAP service fetches message headers for matching UIDs. +9. IMAP service builds thread map from messages in the database. +10. IMAP service sorts threads by latest sent_at and applies pagination. +11. IMAP service enriches threads with first message's from_address. +12. Returns paginated response with threads and pagination info. + +## Search syntax + +* **Header filters:** + * `from:george` - Search by sender + * `to:alice` - Search by recipient + * `subject:meeting` - Search by subject + * Quoted values: `from:"John Doe"` - Search with quoted strings + +* **Date filters:** + * `after:2025-01-01` - Messages after date (YYYY-MM-DD format) + * `before:2025-12-31` - Messages before date (end of day) + +* **Folder filters:** + * `folder:Inbox` - Search in specific folder + * `label:Sent` - Alias for folder: (Gmail compatibility) + +* **Plain text:** + * `cabbage` - Full-text search across message content + +* **Combinations:** + * `from:george after:2025-01-01 cabbage` - Multiple filters and text search + +## Pagination + +* Default page: 1 +* Default limit: User's setting from `PaginationThreadsPerPage`, or 100 if not set. +* Query parameters: `page` and `limit` can override defaults. +* Invalid values (non-positive numbers) fall back to defaults. + +## Error handling + +* Returns 400 for invalid query syntax (e.g., empty filter values, invalid date formats). +* Returns 500 for IMAP connection errors, search failures, or database errors. +* Returns 500 for JSON encoding errors. +* If thread enrichment fails, continues gracefully (threads work without from_address). + +## Current limitations + +* Search is limited to a single folder (defaults to INBOX if not specified). +* Full-text search uses IMAP's TEXT search criteria (server-dependent behavior). +* Threads are sorted by latest sent_at only (no other sort options). diff --git a/docs/backend/settings.md b/docs/backend/settings.md new file mode 100644 index 0000000..506d09f --- /dev/null +++ b/docs/backend/settings.md @@ -0,0 +1,48 @@ +# Settings + +The `settings` feature provides the user a way to save their settings and preferences, including +their IMAP/SMTP credentials and application preferences. + +## Components + +* **`internal/api/settings_handler.go`**: HTTP handlers for the `/api/v1/settings` endpoint. + * `GetSettings`: Returns user settings for the current user (passwords are never included, only a boolean indicating if they're set). + * `PostSettings`: Saves or updates user settings. Passwords are optional on update (empty passwords preserve existing ones), but required for initial setup. + * `validateSettingsRequest`: Validates that all required fields are present in the request. + +* **`internal/db/user_settings.go`**: Database operations for user settings. + * `GetUserSettings`: Retrieves user settings by user ID. + * `SaveUserSettings`: Saves or updates user settings (uses ON CONFLICT for upsert). + * `UserSettingsExist`: Checks if user settings exist for a given user ID. + +## Flow (GetSettings) + +1. Handler extracts user ID from request context. +2. Retrieves user settings from the database. +3. Returns 404 if settings don't exist. +4. Builds response without passwords (only indicates if they're set). +5. Returns settings as JSON. + +## Flow (PostSettings) + +1. Handler extracts user ID from request context. +2. Decodes and validates the request body. +3. Retrieves existing settings (if any) to preserve passwords. +4. Handles password encryption: + * If password is provided: encrypts and uses the new password. + * If password is empty and settings exist: preserves existing encrypted password. + * If password is empty and no settings exist: returns 400 (password required for initial setup). +5. Saves settings to the database. +6. Returns success response. + +## Security + +* Passwords are encrypted using AES-GCM before storage in the database. +* Passwords are never returned in API responses (only a boolean indicating if they're set). +* Passwords can be updated without re-entering other passwords. + +## Error handling + +* Returns 404 if settings are not found (GetSettings). +* Returns 400 for validation errors (missing required fields, empty passwords on initial setup). +* Returns 500 for database or encryption errors. diff --git a/docs/backend/thread.md b/docs/backend/thread.md new file mode 100644 index 0000000..aa9d012 --- /dev/null +++ b/docs/backend/thread.md @@ -0,0 +1,55 @@ +# Thread + +The `thread` feature provides a way to retrieve a single email thread with all its messages, attachments, and bodies. + +It's intentionally not organized into a single package so that its API-level functions can share helpers, etc. + +## Components + +* **`internal/api/thread_handler.go`**: HTTP handler for the `/api/v1/thread/{thread_id}` endpoint. + * `GetThread`: Returns a single thread with all messages, attachments, and bodies. + * `getStableThreadIDFromPath`: Extracts and URL-decodes the thread ID from the request path. + * `collectMessagesToSync`: Identifies messages that need body syncing (lazy loading). + * `syncMissingBodies`: Syncs missing message bodies from IMAP in batch. + * `assignAttachments`: Assigns batch-fetched attachments to messages. + * `convertMessagesToThreadMessages`: Converts messages for response, ensuring attachments are never nil. + +* **`internal/db/messages.go`**: Database operations for messages and attachments. + * `GetMessagesForThread`: Retrieves all messages for a thread, ordered by sent_at. + * `GetMessageByUID`: Retrieves a message by IMAP UID and folder. + * `GetAttachmentsForMessages`: Batch-fetches attachments for multiple messages (avoids N+1 queries). + +## Flow + +1. Handler extracts user ID from request context. +2. Extracts and URL-decodes thread ID from the request path. +3. Retrieves thread from database by stable thread ID. +4. Retrieves all messages for the thread. +5. Batch-fetches all attachments for the messages (single query). +6. Identifies messages with missing bodies (lazy loading optimization). +7. Syncs missing bodies from IMAP in batch if needed. +8. Re-fetches synced messages to get updated bodies. +9. Assigns attachments to messages and converts for response. +10. Returns thread with all messages, attachments, and bodies. + +## Lazy loading + +* Message bodies are not always synced immediately when threads are synced. +* Bodies are synced on-demand when a thread is viewed. +* This optimization reduces initial sync time and storage requirements. +* Bodies are synced in batch for efficiency. + +## Error handling + +* Returns 400 if thread_id is missing or invalid. +* Returns 404 if thread is not found. +* Returns 500 for database errors. +* If attachment fetching fails, continues with empty attachments. +* If body sync fails, continues with messages without bodies (graceful degradation). +* Returns 500 for JSON encoding errors. + +## Performance optimizations + +* Batch-fetches attachments in a single query (avoids N+1 queries). +* Batch-syncs missing message bodies. +* Uses efficient UID-to-index mapping for updating synced messages. diff --git a/docs/backend/threads.md b/docs/backend/threads.md new file mode 100644 index 0000000..baacd76 --- /dev/null +++ b/docs/backend/threads.md @@ -0,0 +1,50 @@ +# Threads + +The `threads` feature provides a way to list email threads for a folder with pagination support. + +It's intentionally not organized into a single package so that API-level functions can share helpers, etc. + +## Components + +* **`internal/api/threads_handler.go`**: HTTP handler for the `/api/v1/threads` endpoint. + * `GetThreads`: Returns a paginated list of email threads for a folder. + * `parsePaginationParams`: Parses page and limit query parameters with validation. + * `getPaginationLimit`: Gets pagination limit from user settings or defaults. + * `syncFolderIfNeeded`: Checks if folder needs syncing and syncs if necessary. + * `buildPaginationResponse`: Builds the paginated response structure. + +* **`internal/db/threads.go`**: Database operations for threads. + * `GetThreadsForFolder`: Retrieves paginated threads for a folder. + * `GetThreadCountForFolder`: Gets the total count of threads for pagination. + * `SaveThread`: Saves or updates a thread in the database. + +## Flow + +1. Handler extracts user ID from request context. +2. Validates that the `folder` query parameter is provided. +3. Parses pagination parameters (page, limit) from query string. +4. Gets pagination limit from user settings if not provided in query. +5. Checks if folder needs syncing and syncs from IMAP if stale. +6. Retrieves threads from the database with pagination. +7. Gets total thread count for pagination metadata. +8. Returns paginated response with threads and pagination info. + +## Pagination + +* Default page: 1 +* Default limit: User's setting from `PaginationThreadsPerPage`, or 100 if not set. +* Query parameters: `page` and `limit` can override defaults. +* Invalid values (non-positive numbers) fall back to defaults. + +## Sync behavior + +* Automatically checks if folder cache is stale before returning threads. +* If stale, syncs from IMAP server in the background. +* If sync fails, continues and returns cached data (graceful degradation). +* Sync errors are logged but don't fail the request. + +## Error handling + +* Returns 400 if folder parameter is missing. +* Returns 500 for database errors (getting threads or count). +* Returns 500 for JSON encoding errors. diff --git a/docs/technical-decisions.md b/docs/technical-decisions.md new file mode 100644 index 0000000..56dc42b --- /dev/null +++ b/docs/technical-decisions.md @@ -0,0 +1,46 @@ +## Back end + +### Go libraries used + +* **IMAP Client:** [`github.com/emersion/go-imap`](https://github.com/emersion/go-imap) + * This seems to be the *de facto* standard library for client-side IMAP in Go. + It seems well-maintained and supports the necessary extensions like `THREAD`. +* **MIME Parsing:** [`github.com/jhillyerd/enmime`](https://github.com/jhillyerd/enmime) + * The Go standard library is not enough for real-world, complex emails. + * `enmime` robustly handles attachments, encodings, + and HTML/text parts. [Docs here.](https://pkg.go.dev/github.com/jhillyerd/enmime) +* **SMTP Sending:** Standard `net/smtp` (for transport) + with [`github.com/go-mail/mail`](https://github.com/go-mail/mail) + * `net/smtp` is the standard library for sending. + * `go-mail` is a popular and simple builder library for composing complex emails (HTML and attachments) + that `net/smtp` can then send. +* **HTTP Router:** [`http.ServeMux`](https://pkg.go.dev/net/http#ServeMux) + * It's part of the Go standard library, is battle-tested and well-documented. + * Selected based on [this guide](https://www.alexedwards.net/blog/which-go-router-should-i-use) +* **Postgres Driver:** [`github.com/jackc/pgx`](https://github.com/jackc/pgx) + * The modern, high-performance Postgres driver for Go. We need no full ORM (like [GORM](https://gorm.io/)) + for this project. +* **Encryption:** Standard `crypto/aes` and `crypto/cipher` + * For encrypting/decrypting user credentials in the DB using AES-GCM. +* **Testing:** [`github.com/ory/dockertest`](https://github.com/ory/dockertest) + * Useful for integration tests to spin up real Postgres containers. + +## Front end + +### Tech + +* **Framework:** React 19+, with functional components and hooks. +* **Language:** TypeScript, using no classes, just modules. +* **Styling:** Tailwind 4, utility-first CSS. +* **Package manager:** pnpm. +* **State management:** + * `TanStack Query` (React Query): For server state (caching, invalidating, and refetching all data from our Go API). + * `Zustand`: For simple, global UI state (e.g., current selection, composer open/closed). +* **Routing:** `react-router` (for URL-based navigation, e.g., `/inbox`, `/thread/id`). +* **Linting/Formatting:** ESLint and Prettier. +* **Testing:** + * `Jest` + `React Testing Library`: For unit and integration tests. + * `Playwright`: For end-to-end tests. +* **Security:** [`DOMPurify`](https://github.com/cure53/DOMPurify) + * To sanitize all email HTML content before rendering it with `dangerouslySetInnerHTML`. + This is a **mandatory** security step. diff --git a/docs/testing.md b/docs/testing.md index b84fe8c..4cf75a6 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -1,4 +1,3 @@ - # Testing guidelines ## Unit tests diff --git a/frontend/src/components/AuthWrapper.test.tsx b/frontend/src/components/AuthWrapper.test.tsx index c019f78..5a06556 100644 --- a/frontend/src/components/AuthWrapper.test.tsx +++ b/frontend/src/components/AuthWrapper.test.tsx @@ -1,8 +1,8 @@ import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { render, screen, waitFor } from '@testing-library/react' import * as React from 'react' -import { BrowserRouter, Routes, Route } from 'react-router-dom' -import { describe, it, expect, vi, beforeEach } from 'vitest' +import { BrowserRouter, Route, Routes } from 'react-router-dom' +import { beforeEach, describe, expect, it, vi } from 'vitest' import * as apiModule from '../lib/api' import { useAuthStore } from '../store/auth.store' @@ -59,7 +59,6 @@ describe('AuthWrapper', () => { it('should render children when setup is complete', async () => { // eslint-disable-next-line @typescript-eslint/unbound-method vi.mocked(apiModule.api.getAuthStatus).mockResolvedValue({ - isAuthenticated: true, isSetupComplete: true, }) @@ -73,7 +72,6 @@ describe('AuthWrapper', () => { it('should redirect to settings when setup is not complete', async () => { // eslint-disable-next-line @typescript-eslint/unbound-method vi.mocked(apiModule.api.getAuthStatus).mockResolvedValue({ - isAuthenticated: true, isSetupComplete: false, }) diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 4126dd8..105077c 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -37,7 +37,6 @@ export function decodeThreadIdFromUrl(encoded: string): string { } export interface AuthStatus { - isAuthenticated: boolean isSetupComplete: boolean } diff --git a/playwright.config.ts b/playwright.config.ts index 9cdb32a..0659335 100644 --- a/playwright.config.ts +++ b/playwright.config.ts @@ -34,8 +34,10 @@ export default defineConfig({ reuseExistingServer: false, timeout: 120 * 1000, env: { + VMAIL_ENV: 'test', VMAIL_TEST_MODE: 'true', PORT: '11765', // Use different port for E2E tests + VMAIL_IMAP_MAX_WORKERS: '50', // Increase max workers for faster tests }, }, { diff --git a/scripts/loc-counter.go b/scripts/loc-counter.go index fb8b933..46d1640 100644 --- a/scripts/loc-counter.go +++ b/scripts/loc-counter.go @@ -29,10 +29,7 @@ func main() { // Get all commits on the main branch commits, err := getCommits() if err != nil { - _, err := fmt.Fprintf(os.Stderr, "Error getting commits: %v\n", err) - if err != nil { - return - } + _, _ = fmt.Fprintf(os.Stderr, "Error getting commits: %v\n", err) os.Exit(1) } @@ -61,6 +58,7 @@ func main() { defer writer.Flush() err = writer.Write([]string{"date", "total", "ts", "go", "go prod", "go test", "ts prod", "ts test", "docs", "other", "comments"}) if err != nil { + _, err = fmt.Fprintf(os.Stderr, "Error writing CSV header: %v\n", err) return } @@ -141,6 +139,7 @@ func main() { comments, }) if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "Error writing CSV row: %v\n", err) return } } diff --git a/scripts/roadmap-burndown.go b/scripts/roadmap-burndown.go index dfdb8ce..05fdc4b 100755 --- a/scripts/roadmap-burndown.go +++ b/scripts/roadmap-burndown.go @@ -32,28 +32,19 @@ type DailyData struct { func main() { // Validate we're in a git repository if err := validateGitRepo(); err != nil { - _, err := fmt.Fprintf(os.Stderr, "Error: %v\n", err) - if err != nil { - return - } + _, _ = fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } // Get all commits where ROADMAP.md changed commits, err := getCommitsForRoadmap() if err != nil { - _, err := fmt.Fprintf(os.Stderr, "Error getting commits: %v\n", err) - if err != nil { - return - } + _, _ = fmt.Fprintf(os.Stderr, "Error getting commits: %v\n", err) os.Exit(1) } if len(commits) == 0 { - _, err := fmt.Fprintf(os.Stderr, "No commits found where ROADMAP.md changed\n") - if err != nil { - return - } + _, _ = fmt.Fprintf(os.Stderr, "No commits found where ROADMAP.md changed\n") os.Exit(1) } @@ -290,10 +281,7 @@ func outputCSV(data []DailyData) { d.Message, } if err := writer.Write(row); err != nil { - _, err := fmt.Fprintf(os.Stderr, "Error writing CSV row: %v\n", err) - if err != nil { - return - } + _, _ = fmt.Fprintf(os.Stderr, "Error writing CSV row: %v\n", err) os.Exit(1) } }