Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add proxy polling restriction, allow to query "/clients/host" #178

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions cmd/proxy/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ func (c *Coordinator) getRequestChannel(fqdn string) chan *http.Request {
return ch
}

func (c *Coordinator) checkRequestChannel(fqdn string) bool {
c.mu.Lock()
defer c.mu.Unlock()
_, ok := c.waiting[fqdn]
return ok
}

func (c *Coordinator) getResponseChannel(id string) chan *http.Response {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down Expand Up @@ -115,7 +122,7 @@ func (c *Coordinator) DoScrape(ctx context.Context, r *http.Request) (*http.Resp
r.Header.Add("Id", id)
select {
case <-ctx.Done():
return nil, fmt.Errorf("Timeout reached for %q: %s", r.URL.String(), ctx.Err())
return nil, fmt.Errorf("timeout reached for %q: %s", r.URL.String(), ctx.Err())
case c.getRequestChannel(r.URL.Hostname()) <- r:
}

Expand Down Expand Up @@ -188,15 +195,23 @@ func (c *Coordinator) addKnownClient(fqdn string) {
}

// KnownClients returns a list of alive clients
func (c *Coordinator) KnownClients() []string {
func (c *Coordinator) KnownClients(client string) []string {
c.mu.Lock()
defer c.mu.Unlock()

var known []string
limit := time.Now().Add(-*registrationTimeout)
known := make([]string, 0, len(c.known))
for k, t := range c.known {
if limit.Before(t) {
known = append(known, k)
if client != "" {
known = make([]string, 0, 1)
if t, ok := c.known[client]; ok && limit.Before(t) {
known = append(known, client)
}
} else {
known = make([]string, 0, len(c.known))
for k, t := range c.known {
if limit.Before(t) {
known = append(known, k)
}
}
}
return known
Expand Down
236 changes: 199 additions & 37 deletions cmd/proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import (
"fmt"
"io"
"log/slog"
"net"
"net/http"
"os"
"regexp"
"strings"

"github.com/alecthomas/kingpin/v2"
Expand All @@ -42,6 +44,7 @@ var (
listenAddress = kingpin.Flag("web.listen-address", "Address to listen on for proxy and client requests.").Default(":8080").String()
maxScrapeTimeout = kingpin.Flag("scrape.max-timeout", "Any scrape with a timeout higher than this will have to be clamped to this.").Default("5m").Duration()
defaultScrapeTimeout = kingpin.Flag("scrape.default-timeout", "If a scrape lacks a timeout, use this value.").Default("15s").Duration()
authorizedPollers = kingpin.Flag("scrape.pollers-ip", "Comma separeted list of ips addresses or networks authorized to scrap via the proxy.").Default("").String()
)

var (
Expand All @@ -62,7 +65,8 @@ var (
prometheus.HistogramOpts{
Name: "pushprox_http_duration_seconds",
Help: "Time taken by path",
}, []string{"path"})
}, []string{"path"},
)
)

func init() {
Expand All @@ -82,38 +86,86 @@ type targetGroup struct {
Labels map[string]string `json:"labels"`
}

const (
OpEgals = 1
OpMatch = 2
)

type route struct {
path string
regex *regexp.Regexp
handler http.HandlerFunc
}

func newRoute(op int, path string, handler http.HandlerFunc) *route {
if op == OpEgals {
return &route{path, nil, handler}
} else if op == OpMatch {
return &route{"", regexp.MustCompile("^" + path + "$"), handler}

} else {
return nil
}

}

type httpHandler struct {
logger *slog.Logger
coordinator *Coordinator
mux http.Handler
proxy http.Handler
pollersNet map[*net.IPNet]int
}

func newHTTPHandler(logger *slog.Logger, coordinator *Coordinator, mux *http.ServeMux) *httpHandler {
h := &httpHandler{logger: logger, coordinator: coordinator, mux: mux}

// api handlers
handlers := map[string]http.HandlerFunc{
"/push": h.handlePush,
"/poll": h.handlePoll,
"/clients": h.handleListClients,
"/metrics": promhttp.Handler().ServeHTTP,
}
for path, handlerFunc := range handlers {
counter := httpAPICounter.MustCurryWith(prometheus.Labels{"path": path})
handler := promhttp.InstrumentHandlerCounter(counter, http.HandlerFunc(handlerFunc))
histogram := httpPathHistogram.MustCurryWith(prometheus.Labels{"path": path})
handler = promhttp.InstrumentHandlerDuration(histogram, handler)
mux.Handle(path, handler)
counter.WithLabelValues("200")
if path == "/push" {
counter.WithLabelValues("500")
}
if path == "/poll" {
counter.WithLabelValues("408")
}
func newHTTPHandler(logger *slog.Logger, coordinator *Coordinator, mux *http.ServeMux, pollers map[*net.IPNet]int) *httpHandler {
h := &httpHandler{logger: logger, coordinator: coordinator, mux: mux, pollersNet: pollers}

var routes = []*route{
newRoute(OpEgals, "/push", h.handlePush),
newRoute(OpEgals, "/poll", h.handlePoll),
newRoute(OpMatch, "/clients(/.*)?", h.handleListClients),
newRoute(OpEgals, "/metrics", promhttp.Handler().ServeHTTP),
}
hf := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
for _, route := range routes {
var path string

if route == nil {
continue
}
if route.regex != nil {
if strings.HasPrefix(route.path, "/clients") {
path = "/clients"
}
} else if req.URL.Path == route.path {
path = route.path
}
counter := httpAPICounter.MustCurryWith(prometheus.Labels{"path": path})
handler := promhttp.InstrumentHandlerCounter(counter, route.handler)
histogram := httpPathHistogram.MustCurryWith(prometheus.Labels{"path": path})
route.handler = promhttp.InstrumentHandlerDuration(histogram, handler)
// mux.Handle(route.path, handler)
counter.WithLabelValues("200")
if route.path == "/push" {
counter.WithLabelValues("500")
}
if route.path == "/poll" {
counter.WithLabelValues("408")
}
if route.regex != nil {
if route.regex != nil {
if route.regex.MatchString(req.URL.Path) {
route.handler(w, req)
return
}
}
} else if req.URL.Path == route.path {
route.handler(w, req)
return
}
}
})
h.mux = hf
// proxy handler
h.proxy = promhttp.InstrumentHandlerCounter(httpProxyCounter, http.HandlerFunc(h.handleProxy))

Expand All @@ -127,15 +179,15 @@ func (h *httpHandler) handlePush(w http.ResponseWriter, r *http.Request) {
scrapeResult, err := http.ReadResponse(bufio.NewReader(buf), nil)
if err != nil {
h.logger.Error("Error reading pushed response:", "err", err)
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), 500)
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), http.StatusInternalServerError)
return
}
scrapeId := scrapeResult.Header.Get("Id")
h.logger.Info("Got /push", "scrape_id", scrapeId)
err = h.coordinator.ScrapeResult(scrapeResult)
if err != nil {
h.logger.Error("Error pushing:", "err", err, "scrape_id", scrapeId)
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), 500)
http.Error(w, fmt.Sprintf("Error pushing: %s", err.Error()), http.StatusInternalServerError)
}
}

Expand All @@ -145,29 +197,104 @@ func (h *httpHandler) handlePoll(w http.ResponseWriter, r *http.Request) {
request, err := h.coordinator.WaitForScrapeInstruction(strings.TrimSpace(string(fqdn)))
if err != nil {
h.logger.Info("Error WaitForScrapeInstruction:", "err", err)
http.Error(w, fmt.Sprintf("Error WaitForScrapeInstruction: %s", err.Error()), 408)
http.Error(w, fmt.Sprintf("Error WaitForScrapeInstruction: %s", err.Error()), http.StatusRequestTimeout)
return
}
//nolint:errcheck // https://github.com/prometheus-community/PushProx/issues/111
request.WriteProxy(w) // Send full request as the body of the response.
h.logger.Info("Responded to /poll", "url", request.URL.String(), "scrape_id", request.Header.Get("Id"))
}

// isPoller checks if caller has an IP addr in authorized nets (if any defined). It uses RemoteAddr field
// from http.Request.
// RETURNS:
// - true and "" if no restriction is defined
// - true and clientip if @ip from RemoteAddr is found in allowed nets
// - false and "" else
func (h *httpHandler) isPoller(r *http.Request) (bool, string) {
var (
ispoller = false
clientip string
)

if len(h.pollersNet) > 0 {
if i := strings.Index(r.RemoteAddr, ":"); i != -1 {
clientip = r.RemoteAddr[0:i]
}
for key := range h.pollersNet {
ip := net.ParseIP(clientip)
if key.Contains(ip) {
ispoller = true
break
}
}
} else {
ispoller = true
}
return ispoller, clientip
}

// handleListClients handles requests to list available clients as a JSON array.
func (h *httpHandler) handleListClients(w http.ResponseWriter, r *http.Request) {
known := h.coordinator.KnownClients()
targets := make([]*targetGroup, 0, len(known))
for _, k := range known {
targets = append(targets, &targetGroup{Targets: []string{k}})
var (
targets []*targetGroup
lknown int
client string
)

ispoller, clientip := h.isPoller(r)
// if not a poller we are not authorized to get all clients, restrict query to itself hostname
if !ispoller {
hosts, err := net.LookupAddr(clientip)
if err != nil {
h.logger.Error("can't reverse client address", "err", err.Error())
}
if len(hosts) > 0 {
client = strings.ToLower(strings.TrimSuffix(hosts[0], "."))
} else {
client = "_not_found_hostname_"
}
} else {
if len(r.URL.Path) > 9 {
client = r.URL.Path[9:]
}
}
w.Header().Set("Content-Type", "application/json")
//nolint:errcheck // https://github.com/prometheus-community/PushProx/issues/111
json.NewEncoder(w).Encode(targets)
h.logger.Info("Responded to /clients", "client_count", len(known))
known := h.coordinator.KnownClients(client)
lknown = len(known)
if client != "" && lknown == 0 {
http.Error(w, "", http.StatusNotFound)
} else {
targets = make([]*targetGroup, 0, lknown)
for _, k := range known {
targets = append(targets, &targetGroup{Targets: []string{k}})
}
w.Header().Set("Content-Type", "application/json")
//nolint:errcheck // https://github.com/prometheus-community/PushProx/issues/111
json.NewEncoder(w).Encode(targets)
}
h.logger.Info("Responded to /clients", "client_count", lknown)
}

// handleProxy handles proxied scrapes from Prometheus.
func (h *httpHandler) handleProxy(w http.ResponseWriter, r *http.Request) {
if ok, clientip := h.isPoller(r); !ok {
var clientfqdn string
hosts, err := net.LookupAddr(clientip)
if err != nil {
h.logger.Error("can't reverse client address", "err", err.Error())
}
if len(hosts) > 0 {
// level.Info(h.logger).Log("hosts", fmt.Sprintf("%v", hosts))
clientfqdn = strings.ToLower(strings.TrimSuffix(hosts[0], "."))
} else {
clientfqdn = "_not_found_hostname_"
}
if !h.coordinator.checkRequestChannel(clientfqdn) {
http.Error(w, "Not an authorized poller", http.StatusForbidden)
return
}
}

ctx, cancel := context.WithTimeout(r.Context(), util.GetScrapeTimeout(maxScrapeTimeout, defaultScrapeTimeout, r.Header))
defer cancel()
request := r.WithContext(ctx)
Expand All @@ -176,7 +303,7 @@ func (h *httpHandler) handleProxy(w http.ResponseWriter, r *http.Request) {
resp, err := h.coordinator.DoScrape(ctx, request)
if err != nil {
h.logger.Error("Error scraping:", "err", err, "url", request.URL.String())
http.Error(w, fmt.Sprintf("Error scraping %q: %s", request.URL.String(), err.Error()), 500)
http.Error(w, fmt.Sprintf("Error scraping %q: %s", request.URL.String(), err.Error()), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
Expand All @@ -192,6 +319,18 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

// return list of network addresses from the httpHandlet.pollersNet map
func (h *httpHandler) pollersNetString() string {
if len(h.pollersNet) > 0 {
l := make([]string, 0, len(h.pollersNet))
for netw := range h.pollersNet {
l = append(l, netw.String())
}
return strings.Join(l, ",")
} else {
return ""
}
}
func main() {
promslogConfig := promslog.Config{}
flag.AddFlags(kingpin.CommandLine, &promslogConfig)
Expand All @@ -203,11 +342,34 @@ func main() {
logger.Error("Coordinator initialization failed", "err", err)
os.Exit(1)
}
pollersNet := make(map[*net.IPNet]int, 10)
if *authorizedPollers != "" {
networks := strings.Split(*authorizedPollers, ",")
for _, network := range networks {
if !strings.Contains(network, "/") {
// detect ipv6
if strings.Contains(network, ":") {
network = fmt.Sprintf("%s/128", network)
} else {
network = fmt.Sprintf("%s/32", network)
}
}
if _, subnet, err := net.ParseCIDR(network); err != nil {
logger.Error("network is invalid", "net", network, "err", err)
os.Exit(1)
} else {
pollersNet[subnet] = 1
}
}
}

mux := http.NewServeMux()
handler := newHTTPHandler(logger, coordinator, mux)
handler := newHTTPHandler(logger, coordinator, mux, pollersNet)

logger.Info("Listening", "address", *listenAddress)
if len(pollersNet) > 0 {
logger.Info("Polling restricted", "allowed", handler.pollersNetString())
}
if err := http.ListenAndServe(*listenAddress, handler); err != nil {
logger.Error("Listening failed", "err", err)
os.Exit(1)
Expand Down