From bc90f38d402aa2b8e25d196d1af80ff19249e6ed Mon Sep 17 00:00:00 2001 From: Artem Streltsov Date: Sun, 15 Sep 2024 14:29:41 +0200 Subject: [PATCH] add rate limiting and loggin middleware --- internal/handlers/handlers.go | 101 +++++++++++++++++++++++++- internal/safebrowsing/safebrowsing.go | 2 +- internal/utils/utils.go | 2 +- main.go | 3 +- 4 files changed, 103 insertions(+), 5 deletions(-) diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 4a72a5e..3315c17 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -5,11 +5,14 @@ import ( "encoding/base64" "html/template" "log" + "net" "net/http" "os" "path/filepath" "strconv" "strings" + "sync" + "time" "github.com/artemstreltsov/url-shortener/internal/database" "github.com/artemstreltsov/url-shortener/internal/safebrowsing" @@ -24,6 +27,23 @@ type Handler struct { store *sessions.CookieStore } +type statusRecorder struct { + http.ResponseWriter + statusCode int +} + +type RateLimiter struct { + visitors map[string]*visitor + mu sync.Mutex + limit int + window time.Duration +} + +type visitor struct { + lastSeen time.Time + tokens int +} + func NewHandler(db *database.DB) *Handler { templatesDir := "./internal/templates" templates := template.Must(template.ParseGlob(filepath.Join(templatesDir, "*.html"))) @@ -54,7 +74,86 @@ func (h *Handler) Routes() http.Handler { mux.HandleFunc("/dashboard", h.dashboardHandler) mux.HandleFunc("/edit/", h.editURLHandler) mux.HandleFunc("/delete/", h.deleteURLHandler) - return mux + + rl := NewRateLimiter(100, time.Minute) + return LoggingMiddleware(RateLimitingMiddleware(rl)(mux)) +} + +func (rec *statusRecorder) WriteHeader(code int) { + rec.statusCode = code + rec.ResponseWriter.WriteHeader(code) +} + +func LoggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("Request: %s %s", r.Method, r.URL.Path) + + rec := &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK} + next.ServeHTTP(rec, r) + + log.Printf("Response: %s %s %d", r.Method, r.URL.Path, rec.statusCode) + }) +} + +func NewRateLimiter(limit int, window time.Duration) *RateLimiter { + rl := &RateLimiter{ + visitors: make(map[string]*visitor), + limit: limit, + window: window, + } + go rl.cleanupVisitors() + return rl +} + +func (rl *RateLimiter) Allow(ip string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + v, exists := rl.visitors[ip] + if !exists || time.Since(v.lastSeen) > rl.window { + rl.visitors[ip] = &visitor{lastSeen: time.Now(), tokens: rl.limit - 1} + return true + } + + if v.tokens > 0 { + v.tokens-- + v.lastSeen = time.Now() + return true + } + + return false +} + +func (rl *RateLimiter) cleanupVisitors() { + for { + time.Sleep(time.Minute) + rl.mu.Lock() + for ip, v := range rl.visitors { + if time.Since(v.lastSeen) > rl.window { + delete(rl.visitors, ip) + } + } + rl.mu.Unlock() + } +} + +func RateLimitingMiddleware(rl *RateLimiter) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + http.Error(w, "Unable to determine IP", http.StatusInternalServerError) + return + } + + if !rl.Allow(ip) { + http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) + return + } + + next.ServeHTTP(w, r) + }) + } } func (h *Handler) indexHandler(w http.ResponseWriter, r *http.Request) { diff --git a/internal/safebrowsing/safebrowsing.go b/internal/safebrowsing/safebrowsing.go index d604e6f..5196e3e 100644 --- a/internal/safebrowsing/safebrowsing.go +++ b/internal/safebrowsing/safebrowsing.go @@ -48,4 +48,4 @@ func Close() { if sb != nil { sb.Close() } -} \ No newline at end of file +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index f0d9d82..3aa7145 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -23,4 +23,4 @@ func GenerateKey(url string) string { func IsValidURL(urlStr string) bool { u, err := url.Parse(urlStr) return err == nil && u.Scheme != "" && u.Host != "" -} \ No newline at end of file +} diff --git a/main.go b/main.go index 9458882..2855fa8 100644 --- a/main.go +++ b/main.go @@ -22,7 +22,7 @@ func init() { } func main() { - godotenv.Load() // Load .env file if it exists, ignore error if it doesn't + godotenv.Load() port := getEnvWithDefault("PORT", "8080") @@ -33,7 +33,6 @@ func main() { dbPath = "database/database.sqlite3" } - // Touch the database file if it doesn't exist if _, err := os.Stat(dbPath); os.IsNotExist(err) { dir := filepath.Dir(dbPath) if err := os.MkdirAll(dir, 0755); err != nil {