diff --git a/pkg/core/limiter/limiter.go b/pkg/core/limiter/limiter.go index c23de47..c7522e6 100644 --- a/pkg/core/limiter/limiter.go +++ b/pkg/core/limiter/limiter.go @@ -2,55 +2,99 @@ package limiter import ( "sync" + "time" "golang.org/x/time/rate" ) -// IPRateLimiter . +type limiterEntry struct { + limiter *rate.Limiter + lastSeen time.Time +} + +// IPRateLimiter handles rate limiting by IP address type IPRateLimiter struct { - ips map[string]*rate.Limiter - mu *sync.RWMutex - r rate.Limit - b int + ips map[string]*limiterEntry + mu *sync.RWMutex + r rate.Limit + b int + maxAge time.Duration } -// NewIPRateLimiter . +// NewIPRateLimiter creates a new IP rate limiter with cleanup routine func NewIPRateLimiter(r rate.Limit, b int) *IPRateLimiter { i := &IPRateLimiter{ - ips: make(map[string]*rate.Limiter), - mu: &sync.RWMutex{}, - r: r, - b: b, + ips: make(map[string]*limiterEntry), + mu: &sync.RWMutex{}, + r: r, + b: b, + maxAge: time.Hour, // Cleanup entries older than 1 hour } + // Start cleanup routine + go i.cleanupRoutine() + return i } -// AddIP creates a new rate limiter and adds it to the ips map, -// using the IP address as the key +// AddIP creates a new rate limiter and adds it to the ips map func (i *IPRateLimiter) AddIP(ip string) *rate.Limiter { i.mu.Lock() defer i.mu.Unlock() limiter := rate.NewLimiter(i.r, i.b) - - i.ips[ip] = limiter + entry := &limiterEntry{ + limiter: limiter, + lastSeen: time.Now(), + } + i.ips[ip] = entry return limiter } -// GetLimiter returns the rate limiter for the provided IP address if it exists. -// Otherwise calls AddIP to add IP address to the map +// GetLimiter returns the rate limiter for the provided IP address func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter { - i.mu.Lock() - limiter, exists := i.ips[ip] + i.mu.RLock() + entry, exists := i.ips[ip] + i.mu.RUnlock() if !exists { - i.mu.Unlock() return i.AddIP(ip) } + // Update last seen time + i.mu.Lock() + entry.lastSeen = time.Now() i.mu.Unlock() - return limiter + return entry.limiter +} + +// cleanupRoutine removes old entries periodically +func (i *IPRateLimiter) cleanupRoutine() { + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + i.cleanup() + } +} + +// cleanup removes entries that haven't been used recently +func (i *IPRateLimiter) cleanup() { + i.mu.Lock() + defer i.mu.Unlock() + + for ip, entry := range i.ips { + if time.Since(entry.lastSeen) > i.maxAge { + delete(i.ips, ip) + } + } +} + +// GetIPCount returns the current number of IP limiters +func (i *IPRateLimiter) GetIPCount() int { + i.mu.RLock() + defer i.mu.RUnlock() + return len(i.ips) } \ No newline at end of file diff --git a/pkg/core/middleware/limiterMidleware.go b/pkg/core/middleware/limiterMidleware.go index b22f4cd..ba978f2 100644 --- a/pkg/core/middleware/limiterMidleware.go +++ b/pkg/core/middleware/limiterMidleware.go @@ -1,8 +1,13 @@ +// /Users/sookh.com/go/pkg/mod/github.com/!jubaer!hossain/rootx@v1.4.7/pkg/core/middleware/limiterMidleware.go package middleware import ( + "fmt" + "log" + "net" "net/http" "os" + "strings" "time" "github.com/JubaerHossain/rootx/pkg/core/config" @@ -11,23 +16,163 @@ import ( "golang.org/x/time/rate" ) -func LimiterMiddleware(next http.Handler) http.Handler { - is_limit_enabled := config.GlobalConfig.RateLimitEnabled - if !is_limit_enabled { - return next +// RateLimitConfig holds the configuration for rate limiting +type RateLimitConfig struct { + Enabled bool + Limit int + Duration time.Duration + WhitelistIPs []string + BlacklistIPs []string +} + +var ( + whitelistedIPs = make(map[string]bool) + blacklistedIPs = make(map[string]bool) +) + +// isIPWhitelisted checks if an IP is in the whitelist +func isIPWhitelisted(ip string) bool { + return whitelistedIPs[ip] +} + +// isIPBlacklisted checks if an IP is in the blacklist +func isIPBlacklisted(ip string) bool { + return blacklistedIPs[ip] +} + +// getClientIP extracts the real client IP considering various headers +func getClientIP(r *http.Request) string { + // Check CF-Connecting-IP (Cloudflare) + if cfIP := r.Header.Get("CF-Connecting-IP"); cfIP != "" { + return cfIP } - limit := config.GlobalConfig.RateLimit - duration, err := time.ParseDuration(os.Getenv("RATE_LIMIT_DURATION")) + + // Check X-Real-IP + if realIP := r.Header.Get("X-Real-IP"); realIP != "" { + return realIP + } + + // Check X-Forwarded-For + if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" { + // X-Forwarded-For can contain multiple IPs, take the first one + ips := strings.Split(forwardedFor, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + + // Fall back to RemoteAddr + ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { - duration = time.Second * 2 + return r.RemoteAddr + } + return ip +} + +// LimiterMiddleware creates a new rate limiting middleware +func LimiterMiddleware(next http.Handler) http.Handler { + // Initialize configuration + config := loadRateLimitConfig() + if !config.Enabled { + log.Println("Rate limiting is disabled.") + return next } - var limiter = limiter.NewIPRateLimiter(rate.Every(duration), limit) + + // Initialize IP lists + initializeIPLists(config) + + var rateLimiter = limiter.NewIPRateLimiter(rate.Every(config.Duration), config.Limit) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - limiter := limiter.GetLimiter(r.RemoteAddr) + clientIP := getClientIP(r) + + // Check whitelist + if isIPWhitelisted(clientIP) { + next.ServeHTTP(w, r) + return + } + + // Check blacklist + if isIPBlacklisted(clientIP) { + utils.WriteJSONError(w, http.StatusForbidden, "Access denied") + return + } + + // Get limiter for this IP + limiter := rateLimiter.GetLimiter(clientIP) + + // Set standard rate limit headers + setRateLimitHeaders(w, limiter, config) + + // Check if rate limit is exceeded if !limiter.Allow() { - utils.WriteJSONError(w, http.StatusTooManyRequests, "Too many requests") + handleRateLimitExceeded(w, r, clientIP, config) return } + + // Add custom headers for debugging/monitoring + w.Header().Set("X-Rate-Limit-IP", clientIP) + w.Header().Set("X-Rate-Limit-Active-IPs", fmt.Sprintf("%d", rateLimiter.GetIPCount())) + next.ServeHTTP(w, r) }) } + +func loadRateLimitConfig() RateLimitConfig { + duration, err := time.ParseDuration(os.Getenv("RATE_LIMIT_DURATION")) + if err != nil { + duration = time.Minute + } + + limit := config.GlobalConfig.RateLimit + if limit <= 0 { + limit = 100 + } + + return RateLimitConfig{ + Enabled: config.GlobalConfig.RateLimitEnabled, + Limit: limit, + Duration: duration, + WhitelistIPs: strings.Split(os.Getenv("RATE_LIMIT_WHITELIST"), ","), + BlacklistIPs: strings.Split(os.Getenv("RATE_LIMIT_BLACKLIST"), ","), + } +} + +func initializeIPLists(config RateLimitConfig) { + // Initialize whitelist + for _, ip := range config.WhitelistIPs { + if ip = strings.TrimSpace(ip); ip != "" { + whitelistedIPs[ip] = true + } + } + + // Initialize blacklist + for _, ip := range config.BlacklistIPs { + if ip = strings.TrimSpace(ip); ip != "" { + blacklistedIPs[ip] = true + } + } +} + +func setRateLimitHeaders(w http.ResponseWriter, limiter *rate.Limiter, config RateLimitConfig) { + tokens := limiter.Tokens() + w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", config.Limit)) + w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%.0f", tokens)) + w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(config.Duration).Unix())) +} + +func handleRateLimitExceeded(w http.ResponseWriter, r *http.Request, clientIP string, config RateLimitConfig) { + log.Printf("Rate limit exceeded for IP: %s, UA: %s, Path: %s", + clientIP, + r.UserAgent(), + r.URL.Path, + ) + + w.Header().Set("Retry-After", fmt.Sprintf("%d", config.Duration/time.Second)) + utils.WriteJSONError(w, + http.StatusTooManyRequests, + fmt.Sprintf("Rate limit exceeded. Please try again in %d seconds", + config.Duration/time.Second, + ), + ) +} diff --git a/pkg/create/create.go b/pkg/create/create.go index 02388fe..5a3a891 100644 --- a/pkg/create/create.go +++ b/pkg/create/create.go @@ -510,7 +510,7 @@ func ApplyMigrations(cmd *cobra.Command, args []string) error { migrationsDir := "migrations" if err := executeScriptsInDirectory(pool, migrationsDir); err != nil { - return fmt.Errorf("failed to execute migration scripts") + return fmt.Errorf("failed to execute migration scripts %w", err) } bar := CreateProgressBar("migration: ") diff --git a/template/route.stub b/template/route.stub index 03a13e0..244ad12 100644 --- a/template/route.stub +++ b/template/route.stub @@ -14,7 +14,7 @@ func {{SingularCapitalName}}Router(router *http.ServeMux, application *app.App) router.Handle("GET /{{PluralLowerName}}", middleware.LimiterMiddleware(http.HandlerFunc(handler.Get{{PluralCapitalName}}))) router.Handle("POST /{{PluralLowerName}}", middleware.LimiterMiddleware(http.HandlerFunc(handler.Create{{SingularCapitalName}}))) - router.Handle("GET /{{PluralLowerName}}/{id}", middleware.LimiterMiddleware(http.HandlerFunc(handler.Get{{SingularCapitalName}}Details))) + router.Handle("GET /{{PluralLowerName}}/{id}/details", middleware.LimiterMiddleware(http.HandlerFunc(handler.Get{{SingularCapitalName}}Details))) router.Handle("PUT /{{PluralLowerName}}/{id}", middleware.LimiterMiddleware(http.HandlerFunc(handler.Update{{SingularCapitalName}}))) router.Handle("DELETE /{{PluralLowerName}}/{id}", middleware.LimiterMiddleware(http.HandlerFunc(handler.Delete{{SingularCapitalName}})))