Skip to content

Commit

Permalink
add rate limiting and loggin middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
artem-streltsov committed Sep 15, 2024
1 parent 61b12f0 commit bc90f38
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 5 deletions.
101 changes: 100 additions & 1 deletion internal/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ import (
"encoding/base64"
"html/template"
"log"
"net"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"

"github.com/artemstreltsov/url-shortener/internal/database"
"github.com/artemstreltsov/url-shortener/internal/safebrowsing"
Expand All @@ -24,6 +27,23 @@ type Handler struct {
store *sessions.CookieStore
}

type statusRecorder struct {
http.ResponseWriter
statusCode int
}

type RateLimiter struct {
visitors map[string]*visitor
mu sync.Mutex
limit int
window time.Duration
}

type visitor struct {
lastSeen time.Time
tokens int
}

func NewHandler(db *database.DB) *Handler {
templatesDir := "./internal/templates"
templates := template.Must(template.ParseGlob(filepath.Join(templatesDir, "*.html")))
Expand Down Expand Up @@ -54,7 +74,86 @@ func (h *Handler) Routes() http.Handler {
mux.HandleFunc("/dashboard", h.dashboardHandler)
mux.HandleFunc("/edit/", h.editURLHandler)
mux.HandleFunc("/delete/", h.deleteURLHandler)
return mux

rl := NewRateLimiter(100, time.Minute)
return LoggingMiddleware(RateLimitingMiddleware(rl)(mux))
}

func (rec *statusRecorder) WriteHeader(code int) {
rec.statusCode = code
rec.ResponseWriter.WriteHeader(code)
}

func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("Request: %s %s", r.Method, r.URL.Path)

rec := &statusRecorder{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(rec, r)

log.Printf("Response: %s %s %d", r.Method, r.URL.Path, rec.statusCode)
})
}

func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
visitors: make(map[string]*visitor),
limit: limit,
window: window,
}
go rl.cleanupVisitors()
return rl
}

func (rl *RateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()

v, exists := rl.visitors[ip]
if !exists || time.Since(v.lastSeen) > rl.window {
rl.visitors[ip] = &visitor{lastSeen: time.Now(), tokens: rl.limit - 1}
return true
}

if v.tokens > 0 {
v.tokens--
v.lastSeen = time.Now()
return true
}

return false
}

func (rl *RateLimiter) cleanupVisitors() {
for {
time.Sleep(time.Minute)
rl.mu.Lock()
for ip, v := range rl.visitors {
if time.Since(v.lastSeen) > rl.window {
delete(rl.visitors, ip)
}
}
rl.mu.Unlock()
}
}

func RateLimitingMiddleware(rl *RateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(w, "Unable to determine IP", http.StatusInternalServerError)
return
}

if !rl.Allow(ip) {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}

next.ServeHTTP(w, r)
})
}
}

func (h *Handler) indexHandler(w http.ResponseWriter, r *http.Request) {
Expand Down
2 changes: 1 addition & 1 deletion internal/safebrowsing/safebrowsing.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ func Close() {
if sb != nil {
sb.Close()
}
}
}
2 changes: 1 addition & 1 deletion internal/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ func GenerateKey(url string) string {
func IsValidURL(urlStr string) bool {
u, err := url.Parse(urlStr)
return err == nil && u.Scheme != "" && u.Host != ""
}
}
3 changes: 1 addition & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func init() {
}

func main() {
godotenv.Load() // Load .env file if it exists, ignore error if it doesn't
godotenv.Load()

port := getEnvWithDefault("PORT", "8080")

Expand All @@ -33,7 +33,6 @@ func main() {
dbPath = "database/database.sqlite3"
}

// Touch the database file if it doesn't exist
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
dir := filepath.Dir(dbPath)
if err := os.MkdirAll(dir, 0755); err != nil {
Expand Down

0 comments on commit bc90f38

Please sign in to comment.