Skip to content

Commit

Permalink
Fix(GPX-696): Add support for labels endpoint (#56)
Browse files Browse the repository at this point in the history
Improve query enforcement and error handling in proxy server

1. Added 'log_tokens' attribute in the config file, defaulting to false. This controls whether request tokens should be logged.
2. Improved the query enforcement in 'promqlEnforcer' function. Now, it also checks if the query is empty and sets it according to the 'allowedTenantLabels'.
3. Numerous changes were made in the 'reverseProxy' function to better handle requests and errors:
   - Removed the requirement for 'X-Plugin-Id' header, instead using URL path to differentiate between Thanos and Loki.
   + Added checks for '/api/v1/label' and '/api/v1/series' in the URL path.
   + Refactored error handling into two separate functions, one for standard error messages and another for custom messages.
   + Improved request logging to include method, URL, header, and body information. If 'log_tokens' is set to false, sensitive information is redacted.
   + Added support for handling POST requests and their bodies. The body query is also enforced using the 'enforceFunc' function.
4. Updated unit tests to reflect these changes.
  • Loading branch information
Lucostus authored Jun 20, 2023
1 parent a94f57d commit 8d34024
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 59 deletions.
1 change: 1 addition & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ proxy:
jwks_cert_url: https://sso.example.com/realms/internal/protocol/openid-connect/certs
admin_group: gepardec-run-admins
insecure_skip_verify: false
log_tokens: false
port: 8080
host: localhost
tenant_labels:
Expand Down
12 changes: 12 additions & 0 deletions enforcer_promql.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ import (
// the error and returns it.
func promqlEnforcer(query string, allowedTenantLabels map[string]bool) (string, error) {
currentTime := time.Now()
if query == "" {
operator := "="
if len(allowedTenantLabels) > 1 {
operator = "=~"
}
query = fmt.Sprintf("{%s%s\"%s\"}",
Cfg.Proxy.TenantLabels.Thanos,
operator,
strings.Join(MapKeysToArray(allowedTenantLabels),
"|"))
}
Logger.Debug("Start promqlEnforcer", zap.String("query", query), zap.Time("time", currentTime))
expr, err := parser.ParseExpr(query)
if err != nil {
Logger.Error("error",
Expand Down
152 changes: 110 additions & 42 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package main

import (
"bytes"
"encoding/json"
"fmt"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
"io"
"net/http"
"net/http/httputil"
"net/http/pprof"
"net/url"
"strings"
)

// main is the entry point of the application. It initializes necessary components, sets up HTTP routes, and starts the HTTP server.
Expand Down Expand Up @@ -35,7 +39,6 @@ func main() {

// healthz is an HTTP handler that always returns an HTTP status of 200 and a response body of "Ok". It's commonly used for health checks.
func healthz(w http.ResponseWriter, _ *http.Request) {
Logger.Debug("Healthz")
w.WriteHeader(http.StatusOK)
_, _ = fmt.Fprint(w, "Ok")
}
Expand All @@ -53,13 +56,34 @@ func reverseProxy(rw http.ResponseWriter, req *http.Request) {
var upstreamUrl *url.URL
var enforceFunc func(string, map[string]bool) (string, error)
var tenantLabels map[string]bool
query := req.URL.Query().Get("query")
var err error

urlKey := "query"
if containsApiV1Labels(req.URL.Path) {
urlKey = "match[]"
}
query := req.URL.Query().Get(urlKey)

upstreamUrl, err = url.Parse(Cfg.Proxy.ThanosUrl)
enforceFunc = promqlEnforcer
Logger.Debug("Parsed Thanos URL")

if containsLoki(req.URL.Path) {
upstreamUrl, err = url.Parse(Cfg.Proxy.LokiUrl)
enforceFunc = logqlEnforcer
Logger.Debug("Parsed Loki URL")
}

if err != nil {
logAndWriteErrorMsg(rw, "Error parsing upstream url", http.StatusInternalServerError, err)
return
}

logRequest(req)
Logger.Debug("url request", zap.String("url", req.URL.String()))

if !hasAuthorizationHeader(req) {
logAndWriteError(rw, "No Authorization header found", http.StatusForbidden, nil)
logAndWriteErrorMsg(rw, "No Authorization header found", http.StatusForbidden, nil)
return
}

Expand All @@ -68,45 +92,19 @@ func reverseProxy(rw http.ResponseWriter, req *http.Request) {
tokenString := getBearerToken(req)
keycloakToken, token, err := parseJwtToken(tokenString)
if err != nil && !Cfg.Dev.Enabled {
logAndWriteError(rw, "Error parsing Keycloak token", http.StatusForbidden, err)
logAndWriteErrorMsg(rw, "Error parsing Keycloak token", http.StatusForbidden, err)
return
}

Logger.Debug("Parsed JWT token")

if !isValidToken(token) {
logAndWriteError(rw, "Invalid token", http.StatusForbidden, nil)
logAndWriteErrorMsg(rw, "Invalid token", http.StatusForbidden, nil)
return
}

Logger.Debug("Token is valid")

if req.Header.Get("X-Plugin-Id") != "thanos" && req.Header.Get("X-Plugin-Id") != "loki" {
logAndWriteError(rw, "No X-Plugin-Id header found", http.StatusForbidden, nil)
return
}

Logger.Debug("Has X-Plugin-Id")

if req.Header.Get("X-Plugin-Id") == "thanos" {
upstreamUrl, err = url.Parse(Cfg.Proxy.ThanosUrl)
enforceFunc = promqlEnforcer
Logger.Debug("Parsed Thanos URL")
}

if req.Header.Get("X-Plugin-Id") == "loki" {
upstreamUrl, err = url.Parse(Cfg.Proxy.LokiUrl)
enforceFunc = logqlEnforcer
Logger.Debug("Parsed Loki URL")
}

if err != nil {
logAndWriteError(rw, "Error parsing upstream url", http.StatusForbidden, err)
return
}

Logger.Debug("No error in parsing URLs")

if isAdminSkip(keycloakToken) {
goto DoRequest
}
Expand All @@ -126,37 +124,57 @@ func reverseProxy(rw http.ResponseWriter, req *http.Request) {
tenantLabels = GetLabelsCM(keycloakToken.PreferredUsername, keycloakToken.Groups)
Logger.Debug("Fetched labels from ConfigMap")
default:
logAndWriteError(rw, "No provider set", http.StatusForbidden, nil)
logAndWriteErrorMsg(rw, "No provider set", http.StatusForbidden, nil)
return
}

Logger.Debug("username", zap.String("username", keycloakToken.PreferredUsername))

if len(tenantLabels) <= 0 {
logAndWriteError(rw, "No tenant labels found", http.StatusForbidden, nil)
logAndWriteErrorMsg(rw, "No tenant labels found", http.StatusForbidden, nil)
return
}
Logger.Debug("Labels", zap.Any("tenantLabels", tenantLabels))

query, err = enforceFunc(query, tenantLabels)
if err != nil {
logAndWriteError(rw, "Error modifying query", http.StatusForbidden, err)
logAndWriteError(rw, http.StatusForbidden, err)
return
}
if req.Method == http.MethodPost {
if err := req.ParseForm(); err != nil {
logAndWriteErrorMsg(rw, "Error processing Post request", http.StatusForbidden, err)
return
}
Logger.Debug("Parsed form", zap.Any("form", req.PostForm))
body := req.PostForm
query, err = enforceFunc(body.Get(urlKey), tenantLabels)
if err != nil {
logAndWriteError(rw, http.StatusForbidden, err)
return
}
body.Set(urlKey, query)

// We are replacing request body, close previous one (ParseForm ensures it is read fully and not nil).
_ = req.Body.Close()
newBody := body.Encode()
req.Body = io.NopCloser(strings.NewReader(newBody))
req.ContentLength = int64(len(newBody))
}

Logger.Debug("Modified query successfully")

DoRequest:

Logger.Debug("Doing request")

values := req.URL.Query()
values.Set("query", query)
values.Set(urlKey, query)
req.URL.RawQuery = values.Encode()
Logger.Debug("Set query")

Logger.Debug("Doing request")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ServiceAccountToken))
Logger.Debug("Set Authorization header")
logRequest(req)

proxy := httputil.NewSingleHostReverseProxy(upstreamUrl)
proxy.ServeHTTP(rw, req)
Expand Down Expand Up @@ -184,20 +202,70 @@ func isAdminSkip(token KeycloakToken) bool {
return ContainsIgnoreCase(token.Groups, Cfg.Proxy.AdminGroup) || ContainsIgnoreCase(token.ApaGroupsOrg, Cfg.Proxy.AdminGroup)
}

// logAndWriteError logs an error and sends an error message as the HTTP response.
func logAndWriteError(rw http.ResponseWriter, message string, statusCode int, err error) {
func containsApiV1Labels(s string) bool {
return strings.Contains(s, "/api/v1/label") || strings.Contains(s, "/api/v1/series")
}

func containsLoki(s string) bool {
return strings.Contains(s, "/loki")
}

// logAndWriteErrorMsg logs an error and sends an error message as the HTTP response.
func logAndWriteErrorMsg(rw http.ResponseWriter, message string, statusCode int, err error) {
Logger.Error(message, zap.Error(err))
rw.WriteHeader(statusCode)
_, _ = fmt.Fprint(rw, message+"\n")
}

// logAndWriteError logs an error and sends an error message as the HTTP response.
func logAndWriteError(rw http.ResponseWriter, statusCode int, err error) {
Logger.Error(err.Error(), zap.Error(err))
rw.WriteHeader(statusCode)
_, _ = fmt.Fprint(rw, err.Error()+"\n")
}

// logRequest logs the details of an incoming HTTP request.
func logRequest(req *http.Request) {
dump, err := httputil.DumpRequest(req, true)
var bodyBytes []byte
if req.Body != nil {
bodyBytes, _ = io.ReadAll(req.Body)
}

// Restore the io.ReadCloser to its original state
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
if !Cfg.Proxy.LogTokens {
bodyBytes = []byte("[REDACTED]")
}

requestData := struct {
Method string `json:"method"`
URL string `json:"url"`
Header http.Header `json:"header"`
Body string `json:"body"`
}{
Method: req.Method,
URL: req.URL.String(),
Header: req.Header,
Body: string(bodyBytes),
}

if !Cfg.Proxy.LogTokens {
// Make a copy of the header map so we're not modifying the original
copyHeader := make(http.Header)
for k, v := range requestData.Header {
copyHeader[k] = v
}
copyHeader.Del("Authorization")
copyHeader.Del("X-Plugin-Id")
requestData.Header = copyHeader
}

jsonData, err := json.Marshal(requestData)
if err != nil {
Logger.Error("Error while dumping request", zap.Error(err))
Logger.Error("Error while marshalling request", zap.Error(err))
return
}
Logger.Debug("Request", zap.String("request", string(dump)))
Logger.Debug("Request", zap.String("request", string(jsonData)))
}

// parseJwtToken parses a JWT token string into a Keycloak token and a JWT token. It returns an error if parsing fails.
Expand Down
21 changes: 4 additions & 17 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,6 @@ func Test_reverseProxy(t *testing.T) {
authorization: "Bearer " + "skk",
expectedBody: "Error parsing Keycloak token\n",
},
{
name: "Missing x-plugin-id header",
expectedStatus: http.StatusForbidden,
setAuthorization: true,
authorization: "Bearer " + tokens["noTenant"],
expectedBody: "No X-Plugin-Id header found\n",
},
{
name: "Missing tenant labels for user",
expectedStatus: http.StatusForbidden,
Expand All @@ -202,8 +195,8 @@ func Test_reverseProxy(t *testing.T) {
pluginID: "thanos",
setAuthorization: true,
setPluginID: true,
expectedStatus: http.StatusForbidden,
expectedBody: "Error modifying query\n",
expectedStatus: http.StatusOK,
expectedBody: "Upstream server response\n",
},
{
name: "User belongs to multiple groups, accessing forbidden tenant",
Expand All @@ -213,7 +206,7 @@ func Test_reverseProxy(t *testing.T) {
setPluginID: true,
URL: "/api/v1/query?query=up{tenant_id=\"forbidden_tenant\"}",
expectedStatus: http.StatusForbidden,
expectedBody: "Error modifying query\n",
expectedBody: "user not allowed with namespace forbidden_tenant\n",
},
{
name: "User belongs to no groups, accessing forbidden tenant",
Expand Down Expand Up @@ -342,13 +335,7 @@ func TestLogAndWriteError(t *testing.T) {
assert := assert.New(t)

rw := httptest.NewRecorder()
logAndWriteError(rw, "test error", http.StatusInternalServerError, nil)
logAndWriteErrorMsg(rw, "test error", http.StatusInternalServerError, nil)
assert.Equal(http.StatusInternalServerError, rw.Code)
assert.Equal("test error\n", rw.Body.String())
}

func TestParseJwtToken(t *testing.T) {
// Here you would typically set up a mock of Jwks.Keyfunc and then verify that
// jwt.ParseWithClaims is called with the appropriate arguments.
// However, due to complexity and without exact structure, skipping this test case.
}
1 change: 1 addition & 0 deletions structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Config struct {
JwksCertURL string `mapstructure:"jwks_cert_url"`
AdminGroup string `mapstructure:"admin_group"`
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
LogTokens bool `mapstructure:"log_tokens"`
Port int `mapstructure:"port"`
Host string `mapstructure:"host"`
TenantLabels struct {
Expand Down

0 comments on commit 8d34024

Please sign in to comment.