Skip to content

Commit

Permalink
Merge branch 'main' into mike/instance-id
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeGoldsmith authored Nov 8, 2024
2 parents 7d4d19d + 909ef50 commit 80f1ffc
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 66 deletions.
15 changes: 15 additions & 0 deletions collect/collect.go
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,19 @@ func (i *InMemCollector) sendTraces() {
for t := range i.outgoingTraces {
i.Metrics.Histogram("collector_outgoing_queue", float64(len(i.outgoingTraces)))
_, span := otelutil.StartSpanMulti(context.Background(), i.Tracer, "sendTrace", map[string]interface{}{"num_spans": t.DescendantCount(), "outgoingTraces_size": len(i.outgoingTraces)})

// if we have a key replacement rule, we should
// replace the key with the new key
keycfg := i.Config.GetAccessKeyConfig()
overwriteWith, err := keycfg.GetReplaceKey(t.APIKey)
if err != nil {
i.Logger.Warn().Logf("error replacing key: %s", err.Error())
continue
}
if overwriteWith != t.APIKey {
t.APIKey = overwriteWith
}

for _, sp := range t.GetSpans() {
if sp.IsDecisionSpan() {
continue
Expand Down Expand Up @@ -1318,6 +1331,8 @@ func (i *InMemCollector) sendTraces() {
}
mergeTraceAndSpanSampleRates(sp, t.SampleRate(), isDryRun)
i.addAdditionalAttributes(sp)

sp.APIKey = t.APIKey
i.Transmission.EnqueueSpan(sp)
}
span.End()
Expand Down
6 changes: 6 additions & 0 deletions collect/collect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ func TestAddRootSpan(t *testing.T) {
GetCollectionConfigVal: config.CollectionConfig{
ShutdownDelay: config.Duration(1 * time.Millisecond),
},
GetAccessKeyConfigVal: config.AccessKeyConfig{
SendKey: "another-key",
SendKeyMode: "all",
},
}
transmission := &transmit.MockTransmission{}
transmission.Start()
Expand Down Expand Up @@ -166,6 +170,7 @@ func TestAddRootSpan(t *testing.T) {
events := transmission.GetBlock(1)
require.Equal(t, 1, len(events), "adding a root span should send the span")
assert.Equal(t, "aoeu", events[0].Dataset, "sending a root span should immediately send that span via transmission")
assert.Equal(t, "another-key", events[0].APIKey, "api key should be replaced with the send key")

assert.Nil(t, coll.getFromCache(traceID1), "after sending the span, it should be removed from the cache")

Expand All @@ -185,6 +190,7 @@ func TestAddRootSpan(t *testing.T) {
events = transmission.GetBlock(1)
require.Equal(t, 1, len(events), "adding another root span should send the span")
assert.Equal(t, "aoeu", events[0].Dataset, "sending a root span should immediately send that span via transmission")
assert.Equal(t, "another-key", events[0].APIKey, "api key should be replaced with the send key")

assert.Nil(t, coll.getFromCache(traceID1), "after sending the span, it should be removed from the cache")

Expand Down
24 changes: 13 additions & 11 deletions config/file_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,23 @@ type AccessKeyConfig struct {
AcceptOnlyListedKeys bool `yaml:"AcceptOnlyListedKeys"`
}

// truncate the key to 8 characters for logging
func (a *AccessKeyConfig) sanitize(key string) string {
return fmt.Sprintf("%.8s...", key)
// IsAccepted checks if the given key is in the list of accepted keys.
// if the key is not in the list, it returns an error with the key truncated to 8 characters for logging.
func (a *AccessKeyConfig) IsAccepted(key string) error {
if a.AcceptOnlyListedKeys {
if slices.Contains(a.ReceiveKeys, key) {
return nil
}

return fmt.Errorf("api key %.8s... not found in list of authorized keys", key)
}
return nil
}

// CheckAndMaybeReplaceKey checks the given API key against the configuration
// GetReplaceKey checks the given API key against the configuration
// and possibly replaces it with the configured SendKey, if the settings so indicate.
// It returns the key to use, or an error if the key is invalid given the settings.
func (a *AccessKeyConfig) CheckAndMaybeReplaceKey(apiKey string) (string, error) {
// Apply AcceptOnlyListedKeys logic BEFORE we consider replacement
if a.AcceptOnlyListedKeys && !slices.Contains(a.ReceiveKeys, apiKey) {
err := fmt.Errorf("api key %s not found in list of authorized keys", a.sanitize(apiKey))
return "", err
}

func (a *AccessKeyConfig) GetReplaceKey(apiKey string) (string, error) {
if a.SendKey != "" {
overwriteWith := ""
switch a.SendKeyMode {
Expand Down
57 changes: 45 additions & 12 deletions config/file_config_test.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
package config

import "testing"
import (
"errors"
"testing"

func TestAccessKeyConfig_CheckAndMaybeReplaceKey(t *testing.T) {
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAccessKeyConfig_GetReplaceKey(t *testing.T) {
type fields struct {
ReceiveKeys []string
SendKey string
SendKeyMode string
AcceptOnlyListedKeys bool
}

fNone := fields{}
fRcvAccept := fields{
ReceiveKeys: []string{"key1", "key2"},
AcceptOnlyListedKeys: true,
}
fSendAll := fields{
ReceiveKeys: []string{"key1", "key2"},
SendKey: "sendkey",
Expand Down Expand Up @@ -43,10 +44,6 @@ func TestAccessKeyConfig_CheckAndMaybeReplaceKey(t *testing.T) {
want string
wantErr bool
}{
{"empty", fNone, "userkey", "userkey", false},
{"acceptonly known key", fRcvAccept, "key1", "key1", false},
{"acceptonly unknown key", fRcvAccept, "badkey", "", true},
{"acceptonly missing key", fRcvAccept, "", "", true},
{"send all known", fSendAll, "key1", "sendkey", false},
{"send all unknown", fSendAll, "userkey", "sendkey", false},
{"send all missing", fSendAll, "", "sendkey", false},
Expand All @@ -68,7 +65,7 @@ func TestAccessKeyConfig_CheckAndMaybeReplaceKey(t *testing.T) {
SendKeyMode: tt.fields.SendKeyMode,
AcceptOnlyListedKeys: tt.fields.AcceptOnlyListedKeys,
}
got, err := a.CheckAndMaybeReplaceKey(tt.apiKey)
got, err := a.GetReplaceKey(tt.apiKey)
if (err != nil) != tt.wantErr {
t.Errorf("AccessKeyConfig.CheckAndMaybeReplaceKey() error = %v, wantErr %v", err, tt.wantErr)
return
Expand All @@ -79,3 +76,39 @@ func TestAccessKeyConfig_CheckAndMaybeReplaceKey(t *testing.T) {
})
}
}

func TestAccessKeyConfig_IsAccepted(t *testing.T) {
type fields struct {
ReceiveKeys []string
SendKey string
SendKeyMode string
AcceptOnlyListedKeys bool
}
tests := []struct {
name string
fields fields
key string
want error
}{
{"no keys", fields{}, "key1", nil},
{"known key", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true}, "key1", nil},
{"unknown key", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true}, "key2", errors.New("api key key2... not found in list of authorized keys")},
{"accept missing key", fields{ReceiveKeys: []string{"key1"}, AcceptOnlyListedKeys: true}, "", errors.New("api key ... not found in list of authorized keys")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &AccessKeyConfig{
ReceiveKeys: tt.fields.ReceiveKeys,
SendKey: tt.fields.SendKey,
SendKeyMode: tt.fields.SendKeyMode,
AcceptOnlyListedKeys: tt.fields.AcceptOnlyListedKeys,
}
err := a.IsAccepted(tt.key)
if tt.want == nil {
require.NoError(t, err)
return
}
assert.Equal(t, tt.want.Error(), err.Error())
})
}
}
1 change: 1 addition & 0 deletions route/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ var (
ErrJSONBuildFailed = handlerError{nil, "failed to build JSON response", http.StatusInternalServerError, false, true}
ErrPostBody = handlerError{nil, "failed to read request body", http.StatusInternalServerError, false, false}
ErrAuthNeeded = handlerError{nil, "unknown API key - check your credentials", http.StatusBadRequest, true, true}
ErrAuthInvalid = handlerError{nil, "invalid API key - check your credentials", http.StatusUnauthorized, true, true}
ErrConfigReadFailed = handlerError{nil, "failed to read config", http.StatusBadRequest, false, false}
ErrUpstreamFailed = handlerError{nil, "failed to create upstream request", http.StatusServiceUnavailable, true, true}
ErrUpstreamUnavailable = handlerError{nil, "upstream target unavailable", http.StatusServiceUnavailable, true, true}
Expand Down
12 changes: 4 additions & 8 deletions route/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

// for generating request IDs
func init() {
rand.Seed(time.Now().UnixNano())
rand.New(rand.NewSource(time.Now().UnixNano()))
}

func (r *Router) queryTokenChecker(next http.Handler) http.Handler {
Expand Down Expand Up @@ -45,15 +45,11 @@ func (r *Router) apiKeyProcessor(next http.Handler) http.Handler {
}

keycfg := r.Config.GetAccessKeyConfig()

overwriteWith, err := keycfg.CheckAndMaybeReplaceKey(apiKey)
if err != nil {
r.handlerReturnWithError(w, ErrAuthNeeded, err)
if err := keycfg.IsAccepted(apiKey); err != nil {
r.handlerReturnWithError(w, ErrAuthInvalid, err)
return
}
if overwriteWith != apiKey {
req.Header.Set(types.APIKeyHeader, overwriteWith)
}

next.ServeHTTP(w, req)
})
}
Expand Down
8 changes: 5 additions & 3 deletions route/otlp_logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ func (r *Router) postOTLPLogs(w http.ResponseWriter, req *http.Request) {
}

apicfg := r.Config.GetAccessKeyConfig()
keyToUse, err := apicfg.CheckAndMaybeReplaceKey(ri.ApiKey)

keyToUse, err := apicfg.GetReplaceKey(ri.ApiKey)
if err != nil {
r.handleOTLPFailureResponse(w, req, huskyotlp.OTLPError{Message: err.Error(), HTTPStatusCode: http.StatusUnauthorized})
return
Expand Down Expand Up @@ -63,7 +62,10 @@ func (l *LogsServer) Export(ctx context.Context, req *collectorlogs.ExportLogsSe
}

apicfg := l.router.Config.GetAccessKeyConfig()
keyToUse, err := apicfg.CheckAndMaybeReplaceKey(ri.ApiKey)
if err := apicfg.IsAccepted(ri.ApiKey); err != nil {
return nil, status.Error(codes.Unauthenticated, err.Error())
}
keyToUse, err := apicfg.GetReplaceKey(ri.ApiKey)

if err != nil {
return nil, status.Error(codes.Unauthenticated, err.Error())
Expand Down
56 changes: 41 additions & 15 deletions route/otlp_logs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/gorilla/mux"
huskyotlp "github.com/honeycombio/husky/otlp"
"github.com/honeycombio/refinery/collect"
"github.com/honeycombio/refinery/config"
Expand Down Expand Up @@ -122,16 +124,25 @@ func TestLogsOTLPHandler(t *testing.T) {

for _, tC := range testCases {
t.Run(tC.name, func(t *testing.T) {
request, err := http.NewRequest("POST", "/v1/traces", anEmptyRequestBody)
muxxer := mux.NewRouter()
muxxer.Use(router.apiKeyProcessor)
router.AddOTLPMuxxer(muxxer)
server := httptest.NewServer(muxxer)
defer server.Close()

request, err := http.NewRequest("POST", server.URL+"/v1/traces", anEmptyRequestBody)
require.NoError(t, err)
request.Header = http.Header{}
request.Header.Set("content-type", tC.requestContentType)
response := httptest.NewRecorder()
router.postOTLPTrace(response, request)

assert.Equal(t, tC.expectedResponseStatus, response.Code)
assert.Equal(t, tC.expectedResponseContentType, response.Header().Get("content-type"))
assert.Equal(t, tC.expectedResponseBody, response.Body.String())
resp, err := http.DefaultClient.Do(request)
require.NoError(t, err)
assert.Equal(t, tC.expectedResponseStatus, resp.StatusCode)
assert.Equal(t, tC.expectedResponseContentType, resp.Header.Get("content-type"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)

assert.Equal(t, tC.expectedResponseBody, string(body))
})
}
})
Expand All @@ -149,15 +160,21 @@ func TestLogsOTLPHandler(t *testing.T) {
t.Error(err)
}

request, _ := http.NewRequest("POST", "/v1/logs", strings.NewReader(string(body)))
muxxer := mux.NewRouter()
muxxer.Use(router.apiKeyProcessor)
router.AddOTLPMuxxer(muxxer)
server := httptest.NewServer(muxxer)
defer server.Close()

request, _ := http.NewRequest("POST", server.URL+"/v1/logs", strings.NewReader(string(body)))
request.Header = http.Header{}
request.Header.Set("content-type", "application/protobuf")
request.Header.Set("x-honeycomb-team", legacyAPIKey)
request.Header.Set("x-honeycomb-dataset", "dataset")

w := httptest.NewRecorder()
router.postOTLPLogs(w, request)
assert.Equal(t, w.Code, http.StatusOK)
resp, err := http.DefaultClient.Do(request)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)

events := mockTransmission.GetBlock(1)
assert.Equal(t, 1, len(events))
Expand Down Expand Up @@ -289,16 +306,25 @@ func TestLogsOTLPHandler(t *testing.T) {
t.Error(err)
}

request, _ := http.NewRequest("POST", "/v1/logs", bytes.NewReader(body))
muxxer := mux.NewRouter()
muxxer.Use(router.apiKeyProcessor)
router.AddOTLPMuxxer(muxxer)
server := httptest.NewServer(muxxer)
defer server.Close()

request, _ := http.NewRequest("POST", server.URL+"/v1/logs", bytes.NewReader(body))
request.Header = http.Header{}
request.Header.Set("content-type", "application/json")
request.Header.Set("x-honeycomb-team", legacyAPIKey)
request.Header.Set("x-honeycomb-dataset", "dataset")

w := httptest.NewRecorder()
router.postOTLPLogs(w, request)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "not found in list of authorized keys")
resp, err := http.DefaultClient.Do(request)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)

assert.Contains(t, string(respBody), "not found in list of authorized keys")

events := mockTransmission.GetBlock(0)
assert.Equal(t, 0, len(events))
Expand Down
8 changes: 6 additions & 2 deletions route/otlp_trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (r *Router) postOTLPTrace(w http.ResponseWriter, req *http.Request) {
}

apicfg := r.Config.GetAccessKeyConfig()
keyToUse, err := apicfg.CheckAndMaybeReplaceKey(ri.ApiKey)
keyToUse, err := apicfg.GetReplaceKey(ri.ApiKey)

if err != nil {
r.handleOTLPFailureResponse(w, req, huskyotlp.OTLPError{Message: err.Error(), HTTPStatusCode: http.StatusUnauthorized})
Expand Down Expand Up @@ -63,7 +63,11 @@ func (t *TraceServer) Export(ctx context.Context, req *collectortrace.ExportTrac
}

apicfg := t.router.Config.GetAccessKeyConfig()
keyToUse, err := apicfg.CheckAndMaybeReplaceKey(ri.ApiKey)
if err := apicfg.IsAccepted(ri.ApiKey); err != nil {
return nil, status.Error(codes.Unauthenticated, err.Error())
}

keyToUse, err := apicfg.GetReplaceKey(ri.ApiKey)

if err != nil {
return nil, status.Error(codes.Unauthenticated, err.Error())
Expand Down
Loading

0 comments on commit 80f1ffc

Please sign in to comment.