diff --git a/cmd/server.go b/cmd/server.go index 27ec393..37e7bd0 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -41,5 +41,5 @@ func init() { rootCmd.AddCommand(serverCmd) serverCmd.Flags().StringVarP(&listenPort, "port", "P", "80", "Listen port") - serverCmd.Flags().StringArrayVarP(&allowIP, "allow-ip", "I", []string{}, "ip to allow") + serverCmd.Flags().StringArrayVarP(&allowIP, "allow-ip", "I", []string{}, "IP addresses and CIDR blocks to allow; example: 192.168.0.1 or 0.0.0.0/0, 10.0.0.0/8") } diff --git a/websrc/serve/serve.go b/websrc/serve/serve.go index 8d95e77..2be2b70 100644 --- a/websrc/serve/serve.go +++ b/websrc/serve/serve.go @@ -19,6 +19,7 @@ import ( "fmt" "html/template" "io" + "net" "net/http" "strings" @@ -62,10 +63,22 @@ func (t *TemplateRenderer) Render(w io.Writer, name string, data interface{}, c func TrustedProxiesMiddleware(trustedProxies []string) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - clientIP := c.RealIP() // Echo gets the real IP of the client + clientIP := net.ParseIP(c.RealIP()) // Parse the real IP of the client + + if clientIP == nil { + return echo.NewHTTPError(http.StatusForbidden, "Invalid IP address") + } for _, proxy := range trustedProxies { - if strings.HasPrefix(clientIP, proxy) { + // Append /32 if no subnet mask is specified + if !strings.Contains(proxy, "/") { + proxy += "/32" + } + _, cidr, err := net.ParseCIDR(proxy) + if err != nil { + continue + } + if cidr.Contains(clientIP) { // Request is from a trusted proxy return next(c) }