From 073f0f3b3d03333a04ab572d09b099388fad2e19 Mon Sep 17 00:00:00 2001 From: Kirill Sysoev Date: Mon, 15 Jan 2024 21:41:30 +0800 Subject: [PATCH] Adds middleware for detecting client ip --- middleware/http/clientip.go | 78 +++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 middleware/http/clientip.go diff --git a/middleware/http/clientip.go b/middleware/http/clientip.go new file mode 100644 index 0000000..c2062d4 --- /dev/null +++ b/middleware/http/clientip.go @@ -0,0 +1,78 @@ +package http + +import ( + "context" + "net" + "net/http" + "strings" +) + +type ContextKey uint + +const ( + ClientIP ContextKey = iota +) + +type Provider uint8 + +const ( + NotProvided Provider = iota + Cloudflare + CloudFront +) + +func NewClientIPMiddleware(provider Provider) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := getIPFromRequest(provider, r) + + ctx := r.Context() + ctx = context.WithValue(ctx, ClientIP, ip) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) + } +} + +func getIPFromRequest(provider Provider, r *http.Request) string { + var ip string + + switch provider { + case Cloudflare: + ip = r.Header.Get("True-Client-IP") + if ip == "" { + ip = r.Header.Get("CF-Connecting-IP") + } + case CloudFront: + ip = r.Header.Get("CloudFront-Viewer-Address") + if ip != "" { + parts := strings.Split(ip, ":") + if len(parts) > 0 { + ip = parts[0] + } + } + default: + } + + if ip != "" { + return ip + } + + if ip := r.Header.Get("X-Real-Ip"); ip != "" { + return ip + } + + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + parts := strings.Split(ip, ",") + if len(parts) > 0 { + return parts[0] + } + } + + if ip, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + return ip + } + + return "" +}