diff --git a/internal/http/router.go b/internal/http/router.go index 6eb09bdb..a6571f29 100644 --- a/internal/http/router.go +++ b/internal/http/router.go @@ -50,14 +50,27 @@ func NewRouter(wikiInstance *wiki.Wiki, options RouterOptions) *gin.Engine { } router := gin.Default() - router.StaticFS("/assets", gin.Dir(wikiInstance.GetAssetService().GetAssetsDir(), true)) - authCookies := auth_middleware.NewAuthCookies(options.AllowInsecure, options.AccessTokenTimeout, options.RefreshTokenTimeout) csrfCookie := security.NewCSRFCookie(options.AllowInsecure, 3*24*time.Hour) loginRateLimiter := security.NewRateLimiter(10, 5*time.Minute, true) // limit to 10 login attempts per 5 minutes per IP - reset on success refreshRateLimiter := security.NewRateLimiter(30, time.Minute, false) // limit to 30 refresh attempts per minute per IP - do not reset on success + assetsFS := gin.Dir(wikiInstance.GetAssetService().GetAssetsDir(), false) // false = no directory listing + + if options.PublicAccess || options.AuthDisabled { + // public read access or auth disabled -> assets are publicly accessible + router.StaticFS("/assets", assetsFS) + } else { + // private mode -> assets only accessible with authentication + assetsGroup := router.Group("/assets") + assetsGroup.Use( + auth_middleware.InjectPublicEditor(options.AuthDisabled), + auth_middleware.RequireAuth(wikiInstance, authCookies, options.AuthDisabled), + ) + assetsGroup.StaticFS("/", assetsFS) + } + nonAuthApiGroup := router.Group("/api") { // Auth @@ -149,17 +162,17 @@ func NewRouter(wikiInstance *wiki.Wiki, options RouterOptions) *gin.Engine { // Serve branding assets (logos, favicons) with extension validation router.GET("/branding/:filename", func(c *gin.Context) { filename := c.Param("filename") - + // Sanitize filename to prevent directory traversal and malicious input // Only allow simple filenames (no path separators, no null bytes, no ..) - if strings.Contains(filename, "..") || - strings.Contains(filename, "/") || - strings.Contains(filename, "\\") || + if strings.Contains(filename, "..") || + strings.Contains(filename, "/") || + strings.Contains(filename, "\\") || strings.Contains(filename, "\x00") { c.Status(http.StatusForbidden) return } - + // Get allowed extensions from branding constraints constraints, err := wikiInstance.GetBrandingConstraints() if err != nil { @@ -167,7 +180,7 @@ func NewRouter(wikiInstance *wiki.Wiki, options RouterOptions) *gin.Engine { c.Status(http.StatusInternalServerError) return } - + // Build a combined set of allowed extensions for O(1) lookup allowedExts := make(map[string]bool) for _, ext := range constraints.LogoExts { @@ -176,22 +189,22 @@ func NewRouter(wikiInstance *wiki.Wiki, options RouterOptions) *gin.Engine { for _, ext := range constraints.FaviconExts { allowedExts[ext] = true } - + // Validate file extension against whitelist ext := strings.ToLower(filepath.Ext(filename)) if !allowedExts[ext] { c.Status(http.StatusForbidden) return } - + // Construct file path brandingDir := wikiInstance.GetBrandingService().GetBrandingAssetsDir() filePath := filepath.Join(brandingDir, filename) - + // Clean the path and verify it's within the branding directory cleanPath := filepath.Clean(filePath) cleanBrandingDir := filepath.Clean(brandingDir) - + // Ensure the resolved path is still within the branding directory // Use filepath.Rel to check the relative path doesn't escape the directory rel, err := filepath.Rel(cleanBrandingDir, cleanPath) @@ -199,7 +212,7 @@ func NewRouter(wikiInstance *wiki.Wiki, options RouterOptions) *gin.Engine { c.Status(http.StatusForbidden) return } - + // Check if file exists if _, err := os.Stat(cleanPath); os.IsNotExist(err) { c.Status(http.StatusNotFound) @@ -209,7 +222,7 @@ func NewRouter(wikiInstance *wiki.Wiki, options RouterOptions) *gin.Engine { c.Status(http.StatusInternalServerError) return } - + // Serve the file c.File(cleanPath) }) diff --git a/internal/http/router_test.go b/internal/http/router_test.go index cefcc288..a7812166 100644 --- a/internal/http/router_test.go +++ b/internal/http/router_test.go @@ -1453,3 +1453,236 @@ func TestIndexingStatusEndpoint(t *testing.T) { t.Errorf("Expected 'active' field in response, got: %v", status) } } + +// uploadTestAsset is a helper function that creates a page, uploads an asset, and returns the asset URL and auth cookies. +// If needsAuth is true, it will obtain authentication cookies; otherwise it will get CSRF token only (for AuthDisabled mode). +func uploadTestAsset(t *testing.T, router *gin.Engine, w *wiki.Wiki, content string, needsAuth bool) (assetURL string, cookies []*http.Cookie) { + // Create a page + page, err := w.CreatePage("system", nil, "Test Page", "test-page") + if err != nil { + t.Fatalf("Failed to create page: %v", err) + } + + // Prepare the file upload + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("file", "test.txt") + if err != nil { + t.Fatalf("Failed to create form file: %v", err) + } + if _, err := part.Write([]byte(content)); err != nil { + t.Fatalf("Failed to write file: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("Failed to close multipart writer: %v", err) + } + + var csrfToken string + + if needsAuth { + // Login to get auth cookies + loginBody := `{"identifier": "admin", "password": "admin"}` + loginReq := httptest.NewRequest(http.MethodPost, "/api/auth/login", strings.NewReader(loginBody)) + loginReq.Header.Set("Content-Type", "application/json") + loginRec := httptest.NewRecorder() + router.ServeHTTP(loginRec, loginReq) + + if loginRec.Code != http.StatusOK { + t.Fatalf("Expected 200 OK on login, got %d", loginRec.Code) + } + + cookies = loginRec.Result().Cookies() + csrfToken = loginRec.Header().Get("X-CSRF-Token") + if csrfToken == "" { + for _, c := range cookies { + if c.Name == "leafwiki_csrf" || c.Name == "__Host-leafwiki_csrf" { + csrfToken = c.Value + break + } + } + } + } else { + // Get CSRF token only (for AuthDisabled mode) + configReq := httptest.NewRequest(http.MethodGet, "/api/config", nil) + configRec := httptest.NewRecorder() + router.ServeHTTP(configRec, configReq) + + cookies = configRec.Result().Cookies() + csrfToken = configRec.Header().Get("X-CSRF-Token") + if csrfToken == "" { + for _, c := range cookies { + if c.Name == "leafwiki_csrf" || c.Name == "__Host-leafwiki_csrf" { + csrfToken = c.Value + break + } + } + } + } + + // Upload the asset + uploadReq := httptest.NewRequest(http.MethodPost, "/api/pages/"+page.ID+"/assets", body) + uploadReq.Header.Set("Content-Type", writer.FormDataContentType()) + for _, cookie := range cookies { + uploadReq.AddCookie(cookie) + } + uploadReq.Header.Set("X-CSRF-Token", csrfToken) + + uploadRec := httptest.NewRecorder() + router.ServeHTTP(uploadRec, uploadReq) + + if uploadRec.Code != http.StatusCreated { + t.Fatalf("Expected 201 Created on upload, got %d - %s", uploadRec.Code, uploadRec.Body.String()) + } + + var uploadResp map[string]string + if err := json.Unmarshal(uploadRec.Body.Bytes(), &uploadResp); err != nil { + t.Fatalf("Invalid upload JSON: %v", err) + } + + assetURL = uploadResp["file"] + if assetURL == "" { + t.Fatal("Expected file URL in upload response") + } + + return assetURL, cookies +} + +// TestAssetAccessControl tests the access control for static asset routes +func TestAssetAccessControl(t *testing.T) { + t.Run("PrivateMode_UnauthenticatedAccess_Returns401", func(t *testing.T) { + w := createWikiTestInstance(t) + defer w.Close() + + // Create router with PublicAccess=false and AuthDisabled=false + router := NewRouter(w, RouterOptions{ + PublicAccess: false, + InjectCodeInHeader: "", + AllowInsecure: true, + AccessTokenTimeout: 15 * time.Minute, + RefreshTokenTimeout: 7 * 24 * time.Hour, + HideLinkMetadataSection: false, + AuthDisabled: false, + }) + + // Upload an asset (with auth) + assetURL, _ := uploadTestAsset(t, router, w, "test content", true) + + // Try to access the asset without authentication + assetReq := httptest.NewRequest(http.MethodGet, assetURL, nil) + assetRec := httptest.NewRecorder() + router.ServeHTTP(assetRec, assetReq) + + // Should return 401 Unauthorized + if assetRec.Code != http.StatusUnauthorized { + t.Errorf("Expected 401 Unauthorized when accessing asset without auth in private mode, got %d", assetRec.Code) + } + }) + + t.Run("PrivateMode_AuthenticatedAccess_Returns200", func(t *testing.T) { + w := createWikiTestInstance(t) + defer w.Close() + + // Create router with PublicAccess=false and AuthDisabled=false + router := NewRouter(w, RouterOptions{ + PublicAccess: false, + InjectCodeInHeader: "", + AllowInsecure: true, + AccessTokenTimeout: 15 * time.Minute, + RefreshTokenTimeout: 7 * 24 * time.Hour, + HideLinkMetadataSection: false, + AuthDisabled: false, + }) + + // Upload an asset (with auth) and get cookies + assetURL, cookies := uploadTestAsset(t, router, w, "test content", true) + + // Access the asset with authentication + assetReq := httptest.NewRequest(http.MethodGet, assetURL, nil) + for _, cookie := range cookies { + assetReq.AddCookie(cookie) + } + assetRec := httptest.NewRecorder() + router.ServeHTTP(assetRec, assetReq) + + // Should return 200 OK + if assetRec.Code != http.StatusOK { + t.Errorf("Expected 200 OK when accessing asset with auth in private mode, got %d", assetRec.Code) + } + + // Verify content + content := assetRec.Body.String() + if content != "test content" { + t.Errorf("Expected 'test content', got '%s'", content) + } + }) + + t.Run("PublicAccessMode_UnauthenticatedAccess_Returns200", func(t *testing.T) { + w := createWikiTestInstance(t) + defer w.Close() + + // Create router with PublicAccess=true + router := NewRouter(w, RouterOptions{ + PublicAccess: true, + InjectCodeInHeader: "", + AllowInsecure: true, + AccessTokenTimeout: 15 * time.Minute, + RefreshTokenTimeout: 7 * 24 * time.Hour, + HideLinkMetadataSection: false, + AuthDisabled: false, + }) + + // Upload an asset (with auth) + assetURL, _ := uploadTestAsset(t, router, w, "test content public", true) + + // Try to access the asset without authentication + assetReq := httptest.NewRequest(http.MethodGet, assetURL, nil) + assetRec := httptest.NewRecorder() + router.ServeHTTP(assetRec, assetReq) + + // Should return 200 OK in public mode + if assetRec.Code != http.StatusOK { + t.Errorf("Expected 200 OK when accessing asset without auth in public mode, got %d", assetRec.Code) + } + + // Verify content + content := assetRec.Body.String() + if content != "test content public" { + t.Errorf("Expected 'test content public', got '%s'", content) + } + }) + + t.Run("AuthDisabledMode_UnauthenticatedAccess_Returns200", func(t *testing.T) { + w := createWikiTestInstance(t) + defer w.Close() + + // Create router with AuthDisabled=true + router := NewRouter(w, RouterOptions{ + PublicAccess: false, + InjectCodeInHeader: "", + AllowInsecure: true, + AccessTokenTimeout: 15 * time.Minute, + RefreshTokenTimeout: 7 * 24 * time.Hour, + HideLinkMetadataSection: false, + AuthDisabled: true, + }) + + // Upload an asset (no auth needed, but CSRF token still required) + assetURL, _ := uploadTestAsset(t, router, w, "test content no auth", false) + + // Try to access the asset without authentication + assetReq := httptest.NewRequest(http.MethodGet, assetURL, nil) + assetRec := httptest.NewRecorder() + router.ServeHTTP(assetRec, assetReq) + + // Should return 200 OK when auth is disabled + if assetRec.Code != http.StatusOK { + t.Errorf("Expected 200 OK when accessing asset without auth when AuthDisabled=true, got %d", assetRec.Code) + } + + // Verify content + content := assetRec.Body.String() + if content != "test content no auth" { + t.Errorf("Expected 'test content no auth', got '%s'", content) + } + }) +}