Skip to content

Commit

Permalink
rate limier updated
Browse files Browse the repository at this point in the history
  • Loading branch information
JubaerHossain committed Nov 27, 2024
1 parent 4491c37 commit 6ef4326
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 32 deletions.
84 changes: 64 additions & 20 deletions pkg/core/limiter/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,99 @@ package limiter

import (
"sync"
"time"

"golang.org/x/time/rate"
)

// IPRateLimiter .
type limiterEntry struct {
limiter *rate.Limiter
lastSeen time.Time
}

// IPRateLimiter handles rate limiting by IP address
type IPRateLimiter struct {
ips map[string]*rate.Limiter
mu *sync.RWMutex
r rate.Limit
b int
ips map[string]*limiterEntry
mu *sync.RWMutex
r rate.Limit
b int
maxAge time.Duration
}

// NewIPRateLimiter .
// NewIPRateLimiter creates a new IP rate limiter with cleanup routine
func NewIPRateLimiter(r rate.Limit, b int) *IPRateLimiter {
i := &IPRateLimiter{
ips: make(map[string]*rate.Limiter),
mu: &sync.RWMutex{},
r: r,
b: b,
ips: make(map[string]*limiterEntry),
mu: &sync.RWMutex{},
r: r,
b: b,
maxAge: time.Hour, // Cleanup entries older than 1 hour
}

// Start cleanup routine
go i.cleanupRoutine()

return i
}

// AddIP creates a new rate limiter and adds it to the ips map,
// using the IP address as the key
// AddIP creates a new rate limiter and adds it to the ips map
func (i *IPRateLimiter) AddIP(ip string) *rate.Limiter {
i.mu.Lock()
defer i.mu.Unlock()

limiter := rate.NewLimiter(i.r, i.b)

i.ips[ip] = limiter
entry := &limiterEntry{
limiter: limiter,
lastSeen: time.Now(),
}
i.ips[ip] = entry

return limiter
}

// GetLimiter returns the rate limiter for the provided IP address if it exists.
// Otherwise calls AddIP to add IP address to the map
// GetLimiter returns the rate limiter for the provided IP address
func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
i.mu.Lock()
limiter, exists := i.ips[ip]
i.mu.RLock()
entry, exists := i.ips[ip]
i.mu.RUnlock()

if !exists {
i.mu.Unlock()
return i.AddIP(ip)
}

// Update last seen time
i.mu.Lock()
entry.lastSeen = time.Now()
i.mu.Unlock()

return limiter
return entry.limiter
}

// cleanupRoutine removes old entries periodically
func (i *IPRateLimiter) cleanupRoutine() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()

for range ticker.C {
i.cleanup()
}
}

// cleanup removes entries that haven't been used recently
func (i *IPRateLimiter) cleanup() {
i.mu.Lock()
defer i.mu.Unlock()

for ip, entry := range i.ips {
if time.Since(entry.lastSeen) > i.maxAge {
delete(i.ips, ip)
}
}
}

// GetIPCount returns the current number of IP limiters
func (i *IPRateLimiter) GetIPCount() int {
i.mu.RLock()
defer i.mu.RUnlock()
return len(i.ips)
}
165 changes: 155 additions & 10 deletions pkg/core/middleware/limiterMidleware.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
// /Users/sookh.com/go/pkg/mod/github.com/!jubaer!hossain/rootx@v1.4.7/pkg/core/middleware/limiterMidleware.go
package middleware

import (
"fmt"
"log"
"net"
"net/http"
"os"
"strings"
"time"

"github.com/JubaerHossain/rootx/pkg/core/config"
Expand All @@ -11,23 +16,163 @@ import (
"golang.org/x/time/rate"
)

func LimiterMiddleware(next http.Handler) http.Handler {
is_limit_enabled := config.GlobalConfig.RateLimitEnabled
if !is_limit_enabled {
return next
// RateLimitConfig holds the configuration for rate limiting
type RateLimitConfig struct {
Enabled bool
Limit int
Duration time.Duration
WhitelistIPs []string
BlacklistIPs []string
}

var (
whitelistedIPs = make(map[string]bool)
blacklistedIPs = make(map[string]bool)
)

// isIPWhitelisted checks if an IP is in the whitelist
func isIPWhitelisted(ip string) bool {
return whitelistedIPs[ip]
}

// isIPBlacklisted checks if an IP is in the blacklist
func isIPBlacklisted(ip string) bool {
return blacklistedIPs[ip]
}

// getClientIP extracts the real client IP considering various headers
func getClientIP(r *http.Request) string {
// Check CF-Connecting-IP (Cloudflare)
if cfIP := r.Header.Get("CF-Connecting-IP"); cfIP != "" {
return cfIP
}
limit := config.GlobalConfig.RateLimit
duration, err := time.ParseDuration(os.Getenv("RATE_LIMIT_DURATION"))

// Check X-Real-IP
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
return realIP
}

// Check X-Forwarded-For
if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" {
// X-Forwarded-For can contain multiple IPs, take the first one
ips := strings.Split(forwardedFor, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}

// Fall back to RemoteAddr
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
duration = time.Second * 2
return r.RemoteAddr
}
return ip
}

// LimiterMiddleware creates a new rate limiting middleware
func LimiterMiddleware(next http.Handler) http.Handler {
// Initialize configuration
config := loadRateLimitConfig()
if !config.Enabled {
log.Println("Rate limiting is disabled.")
return next
}
var limiter = limiter.NewIPRateLimiter(rate.Every(duration), limit)

// Initialize IP lists
initializeIPLists(config)

var rateLimiter = limiter.NewIPRateLimiter(rate.Every(config.Duration), config.Limit)

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limiter := limiter.GetLimiter(r.RemoteAddr)
clientIP := getClientIP(r)

// Check whitelist
if isIPWhitelisted(clientIP) {
next.ServeHTTP(w, r)
return
}

// Check blacklist
if isIPBlacklisted(clientIP) {
utils.WriteJSONError(w, http.StatusForbidden, "Access denied")
return
}

// Get limiter for this IP
limiter := rateLimiter.GetLimiter(clientIP)

// Set standard rate limit headers
setRateLimitHeaders(w, limiter, config)

// Check if rate limit is exceeded
if !limiter.Allow() {
utils.WriteJSONError(w, http.StatusTooManyRequests, "Too many requests")
handleRateLimitExceeded(w, r, clientIP, config)
return
}

// Add custom headers for debugging/monitoring
w.Header().Set("X-Rate-Limit-IP", clientIP)
w.Header().Set("X-Rate-Limit-Active-IPs", fmt.Sprintf("%d", rateLimiter.GetIPCount()))

next.ServeHTTP(w, r)
})
}

func loadRateLimitConfig() RateLimitConfig {
duration, err := time.ParseDuration(os.Getenv("RATE_LIMIT_DURATION"))
if err != nil {
duration = time.Minute
}

limit := config.GlobalConfig.RateLimit
if limit <= 0 {
limit = 100
}

return RateLimitConfig{
Enabled: config.GlobalConfig.RateLimitEnabled,
Limit: limit,
Duration: duration,
WhitelistIPs: strings.Split(os.Getenv("RATE_LIMIT_WHITELIST"), ","),
BlacklistIPs: strings.Split(os.Getenv("RATE_LIMIT_BLACKLIST"), ","),
}
}

func initializeIPLists(config RateLimitConfig) {
// Initialize whitelist
for _, ip := range config.WhitelistIPs {
if ip = strings.TrimSpace(ip); ip != "" {
whitelistedIPs[ip] = true
}
}

// Initialize blacklist
for _, ip := range config.BlacklistIPs {
if ip = strings.TrimSpace(ip); ip != "" {
blacklistedIPs[ip] = true
}
}
}

func setRateLimitHeaders(w http.ResponseWriter, limiter *rate.Limiter, config RateLimitConfig) {
tokens := limiter.Tokens()
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", config.Limit))
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%.0f", tokens))
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(config.Duration).Unix()))
}

func handleRateLimitExceeded(w http.ResponseWriter, r *http.Request, clientIP string, config RateLimitConfig) {
log.Printf("Rate limit exceeded for IP: %s, UA: %s, Path: %s",
clientIP,
r.UserAgent(),
r.URL.Path,
)

w.Header().Set("Retry-After", fmt.Sprintf("%d", config.Duration/time.Second))
utils.WriteJSONError(w,
http.StatusTooManyRequests,
fmt.Sprintf("Rate limit exceeded. Please try again in %d seconds",
config.Duration/time.Second,
),
)
}
2 changes: 1 addition & 1 deletion pkg/create/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ func ApplyMigrations(cmd *cobra.Command, args []string) error {

migrationsDir := "migrations"
if err := executeScriptsInDirectory(pool, migrationsDir); err != nil {
return fmt.Errorf("failed to execute migration scripts")
return fmt.Errorf("failed to execute migration scripts %w", err)
}

bar := CreateProgressBar("migration: ")
Expand Down
2 changes: 1 addition & 1 deletion template/route.stub
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func {{SingularCapitalName}}Router(router *http.ServeMux, application *app.App)

router.Handle("GET /{{PluralLowerName}}", middleware.LimiterMiddleware(http.HandlerFunc(handler.Get{{PluralCapitalName}})))
router.Handle("POST /{{PluralLowerName}}", middleware.LimiterMiddleware(http.HandlerFunc(handler.Create{{SingularCapitalName}})))
router.Handle("GET /{{PluralLowerName}}/{id}", middleware.LimiterMiddleware(http.HandlerFunc(handler.Get{{SingularCapitalName}}Details)))
router.Handle("GET /{{PluralLowerName}}/{id}/details", middleware.LimiterMiddleware(http.HandlerFunc(handler.Get{{SingularCapitalName}}Details)))
router.Handle("PUT /{{PluralLowerName}}/{id}", middleware.LimiterMiddleware(http.HandlerFunc(handler.Update{{SingularCapitalName}})))
router.Handle("DELETE /{{PluralLowerName}}/{id}", middleware.LimiterMiddleware(http.HandlerFunc(handler.Delete{{SingularCapitalName}})))

Expand Down

0 comments on commit 6ef4326

Please sign in to comment.