diff --git a/internal/admin/handlers.go b/internal/admin/handlers.go index 958c730..b56fa52 100644 --- a/internal/admin/handlers.go +++ b/internal/admin/handlers.go @@ -45,8 +45,9 @@ func (a *AdminAPI) Handler() http.Handler { var middlewares []middleware.Middleware middlewares = append(middlewares, logger, - admin.NewAdminAccessLogMiddleware(a.logger), + admin.NewAccessLogMiddleware(a.logger), admin.NewHostnameMiddleware(adminApiHost, a.logger), + admin.NewIPRestrictionMiddleware(a.config.AdminAPI.AllowedIPs, a.logger), ) chain := middleware.NewMiddlewareChain(middlewares...) diff --git a/internal/admin/middleware/ip.go b/internal/admin/middleware/ip.go new file mode 100644 index 0000000..c9a40b0 --- /dev/null +++ b/internal/admin/middleware/ip.go @@ -0,0 +1,92 @@ +package admin + +import ( + "net" + "net/http" + "strings" + + "github.com/unkn0wn-root/terraster/internal/middleware" + "go.uber.org/zap" +) + +// IPRestrictionMiddleware validates incoming requests against configured allowed IPs +type IPRestrictionMiddleware struct { + allowedIPs []string + logger *zap.Logger +} + +// NewIPRestrictionMiddleware creates a new middleware for IP-based access control +func NewIPRestrictionMiddleware(allowedIPs []string, logger *zap.Logger) middleware.Middleware { + return &IPRestrictionMiddleware{ + allowedIPs: allowedIPs, + logger: logger, + } +} + +// This middleware provides IP-based access control. +// It validates the client's IP address against a configured list of allowed IPs. +// +// The middleware follows these rules: +// - If no IPs are configured (allowedIPs is empty), all requests are allowed +// - If IPs are configured, only requests from those IPs are allowed +// - Client IP is extracted from X-Forwarded-For header first, then X-Real-IP, finally falling back to RemoteAddr +// +// The function will return an HTTP 403 Forbidden status if the IP is not allowed, +// or HTTP 500 Internal Server Error if the client IP cannot be determined. +func (m *IPRestrictionMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // if no ip configured - assume allow all + if len(m.allowedIPs) == 0 { + next.ServeHTTP(w, r) + return + } + + // extract client IP from request + clientIP := extractIP(r) + if clientIP == "" { + http.Error(w, "Could not verify client IP", http.StatusInternalServerError) + return + } + + // check if client IP is allowed + for _, allowedIP := range m.allowedIPs { + if clientIP == allowedIP { + next.ServeHTTP(w, r) + return + } + } + + // if we get here, the IP is not allowed + m.logger.Warn("Access denied: IP not allowed", + zap.String("client_ip", clientIP), + zap.Strings("allowed_ips", m.allowedIPs), + ) + http.Error(w, "Access denied", http.StatusForbidden) + }) +} + +// extractIP gets the real client IP, taking into account X-Forwarded-For and X-Real-IP headers +func extractIP(r *http.Request) string { + forwardedFor := r.Header.Get("X-Forwarded-For") + if forwardedFor != "" { + // X-Forwarded-For can contain multiple IPs; take the first one + ips := strings.Split(forwardedFor, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + + // check X-Real-IP header if no X-Forwarded-For + realIP := r.Header.Get("X-Real-IP") + if realIP != "" { + return realIP + } + + // fall back to RemoteAddr + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + // if SplitHostPort fails, try using RemoteAddr directly + return r.RemoteAddr + } + return ip +} diff --git a/internal/admin/middleware/log.go b/internal/admin/middleware/log.go index 6baf1ee..fb8b3e1 100644 --- a/internal/admin/middleware/log.go +++ b/internal/admin/middleware/log.go @@ -7,17 +7,17 @@ import ( "go.uber.org/zap" ) -type AdminAccessLogMiddleware struct { +type AccessLogMiddleware struct { logger *zap.Logger } -func NewAdminAccessLogMiddleware(logger *zap.Logger) middleware.Middleware { - return &AdminAccessLogMiddleware{ +func NewAccessLogMiddleware(logger *zap.Logger) middleware.Middleware { + return &AccessLogMiddleware{ logger: logger, } } -func (m *AdminAccessLogMiddleware) Middleware(next http.Handler) http.Handler { +func (m *AccessLogMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.logger.Info("Request to Admin API", zap.String("method", r.Method), diff --git a/internal/config/api.config.go b/internal/config/api.config.go index dbdb135..24d5475 100644 --- a/internal/config/api.config.go +++ b/internal/config/api.config.go @@ -15,11 +15,12 @@ type APIConfig struct { } type API struct { - Enabled bool `yaml:"enabled"` - Host string `yaml:"host"` - Port int `yaml:"port"` - TLS *TLSConfig `yaml:"tls"` - Insecure bool `yaml:"insecure"` + Enabled bool `yaml:"enabled"` + Host string `yaml:"host"` + Port int `yaml:"port"` + TLS *TLSConfig `yaml:"tls"` + Insecure bool `yaml:"insecure"` + AllowedIPs []string `yaml:"allowed_ips"` } type DatabaseConfig struct {