Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions internal/http/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,27 @@ func NewRouter(wikiInstance *wiki.Wiki, options RouterOptions) *gin.Engine {
}

router := gin.Default()
router.StaticFS("/assets", gin.Dir(wikiInstance.GetAssetService().GetAssetsDir(), true))

authCookies := auth_middleware.NewAuthCookies(options.AllowInsecure, options.AccessTokenTimeout, options.RefreshTokenTimeout)
csrfCookie := security.NewCSRFCookie(options.AllowInsecure, 3*24*time.Hour)

loginRateLimiter := security.NewRateLimiter(10, 5*time.Minute, true) // limit to 10 login attempts per 5 minutes per IP - reset on success
refreshRateLimiter := security.NewRateLimiter(30, time.Minute, false) // limit to 30 refresh attempts per minute per IP - do not reset on success

assetsFS := gin.Dir(wikiInstance.GetAssetService().GetAssetsDir(), false) // false = no directory listing

if options.PublicAccess || options.AuthDisabled {
// public read access or auth disabled -> assets are publicly accessible
router.StaticFS("/assets", assetsFS)
} else {
// private mode -> assets only accessible with authentication
assetsGroup := router.Group("/assets")
assetsGroup.Use(
auth_middleware.InjectPublicEditor(options.AuthDisabled),
auth_middleware.RequireAuth(wikiInstance, authCookies, options.AuthDisabled),
)
assetsGroup.StaticFS("/", assetsFS)
}

nonAuthApiGroup := router.Group("/api")
{
// Auth
Expand Down Expand Up @@ -149,25 +162,25 @@ func NewRouter(wikiInstance *wiki.Wiki, options RouterOptions) *gin.Engine {
// Serve branding assets (logos, favicons) with extension validation
router.GET("/branding/:filename", func(c *gin.Context) {
filename := c.Param("filename")

// Sanitize filename to prevent directory traversal and malicious input
// Only allow simple filenames (no path separators, no null bytes, no ..)
if strings.Contains(filename, "..") ||
strings.Contains(filename, "/") ||
strings.Contains(filename, "\\") ||
if strings.Contains(filename, "..") ||
strings.Contains(filename, "/") ||
strings.Contains(filename, "\\") ||
strings.Contains(filename, "\x00") {
c.Status(http.StatusForbidden)
return
}

// Get allowed extensions from branding constraints
constraints, err := wikiInstance.GetBrandingConstraints()
if err != nil {
log.Printf("Failed to get branding constraints: %v", err)
c.Status(http.StatusInternalServerError)
return
}

// Build a combined set of allowed extensions for O(1) lookup
allowedExts := make(map[string]bool)
for _, ext := range constraints.LogoExts {
Expand All @@ -176,30 +189,30 @@ func NewRouter(wikiInstance *wiki.Wiki, options RouterOptions) *gin.Engine {
for _, ext := range constraints.FaviconExts {
allowedExts[ext] = true
}

// Validate file extension against whitelist
ext := strings.ToLower(filepath.Ext(filename))
if !allowedExts[ext] {
c.Status(http.StatusForbidden)
return
}

// Construct file path
brandingDir := wikiInstance.GetBrandingService().GetBrandingAssetsDir()
filePath := filepath.Join(brandingDir, filename)

// Clean the path and verify it's within the branding directory
cleanPath := filepath.Clean(filePath)
cleanBrandingDir := filepath.Clean(brandingDir)

// Ensure the resolved path is still within the branding directory
// Use filepath.Rel to check the relative path doesn't escape the directory
rel, err := filepath.Rel(cleanBrandingDir, cleanPath)
if err != nil || strings.HasPrefix(rel, "..") {
c.Status(http.StatusForbidden)
return
}

// Check if file exists
if _, err := os.Stat(cleanPath); os.IsNotExist(err) {
c.Status(http.StatusNotFound)
Expand All @@ -209,7 +222,7 @@ func NewRouter(wikiInstance *wiki.Wiki, options RouterOptions) *gin.Engine {
c.Status(http.StatusInternalServerError)
return
}

// Serve the file
c.File(cleanPath)
})
Expand Down
233 changes: 233 additions & 0 deletions internal/http/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1453,3 +1453,236 @@ func TestIndexingStatusEndpoint(t *testing.T) {
t.Errorf("Expected 'active' field in response, got: %v", status)
}
}

// uploadTestAsset is a helper function that creates a page, uploads an asset, and returns the asset URL and auth cookies.
// If needsAuth is true, it will obtain authentication cookies; otherwise it will get CSRF token only (for AuthDisabled mode).
func uploadTestAsset(t *testing.T, router *gin.Engine, w *wiki.Wiki, content string, needsAuth bool) (assetURL string, cookies []*http.Cookie) {
// Create a page
page, err := w.CreatePage("system", nil, "Test Page", "test-page")
if err != nil {
t.Fatalf("Failed to create page: %v", err)
}

// Prepare the file upload
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", "test.txt")
if err != nil {
t.Fatalf("Failed to create form file: %v", err)
}
if _, err := part.Write([]byte(content)); err != nil {
t.Fatalf("Failed to write file: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("Failed to close multipart writer: %v", err)
}

var csrfToken string

if needsAuth {
// Login to get auth cookies
loginBody := `{"identifier": "admin", "password": "admin"}`
loginReq := httptest.NewRequest(http.MethodPost, "/api/auth/login", strings.NewReader(loginBody))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
router.ServeHTTP(loginRec, loginReq)

if loginRec.Code != http.StatusOK {
t.Fatalf("Expected 200 OK on login, got %d", loginRec.Code)
}

cookies = loginRec.Result().Cookies()
csrfToken = loginRec.Header().Get("X-CSRF-Token")
if csrfToken == "" {
for _, c := range cookies {
if c.Name == "leafwiki_csrf" || c.Name == "__Host-leafwiki_csrf" {
csrfToken = c.Value
break
}
}
}
} else {
// Get CSRF token only (for AuthDisabled mode)
configReq := httptest.NewRequest(http.MethodGet, "/api/config", nil)
configRec := httptest.NewRecorder()
router.ServeHTTP(configRec, configReq)

cookies = configRec.Result().Cookies()
csrfToken = configRec.Header().Get("X-CSRF-Token")
if csrfToken == "" {
for _, c := range cookies {
if c.Name == "leafwiki_csrf" || c.Name == "__Host-leafwiki_csrf" {
csrfToken = c.Value
break
}
}
}
}

// Upload the asset
uploadReq := httptest.NewRequest(http.MethodPost, "/api/pages/"+page.ID+"/assets", body)
uploadReq.Header.Set("Content-Type", writer.FormDataContentType())
for _, cookie := range cookies {
uploadReq.AddCookie(cookie)
}
uploadReq.Header.Set("X-CSRF-Token", csrfToken)

uploadRec := httptest.NewRecorder()
router.ServeHTTP(uploadRec, uploadReq)

if uploadRec.Code != http.StatusCreated {
t.Fatalf("Expected 201 Created on upload, got %d - %s", uploadRec.Code, uploadRec.Body.String())
}

var uploadResp map[string]string
if err := json.Unmarshal(uploadRec.Body.Bytes(), &uploadResp); err != nil {
t.Fatalf("Invalid upload JSON: %v", err)
}

assetURL = uploadResp["file"]
if assetURL == "" {
t.Fatal("Expected file URL in upload response")
}

return assetURL, cookies
}

// TestAssetAccessControl tests the access control for static asset routes
func TestAssetAccessControl(t *testing.T) {
t.Run("PrivateMode_UnauthenticatedAccess_Returns401", func(t *testing.T) {
w := createWikiTestInstance(t)
defer w.Close()

// Create router with PublicAccess=false and AuthDisabled=false
router := NewRouter(w, RouterOptions{
PublicAccess: false,
InjectCodeInHeader: "",
AllowInsecure: true,
AccessTokenTimeout: 15 * time.Minute,
RefreshTokenTimeout: 7 * 24 * time.Hour,
HideLinkMetadataSection: false,
AuthDisabled: false,
})

// Upload an asset (with auth)
assetURL, _ := uploadTestAsset(t, router, w, "test content", true)

// Try to access the asset without authentication
assetReq := httptest.NewRequest(http.MethodGet, assetURL, nil)
assetRec := httptest.NewRecorder()
router.ServeHTTP(assetRec, assetReq)

// Should return 401 Unauthorized
if assetRec.Code != http.StatusUnauthorized {
t.Errorf("Expected 401 Unauthorized when accessing asset without auth in private mode, got %d", assetRec.Code)
}
})

t.Run("PrivateMode_AuthenticatedAccess_Returns200", func(t *testing.T) {
w := createWikiTestInstance(t)
defer w.Close()

// Create router with PublicAccess=false and AuthDisabled=false
router := NewRouter(w, RouterOptions{
PublicAccess: false,
InjectCodeInHeader: "",
AllowInsecure: true,
AccessTokenTimeout: 15 * time.Minute,
RefreshTokenTimeout: 7 * 24 * time.Hour,
HideLinkMetadataSection: false,
AuthDisabled: false,
})

// Upload an asset (with auth) and get cookies
assetURL, cookies := uploadTestAsset(t, router, w, "test content", true)

// Access the asset with authentication
assetReq := httptest.NewRequest(http.MethodGet, assetURL, nil)
for _, cookie := range cookies {
assetReq.AddCookie(cookie)
}
assetRec := httptest.NewRecorder()
router.ServeHTTP(assetRec, assetReq)

// Should return 200 OK
if assetRec.Code != http.StatusOK {
t.Errorf("Expected 200 OK when accessing asset with auth in private mode, got %d", assetRec.Code)
}

// Verify content
content := assetRec.Body.String()
if content != "test content" {
t.Errorf("Expected 'test content', got '%s'", content)
}
})

t.Run("PublicAccessMode_UnauthenticatedAccess_Returns200", func(t *testing.T) {
w := createWikiTestInstance(t)
defer w.Close()

// Create router with PublicAccess=true
router := NewRouter(w, RouterOptions{
PublicAccess: true,
InjectCodeInHeader: "",
AllowInsecure: true,
AccessTokenTimeout: 15 * time.Minute,
RefreshTokenTimeout: 7 * 24 * time.Hour,
HideLinkMetadataSection: false,
AuthDisabled: false,
})

// Upload an asset (with auth)
assetURL, _ := uploadTestAsset(t, router, w, "test content public", true)

// Try to access the asset without authentication
assetReq := httptest.NewRequest(http.MethodGet, assetURL, nil)
assetRec := httptest.NewRecorder()
router.ServeHTTP(assetRec, assetReq)

// Should return 200 OK in public mode
if assetRec.Code != http.StatusOK {
t.Errorf("Expected 200 OK when accessing asset without auth in public mode, got %d", assetRec.Code)
}

// Verify content
content := assetRec.Body.String()
if content != "test content public" {
t.Errorf("Expected 'test content public', got '%s'", content)
}
})

t.Run("AuthDisabledMode_UnauthenticatedAccess_Returns200", func(t *testing.T) {
w := createWikiTestInstance(t)
defer w.Close()

// Create router with AuthDisabled=true
router := NewRouter(w, RouterOptions{
PublicAccess: false,
InjectCodeInHeader: "",
AllowInsecure: true,
AccessTokenTimeout: 15 * time.Minute,
RefreshTokenTimeout: 7 * 24 * time.Hour,
HideLinkMetadataSection: false,
AuthDisabled: true,
})

// Upload an asset (no auth needed, but CSRF token still required)
assetURL, _ := uploadTestAsset(t, router, w, "test content no auth", false)

// Try to access the asset without authentication
assetReq := httptest.NewRequest(http.MethodGet, assetURL, nil)
assetRec := httptest.NewRecorder()
router.ServeHTTP(assetRec, assetReq)

// Should return 200 OK when auth is disabled
if assetRec.Code != http.StatusOK {
t.Errorf("Expected 200 OK when accessing asset without auth when AuthDisabled=true, got %d", assetRec.Code)
}

// Verify content
content := assetRec.Body.String()
if content != "test content no auth" {
t.Errorf("Expected 'test content no auth', got '%s'", content)
}
})
}
Loading