diff --git a/collect/collect.go b/collect/collect.go index 2a499dafda..06123faf8b 100644 --- a/collect/collect.go +++ b/collect/collect.go @@ -1283,6 +1283,19 @@ func (i *InMemCollector) sendTraces() { for t := range i.outgoingTraces { i.Metrics.Histogram("collector_outgoing_queue", float64(len(i.outgoingTraces))) _, span := otelutil.StartSpanMulti(context.Background(), i.Tracer, "sendTrace", map[string]interface{}{"num_spans": t.DescendantCount(), "outgoingTraces_size": len(i.outgoingTraces)}) + + // if we have a key replacement rule, we should + // replace the key with the new key + keycfg := i.Config.GetAccessKeyConfig() + overwriteWith, err := keycfg.GetReplaceKey(t.APIKey) + if err != nil { + i.Logger.Warn().Logf("error replacing key: %s", err.Error()) + continue + } + if overwriteWith != t.APIKey { + t.APIKey = overwriteWith + } + for _, sp := range t.GetSpans() { if sp.IsDecisionSpan() { continue @@ -1318,6 +1331,8 @@ func (i *InMemCollector) sendTraces() { } mergeTraceAndSpanSampleRates(sp, t.SampleRate(), isDryRun) i.addAdditionalAttributes(sp) + + sp.APIKey = t.APIKey i.Transmission.EnqueueSpan(sp) } span.End() diff --git a/collect/collect_test.go b/collect/collect_test.go index e339821347..da96fcdefa 100644 --- a/collect/collect_test.go +++ b/collect/collect_test.go @@ -120,6 +120,10 @@ func TestAddRootSpan(t *testing.T) { GetCollectionConfigVal: config.CollectionConfig{ ShutdownDelay: config.Duration(1 * time.Millisecond), }, + GetAccessKeyConfigVal: config.AccessKeyConfig{ + SendKey: "another-key", + SendKeyMode: "all", + }, } transmission := &transmit.MockTransmission{} transmission.Start() @@ -166,6 +170,7 @@ func TestAddRootSpan(t *testing.T) { events := transmission.GetBlock(1) require.Equal(t, 1, len(events), "adding a root span should send the span") assert.Equal(t, "aoeu", events[0].Dataset, "sending a root span should immediately send that span via transmission") + assert.Equal(t, "another-key", events[0].APIKey, "api key should be replaced with the send key") assert.Nil(t, coll.getFromCache(traceID1), "after sending the span, it should be removed from the cache") @@ -185,6 +190,7 @@ func TestAddRootSpan(t *testing.T) { events = transmission.GetBlock(1) require.Equal(t, 1, len(events), "adding another root span should send the span") assert.Equal(t, "aoeu", events[0].Dataset, "sending a root span should immediately send that span via transmission") + assert.Equal(t, "another-key", events[0].APIKey, "api key should be replaced with the send key") assert.Nil(t, coll.getFromCache(traceID1), "after sending the span, it should be removed from the cache") diff --git a/config/file_config.go b/config/file_config.go index 052f9fe70d..cd270d4381 100644 --- a/config/file_config.go +++ b/config/file_config.go @@ -91,21 +91,23 @@ type AccessKeyConfig struct { AcceptOnlyListedKeys bool `yaml:"AcceptOnlyListedKeys"` } -// truncate the key to 8 characters for logging -func (a *AccessKeyConfig) sanitize(key string) string { - return fmt.Sprintf("%.8s...", key) +// IsAccepted checks if the given key is in the list of accepted keys. +// if the key is not in the list, it returns an error with the key truncated to 8 characters for logging. +func (a *AccessKeyConfig) IsAccepted(key string) error { + if a.AcceptOnlyListedKeys { + if slices.Contains(a.ReceiveKeys, key) { + return nil + } + + return fmt.Errorf("api key %.8s... not found in list of authorized keys", key) + } + return nil } -// CheckAndMaybeReplaceKey checks the given API key against the configuration +// GetReplaceKey checks the given API key against the configuration // and possibly replaces it with the configured SendKey, if the settings so indicate. // It returns the key to use, or an error if the key is invalid given the settings. -func (a *AccessKeyConfig) CheckAndMaybeReplaceKey(apiKey string) (string, error) { - // Apply AcceptOnlyListedKeys logic BEFORE we consider replacement - if a.AcceptOnlyListedKeys && !slices.Contains(a.ReceiveKeys, apiKey) { - err := fmt.Errorf("api key %s not found in list of authorized keys", a.sanitize(apiKey)) - return "", err - } - +func (a *AccessKeyConfig) GetReplaceKey(apiKey string) (string, error) { if a.SendKey != "" { overwriteWith := "" switch a.SendKeyMode { diff --git a/config/file_config_test.go b/config/file_config_test.go index 3c1b868fd5..1ea4dda965 100644 --- a/config/file_config_test.go +++ b/config/file_config_test.go @@ -1,8 +1,14 @@ package config -import "testing" +import ( + "errors" + "testing" -func TestAccessKeyConfig_CheckAndMaybeReplaceKey(t *testing.T) { + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAccessKeyConfig_GetReplaceKey(t *testing.T) { type fields struct { ReceiveKeys []string SendKey string @@ -10,11 +16,6 @@ func TestAccessKeyConfig_CheckAndMaybeReplaceKey(t *testing.T) { AcceptOnlyListedKeys bool } - fNone := fields{} - fRcvAccept := fields{ - ReceiveKeys: []string{"key1", "key2"}, - AcceptOnlyListedKeys: true, - } fSendAll := fields{ ReceiveKeys: []string{"key1", "key2"}, SendKey: "sendkey", @@ -43,10 +44,6 @@ func TestAccessKeyConfig_CheckAndMaybeReplaceKey(t *testing.T) { want string wantErr bool }{ - {"empty", fNone, "userkey", "userkey", false}, - {"acceptonly known key", fRcvAccept, "key1", "key1", false}, - {"acceptonly unknown key", fRcvAccept, "badkey", "", true}, - {"acceptonly missing key", fRcvAccept, "", "", true}, {"send all known", fSendAll, "key1", "sendkey", false}, {"send all unknown", fSendAll, "userkey", "sendkey", false}, {"send all missing", fSendAll, "", "sendkey", false}, @@ -68,7 +65,7 @@ func TestAccessKeyConfig_CheckAndMaybeReplaceKey(t *testing.T) { SendKeyMode: tt.fields.SendKeyMode, AcceptOnlyListedKeys: tt.fields.AcceptOnlyListedKeys, } - got, err := a.CheckAndMaybeReplaceKey(tt.apiKey) + got, err := a.GetReplaceKey(tt.apiKey) if (err != nil) != tt.wantErr { t.Errorf("AccessKeyConfig.CheckAndMaybeReplaceKey() error = %v, wantErr %v", err, tt.wantErr) return @@ -79,3 +76,39 @@ func TestAccessKeyConfig_CheckAndMaybeReplaceKey(t *testing.T) { }) } } + +func TestAccessKeyConfig_IsAccepted(t *testing.T) { + type fields struct { + ReceiveKeys []string + SendKey string + SendKeyMode string + AcceptOnlyListedKeys bool + } + tests := []struct { + name string + fields fields + key string + want error + }{ + {"no keys", fields{}, "key1", nil}, + {"known key", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true}, "key1", nil}, + {"unknown key", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true}, "key2", errors.New("api key key2... not found in list of authorized keys")}, + {"accept missing key", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true}, "", errors.New("api key ... not found in list of authorized keys")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &AccessKeyConfig{ + ReceiveKeys: tt.fields.ReceiveKeys, + SendKey: tt.fields.SendKey, + SendKeyMode: tt.fields.SendKeyMode, + AcceptOnlyListedKeys: tt.fields.AcceptOnlyListedKeys, + } + err := a.IsAccepted(tt.key) + if tt.want == nil { + require.NoError(t, err) + return + } + assert.Equal(t, tt.want.Error(), err.Error()) + }) + } +} diff --git a/route/errors.go b/route/errors.go index f619d311fd..71cd0f3b34 100644 --- a/route/errors.go +++ b/route/errors.go @@ -31,6 +31,7 @@ var ( ErrJSONBuildFailed = handlerError{nil, "failed to build JSON response", http.StatusInternalServerError, false, true} ErrPostBody = handlerError{nil, "failed to read request body", http.StatusInternalServerError, false, false} ErrAuthNeeded = handlerError{nil, "unknown API key - check your credentials", http.StatusBadRequest, true, true} + ErrAuthInvalid = handlerError{nil, "invalid API key - check your credentials", http.StatusUnauthorized, true, true} ErrConfigReadFailed = handlerError{nil, "failed to read config", http.StatusBadRequest, false, false} ErrUpstreamFailed = handlerError{nil, "failed to create upstream request", http.StatusServiceUnavailable, true, true} ErrUpstreamUnavailable = handlerError{nil, "upstream target unavailable", http.StatusServiceUnavailable, true, true} diff --git a/route/middleware.go b/route/middleware.go index 0c8897e6c0..aad135cbaf 100644 --- a/route/middleware.go +++ b/route/middleware.go @@ -13,7 +13,7 @@ import ( // for generating request IDs func init() { - rand.Seed(time.Now().UnixNano()) + rand.New(rand.NewSource(time.Now().UnixNano())) } func (r *Router) queryTokenChecker(next http.Handler) http.Handler { @@ -45,15 +45,11 @@ func (r *Router) apiKeyProcessor(next http.Handler) http.Handler { } keycfg := r.Config.GetAccessKeyConfig() - - overwriteWith, err := keycfg.CheckAndMaybeReplaceKey(apiKey) - if err != nil { - r.handlerReturnWithError(w, ErrAuthNeeded, err) + if err := keycfg.IsAccepted(apiKey); err != nil { + r.handlerReturnWithError(w, ErrAuthInvalid, err) return } - if overwriteWith != apiKey { - req.Header.Set(types.APIKeyHeader, overwriteWith) - } + next.ServeHTTP(w, req) }) } diff --git a/route/otlp_logs.go b/route/otlp_logs.go index 3b2c2b2889..14ff31ac88 100644 --- a/route/otlp_logs.go +++ b/route/otlp_logs.go @@ -25,8 +25,7 @@ func (r *Router) postOTLPLogs(w http.ResponseWriter, req *http.Request) { } apicfg := r.Config.GetAccessKeyConfig() - keyToUse, err := apicfg.CheckAndMaybeReplaceKey(ri.ApiKey) - + keyToUse, err := apicfg.GetReplaceKey(ri.ApiKey) if err != nil { r.handleOTLPFailureResponse(w, req, huskyotlp.OTLPError{Message: err.Error(), HTTPStatusCode: http.StatusUnauthorized}) return @@ -63,7 +62,10 @@ func (l *LogsServer) Export(ctx context.Context, req *collectorlogs.ExportLogsSe } apicfg := l.router.Config.GetAccessKeyConfig() - keyToUse, err := apicfg.CheckAndMaybeReplaceKey(ri.ApiKey) + if err := apicfg.IsAccepted(ri.ApiKey); err != nil { + return nil, status.Error(codes.Unauthenticated, err.Error()) + } + keyToUse, err := apicfg.GetReplaceKey(ri.ApiKey) if err != nil { return nil, status.Error(codes.Unauthenticated, err.Error()) diff --git a/route/otlp_logs_test.go b/route/otlp_logs_test.go index 1b01713067..d5bfa34c28 100644 --- a/route/otlp_logs_test.go +++ b/route/otlp_logs_test.go @@ -5,12 +5,14 @@ import ( "compress/gzip" "context" "fmt" + "io" "net/http" "net/http/httptest" "strings" "testing" "time" + "github.com/gorilla/mux" huskyotlp "github.com/honeycombio/husky/otlp" "github.com/honeycombio/refinery/collect" "github.com/honeycombio/refinery/config" @@ -122,16 +124,25 @@ func TestLogsOTLPHandler(t *testing.T) { for _, tC := range testCases { t.Run(tC.name, func(t *testing.T) { - request, err := http.NewRequest("POST", "/v1/traces", anEmptyRequestBody) + muxxer := mux.NewRouter() + muxxer.Use(router.apiKeyProcessor) + router.AddOTLPMuxxer(muxxer) + server := httptest.NewServer(muxxer) + defer server.Close() + + request, err := http.NewRequest("POST", server.URL+"/v1/traces", anEmptyRequestBody) require.NoError(t, err) request.Header = http.Header{} request.Header.Set("content-type", tC.requestContentType) - response := httptest.NewRecorder() - router.postOTLPTrace(response, request) - assert.Equal(t, tC.expectedResponseStatus, response.Code) - assert.Equal(t, tC.expectedResponseContentType, response.Header().Get("content-type")) - assert.Equal(t, tC.expectedResponseBody, response.Body.String()) + resp, err := http.DefaultClient.Do(request) + require.NoError(t, err) + assert.Equal(t, tC.expectedResponseStatus, resp.StatusCode) + assert.Equal(t, tC.expectedResponseContentType, resp.Header.Get("content-type")) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, tC.expectedResponseBody, string(body)) }) } }) @@ -149,15 +160,21 @@ func TestLogsOTLPHandler(t *testing.T) { t.Error(err) } - request, _ := http.NewRequest("POST", "/v1/logs", strings.NewReader(string(body))) + muxxer := mux.NewRouter() + muxxer.Use(router.apiKeyProcessor) + router.AddOTLPMuxxer(muxxer) + server := httptest.NewServer(muxxer) + defer server.Close() + + request, _ := http.NewRequest("POST", server.URL+"/v1/logs", strings.NewReader(string(body))) request.Header = http.Header{} request.Header.Set("content-type", "application/protobuf") request.Header.Set("x-honeycomb-team", legacyAPIKey) request.Header.Set("x-honeycomb-dataset", "dataset") - w := httptest.NewRecorder() - router.postOTLPLogs(w, request) - assert.Equal(t, w.Code, http.StatusOK) + resp, err := http.DefaultClient.Do(request) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) events := mockTransmission.GetBlock(1) assert.Equal(t, 1, len(events)) @@ -289,16 +306,25 @@ func TestLogsOTLPHandler(t *testing.T) { t.Error(err) } - request, _ := http.NewRequest("POST", "/v1/logs", bytes.NewReader(body)) + muxxer := mux.NewRouter() + muxxer.Use(router.apiKeyProcessor) + router.AddOTLPMuxxer(muxxer) + server := httptest.NewServer(muxxer) + defer server.Close() + + request, _ := http.NewRequest("POST", server.URL+"/v1/logs", bytes.NewReader(body)) request.Header = http.Header{} request.Header.Set("content-type", "application/json") request.Header.Set("x-honeycomb-team", legacyAPIKey) request.Header.Set("x-honeycomb-dataset", "dataset") - w := httptest.NewRecorder() - router.postOTLPLogs(w, request) - assert.Equal(t, http.StatusUnauthorized, w.Code) - assert.Contains(t, w.Body.String(), "not found in list of authorized keys") + resp, err := http.DefaultClient.Do(request) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Contains(t, string(respBody), "not found in list of authorized keys") events := mockTransmission.GetBlock(0) assert.Equal(t, 0, len(events)) diff --git a/route/otlp_trace.go b/route/otlp_trace.go index 8fcbdaae7c..f430c9622b 100644 --- a/route/otlp_trace.go +++ b/route/otlp_trace.go @@ -25,7 +25,7 @@ func (r *Router) postOTLPTrace(w http.ResponseWriter, req *http.Request) { } apicfg := r.Config.GetAccessKeyConfig() - keyToUse, err := apicfg.CheckAndMaybeReplaceKey(ri.ApiKey) + keyToUse, err := apicfg.GetReplaceKey(ri.ApiKey) if err != nil { r.handleOTLPFailureResponse(w, req, huskyotlp.OTLPError{Message: err.Error(), HTTPStatusCode: http.StatusUnauthorized}) @@ -63,7 +63,11 @@ func (t *TraceServer) Export(ctx context.Context, req *collectortrace.ExportTrac } apicfg := t.router.Config.GetAccessKeyConfig() - keyToUse, err := apicfg.CheckAndMaybeReplaceKey(ri.ApiKey) + if err := apicfg.IsAccepted(ri.ApiKey); err != nil { + return nil, status.Error(codes.Unauthenticated, err.Error()) + } + + keyToUse, err := apicfg.GetReplaceKey(ri.ApiKey) if err != nil { return nil, status.Error(codes.Unauthenticated, err.Error()) diff --git a/route/otlp_trace_test.go b/route/otlp_trace_test.go index 7f9382108c..abcea1c7e5 100644 --- a/route/otlp_trace_test.go +++ b/route/otlp_trace_test.go @@ -6,12 +6,14 @@ import ( "context" "encoding/hex" "fmt" + "io" "net/http" "net/http/httptest" "strings" "testing" "time" + "github.com/gorilla/mux" huskyotlp "github.com/honeycombio/husky/otlp" "github.com/honeycombio/refinery/config" "github.com/honeycombio/refinery/logger" @@ -140,8 +142,6 @@ func TestOTLPHandler(t *testing.T) { t.Errorf(`Unexpected error: %s`, err) } - time.Sleep(conf.GetTracesConfigVal.GetSendTickerValue() * 2) - events := mockTransmission.GetBlock(2) assert.Equal(t, 2, len(events)) @@ -188,8 +188,6 @@ func TestOTLPHandler(t *testing.T) { t.Errorf(`Unexpected error: %s`, err) } - time.Sleep(conf.GetTracesConfigVal.GetSendTickerValue() * 2) - events := mockTransmission.GetBlock(2) assert.Equal(t, 2, len(events)) @@ -240,16 +238,26 @@ func TestOTLPHandler(t *testing.T) { for _, tC := range testCases { t.Run(tC.name, func(t *testing.T) { - request, err := http.NewRequest("POST", "/v1/traces", anEmptyRequestBody) + muxxer := mux.NewRouter() + muxxer.Use(router.apiKeyProcessor) + router.AddOTLPMuxxer(muxxer) + server := httptest.NewServer(muxxer) + defer server.Close() + + request, err := http.NewRequest("POST", server.URL+"/v1/traces", anEmptyRequestBody) require.NoError(t, err) request.Header = http.Header{} request.Header.Set("content-type", tC.requestContentType) - response := httptest.NewRecorder() - router.postOTLPTrace(response, request) - assert.Equal(t, tC.expectedResponseStatus, response.Code) - assert.Equal(t, tC.expectedResponseContentType, response.Header().Get("content-type")) - assert.Equal(t, tC.expectedResponseBody, response.Body.String()) + resp, err := http.DefaultClient.Do(request) + require.NoError(t, err) + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, tC.expectedResponseStatus, resp.StatusCode) + assert.Equal(t, tC.expectedResponseContentType, resp.Header.Get("content-type")) + assert.Equal(t, tC.expectedResponseBody, string(respBody)) }) } }) @@ -469,16 +477,25 @@ func TestOTLPHandler(t *testing.T) { t.Error(err) } - request, _ := http.NewRequest("POST", "/v1/traces", bytes.NewReader(body)) + muxxer := mux.NewRouter() + muxxer.Use(router.apiKeyProcessor) + router.AddOTLPMuxxer(muxxer) + server := httptest.NewServer(muxxer) + defer server.Close() + + request, _ := http.NewRequest("POST", server.URL+"/v1/traces", bytes.NewReader(body)) request.Header = http.Header{} request.Header.Set("content-type", "application/json") request.Header.Set("x-honeycomb-team", legacyAPIKey) request.Header.Set("x-honeycomb-dataset", "dataset") - w := httptest.NewRecorder() - router.postOTLPTrace(w, request) - assert.Equal(t, http.StatusUnauthorized, w.Code) - assert.Contains(t, w.Body.String(), "not found in list of authorized keys") + resp, err := http.DefaultClient.Do(request) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Contains(t, string(respBody), "not found in list of authorized keys") events := mockTransmission.GetBlock(0) assert.Equal(t, 0, len(events)) diff --git a/route/route.go b/route/route.go index e812c6f696..d1a31f36d9 100644 --- a/route/route.go +++ b/route/route.go @@ -996,6 +996,7 @@ func (r *Router) startGRPCHealthMonitor() { func (r *Router) AddOTLPMuxxer(muxxer *mux.Router) { // require an auth header for OTLP requests otlpMuxxer := muxxer.PathPrefix("/v1/").Methods("POST").Subrouter() + otlpMuxxer.Use(r.apiKeyProcessor) // handle OTLP trace requests otlpMuxxer.HandleFunc("/traces", r.postOTLPTrace).Name("otlp_traces") diff --git a/route/route_test.go b/route/route_test.go index 1a2de4bd6d..f6ab3e6c35 100644 --- a/route/route_test.go +++ b/route/route_test.go @@ -353,6 +353,7 @@ func TestOTLPRequest(t *testing.T) { } muxxer := mux.NewRouter() + muxxer.Use(router.apiKeyProcessor) router.AddOTLPMuxxer(muxxer) server := httptest.NewServer(muxxer) defer server.Close()