Skip to content

Commit

Permalink
add proxy polling restriction
Browse files Browse the repository at this point in the history
Signed-off-by: peekjef72 <jfpik78@gmail.com>
  • Loading branch information
peekjef72 committed Jun 1, 2024
1 parent 83cae52 commit 5302694
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 43 deletions.
27 changes: 21 additions & 6 deletions cmd/proxy/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ func (c *Coordinator) getRequestChannel(fqdn string) chan *http.Request {
return ch
}

func (c *Coordinator) checkRequestChannel(fqdn string) bool {
c.mu.Lock()
defer c.mu.Unlock()
_, ok := c.waiting[fqdn]
return ok
}

func (c *Coordinator) getResponseChannel(id string) chan *http.Response {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down Expand Up @@ -116,7 +123,7 @@ func (c *Coordinator) DoScrape(ctx context.Context, r *http.Request) (*http.Resp
r.Header.Add("Id", id)
select {
case <-ctx.Done():
return nil, fmt.Errorf("Timeout reached for %q: %s", r.URL.String(), ctx.Err())
return nil, fmt.Errorf("timeout reached for %q: %s", r.URL.String(), ctx.Err())
case c.getRequestChannel(r.URL.Hostname()) <- r:
}

Expand Down Expand Up @@ -189,15 +196,23 @@ func (c *Coordinator) addKnownClient(fqdn string) {
}

// KnownClients returns a list of alive clients
func (c *Coordinator) KnownClients() []string {
func (c *Coordinator) KnownClients(client string) []string {
c.mu.Lock()
defer c.mu.Unlock()

var known []string
limit := time.Now().Add(-*registrationTimeout)
known := make([]string, 0, len(c.known))
for k, t := range c.known {
if limit.Before(t) {
known = append(known, k)
if client != "" {
known = make([]string, 0, 1)
if t, ok := c.known[client]; ok && limit.Before(t) {
known = append(known, client)
}
} else {
known = make([]string, 0, len(c.known))
for k, t := range c.known {
if limit.Before(t) {
known = append(known, k)
}
}
}
return known
Expand Down
239 changes: 202 additions & 37 deletions cmd/proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"regexp"
"strings"

"github.com/alecthomas/kingpin/v2"
Expand All @@ -43,6 +45,7 @@ var (
listenAddress = kingpin.Flag("web.listen-address", "Address to listen on for proxy and client requests.").Default(":8080").String()
maxScrapeTimeout = kingpin.Flag("scrape.max-timeout", "Any scrape with a timeout higher than this will have to be clamped to this.").Default("5m").Duration()
defaultScrapeTimeout = kingpin.Flag("scrape.default-timeout", "If a scrape lacks a timeout, use this value.").Default("15s").Duration()
authorizedPollers = kingpin.Flag("scrape.pollers-ip", "Comma separeted list of ips addresses or networks authorized to scrap via the proxy.").Default("").String()
)

var (
Expand All @@ -63,7 +66,10 @@ var (
prometheus.HistogramOpts{
Name: "pushprox_http_duration_seconds",
Help: "Time taken by path",
}, []string{"path"})
}, []string{"path"},
)

// hasPollersNet = false
)

func init() {
Expand All @@ -83,38 +89,86 @@ type targetGroup struct {
Labels map[string]string `json:"labels"`
}

const (
OpEgals = 1
OpMatch = 2
)

type route struct {
path string
regex *regexp.Regexp
handler http.HandlerFunc
}

func newRoute(op int, path string, handler http.HandlerFunc) *route {
if op == OpEgals {
return &route{path, nil, handler}
} else if op == OpMatch {
return &route{"", regexp.MustCompile("^" + path + "$"), handler}

} else {
return nil
}

}

type httpHandler struct {
logger log.Logger
coordinator *Coordinator
mux http.Handler
proxy http.Handler
pollersNet map[*net.IPNet]int
}

func newHTTPHandler(logger log.Logger, coordinator *Coordinator, mux *http.ServeMux) *httpHandler {
h := &httpHandler{logger: logger, coordinator: coordinator, mux: mux}

// api handlers
handlers := map[string]http.HandlerFunc{
"/push": h.handlePush,
"/poll": h.handlePoll,
"/clients": h.handleListClients,
"/metrics": promhttp.Handler().ServeHTTP,
}
for path, handlerFunc := range handlers {
counter := httpAPICounter.MustCurryWith(prometheus.Labels{"path": path})
handler := promhttp.InstrumentHandlerCounter(counter, http.HandlerFunc(handlerFunc))
histogram := httpPathHistogram.MustCurryWith(prometheus.Labels{"path": path})
handler = promhttp.InstrumentHandlerDuration(histogram, handler)
mux.Handle(path, handler)
counter.WithLabelValues("200")
if path == "/push" {
counter.WithLabelValues("500")
}
if path == "/poll" {
counter.WithLabelValues("408")
}
func newHTTPHandler(logger log.Logger, coordinator *Coordinator, mux *http.ServeMux, pollers map[*net.IPNet]int) *httpHandler {
h := &httpHandler{logger: logger, coordinator: coordinator, mux: mux, pollersNet: pollers}

var routes = []*route{
newRoute(OpEgals, "/push", h.handlePush),
newRoute(OpEgals, "/poll", h.handlePoll),
newRoute(OpMatch, "/clients(/.*)?", h.handleListClients),
newRoute(OpEgals, "/metrics", promhttp.Handler().ServeHTTP),
}
hf := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
for _, route := range routes {
var path string

if route == nil {
continue
}
if route.regex != nil {
if strings.HasPrefix(route.path, "/clients") {
path = "/clients"
}
} else if req.URL.Path == route.path {
path = route.path
}
counter := httpAPICounter.MustCurryWith(prometheus.Labels{"path": path})
handler := promhttp.InstrumentHandlerCounter(counter, route.handler)
histogram := httpPathHistogram.MustCurryWith(prometheus.Labels{"path": path})
route.handler = promhttp.InstrumentHandlerDuration(histogram, handler)
// mux.Handle(route.path, handler)
counter.WithLabelValues("200")
if route.path == "/push" {
counter.WithLabelValues("500")
}
if route.path == "/poll" {
counter.WithLabelValues("408")
}
if route.regex != nil {
if route.regex != nil {
if route.regex.MatchString(req.URL.Path) {
route.handler(w, req)
return
}
}
} else if req.URL.Path == route.path {
route.handler(w, req)
return
}
}
})
h.mux = hf
// proxy handler
h.proxy = promhttp.InstrumentHandlerCounter(httpProxyCounter, http.HandlerFunc(h.handleProxy))

Expand All @@ -128,15 +182,15 @@ func (h *httpHandler) handlePush(w http.ResponseWriter, r *http.Request) {
scrapeResult, err := http.ReadResponse(bufio.NewReader(buf), nil)
if err != nil {
level.Error(h.logger).Log("msg", "Error reading pushed response:", "err", err)
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), 500)
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), http.StatusInternalServerError)
return
}
scrapeId := scrapeResult.Header.Get("Id")
level.Info(h.logger).Log("msg", "Got /push", "scrape_id", scrapeId)
err = h.coordinator.ScrapeResult(scrapeResult)
if err != nil {
level.Error(h.logger).Log("msg", "Error pushing:", "err", err, "scrape_id", scrapeId)
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), 500)
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), http.StatusInternalServerError)
}
}

Expand All @@ -146,29 +200,105 @@ func (h *httpHandler) handlePoll(w http.ResponseWriter, r *http.Request) {
request, err := h.coordinator.WaitForScrapeInstruction(strings.TrimSpace(string(fqdn)))
if err != nil {
level.Info(h.logger).Log("msg", "Error WaitForScrapeInstruction:", "err", err)
http.Error(w, fmt.Sprintf("Error WaitForScrapeInstruction: %s", err.Error()), 408)
http.Error(w, fmt.Sprintf("Error WaitForScrapeInstruction: %s", err.Error()), http.StatusRequestTimeout)
return
}
//nolint:errcheck // https://github.com/prometheus-community/PushProx/issues/111
request.WriteProxy(w) // Send full request as the body of the response.
level.Info(h.logger).Log("msg", "Responded to /poll", "url", request.URL.String(), "scrape_id", request.Header.Get("Id"))
}

// isPoller checks if caller has an IP addr in authorized nets (if any defined). It uses RemoteAddr field
// from http.Request.
// RETURNS:
// - true and "" if no restriction is defined
// - true and clientip if @ip from RemoteAddr is found in allowed nets
// - false and "" else
func (h *httpHandler) isPoller(r *http.Request) (bool, string) {
var (
ispoller = false
clientip string
)

if len(h.pollersNet) > 0 {
if i := strings.Index(r.RemoteAddr, ":"); i != -1 {
clientip = r.RemoteAddr[0:i]
}
for key := range h.pollersNet {
ip := net.ParseIP(clientip)
if key.Contains(ip) {
ispoller = true
break
}
}
} else {
ispoller = true
}
return ispoller, clientip
}

// handleListClients handles requests to list available clients as a JSON array.
func (h *httpHandler) handleListClients(w http.ResponseWriter, r *http.Request) {
known := h.coordinator.KnownClients()
targets := make([]*targetGroup, 0, len(known))
for _, k := range known {
targets = append(targets, &targetGroup{Targets: []string{k}})
var (
targets []*targetGroup
lknown int
client string
)

ispoller, clientip := h.isPoller(r)
// if not a poller we are not authorized to get all clients, restrict query to itself hostname
if !ispoller {
hosts, err := net.LookupAddr(clientip)
if err != nil {
level.Error(h.logger).Log("msg", "can't reverse client address", "err", err.Error())
}
if len(hosts) > 0 {
// level.Info(h.logger).Log("hosts", fmt.Sprintf("%v", hosts))
client = strings.ToLower(strings.TrimSuffix(hosts[0], "."))
} else {
client = "_not_found_hostname_"
}
} else {
if len(r.URL.Path) > 9 {
client = r.URL.Path[9:]
}
}
w.Header().Set("Content-Type", "application/json")
//nolint:errcheck // https://github.com/prometheus-community/PushProx/issues/111
json.NewEncoder(w).Encode(targets)
level.Info(h.logger).Log("msg", "Responded to /clients", "client_count", len(known))
known := h.coordinator.KnownClients(client)
lknown = len(known)
if client != "" && lknown == 0 {
http.Error(w, "", http.StatusNotFound)
} else {
targets = make([]*targetGroup, 0, lknown)
for _, k := range known {
targets = append(targets, &targetGroup{Targets: []string{k}})
}
w.Header().Set("Content-Type", "application/json")
//nolint:errcheck // https://github.com/prometheus-community/PushProx/issues/111
json.NewEncoder(w).Encode(targets)
}
level.Info(h.logger).Log("msg", "Responded to /clients", "client_count", lknown)
}

// handleProxy handles proxied scrapes from Prometheus.
func (h *httpHandler) handleProxy(w http.ResponseWriter, r *http.Request) {
if ok, clientip := h.isPoller(r); !ok {
var clientfqdn string
hosts, err := net.LookupAddr(clientip)
if err != nil {
level.Error(h.logger).Log("msg", "can't reverse client address", "err", err.Error())
}
if len(hosts) > 0 {
// level.Info(h.logger).Log("hosts", fmt.Sprintf("%v", hosts))
clientfqdn = strings.ToLower(strings.TrimSuffix(hosts[0], "."))
} else {
clientfqdn = "_not_found_hostname_"
}
if !h.coordinator.checkRequestChannel(clientfqdn) {
http.Error(w, "Not an authorized poller", http.StatusForbidden)
return
}
}

ctx, cancel := context.WithTimeout(r.Context(), util.GetScrapeTimeout(maxScrapeTimeout, defaultScrapeTimeout, r.Header))
defer cancel()
request := r.WithContext(ctx)
Expand All @@ -177,7 +307,7 @@ func (h *httpHandler) handleProxy(w http.ResponseWriter, r *http.Request) {
resp, err := h.coordinator.DoScrape(ctx, request)
if err != nil {
level.Error(h.logger).Log("msg", "Error scraping:", "err", err, "url", request.URL.String())
http.Error(w, fmt.Sprintf("Error scraping %q: %s", request.URL.String(), err.Error()), 500)
http.Error(w, fmt.Sprintf("Error scraping %q: %s", request.URL.String(), err.Error()), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
Expand All @@ -193,6 +323,18 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

// return list of network addresses from the httpHandlet.pollersNet map
func (h *httpHandler) pollersNetString() string {
if len(h.pollersNet) > 0 {
l := make([]string, 0, len(h.pollersNet))
for netw := range h.pollersNet {
l = append(l, netw.String())
}
return strings.Join(l, ",")
} else {
return ""
}
}
func main() {
promlogConfig := promlog.Config{}
flag.AddFlags(kingpin.CommandLine, &promlogConfig)
Expand All @@ -204,11 +346,34 @@ func main() {
level.Error(logger).Log("msg", "Coordinator initialization failed", "err", err)
os.Exit(1)
}
pollersNet := make(map[*net.IPNet]int, 10)
if *authorizedPollers != "" {
networks := strings.Split(*authorizedPollers, ",")
for _, network := range networks {
if !strings.Contains(network, "/") {
// detect ipv6
if strings.Contains(network, ":") {
network = fmt.Sprintf("%s/128", network)
} else {
network = fmt.Sprintf("%s/32", network)
}
}
if _, subnet, err := net.ParseCIDR(network); err != nil {
level.Error(logger).Log("msg", "network is invalid", "net", network, "err", err)
os.Exit(1)
} else {
pollersNet[subnet] = 1
}
}
}

mux := http.NewServeMux()
handler := newHTTPHandler(logger, coordinator, mux)
handler := newHTTPHandler(logger, coordinator, mux, pollersNet)

level.Info(logger).Log("msg", "Listening", "address", *listenAddress)
if len(pollersNet) > 0 {
level.Info(logger).Log("msg", "Polling restricted", "allowed", handler.pollersNetString())
}
if err := http.ListenAndServe(*listenAddress, handler); err != nil {
level.Error(logger).Log("msg", "Listening failed", "err", err)
os.Exit(1)
Expand Down

0 comments on commit 5302694

Please sign in to comment.