From 04dd776e628ee32f160c651ba15d5cc69a6b96f0 Mon Sep 17 00:00:00 2001 From: Sean Marciniak Date: Wed, 3 Jul 2024 14:43:39 +0930 Subject: [PATCH] Ensure HTTP Forwarder v2 settings are set If you're only configuration gostatsd to be configured from in memory values, the existing code will override what is currently set in memory. This will create a default set of values and allow the existing viper configuration to apply its own overrides for existing values. --- pkg/statsd/handler_http_forwarder_v2.go | 61 ++++++++------ pkg/statsd/handler_http_forwarder_v2_test.go | 83 +++++++++++++------- 2 files changed, 92 insertions(+), 52 deletions(-) diff --git a/pkg/statsd/handler_http_forwarder_v2.go b/pkg/statsd/handler_http_forwarder_v2.go index 6dea16ff..f794b643 100644 --- a/pkg/statsd/handler_http_forwarder_v2.go +++ b/pkg/statsd/handler_http_forwarder_v2.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "maps" "net/http" "strings" "sync" @@ -75,34 +76,48 @@ var ( _ healthcheck.DeepCheckProvider = &HttpForwarderHandlerV2{} ) +// newHTTPForwarderHandlerViperConfig defines the set of required values in order +// to create a HTTPForwarderHandlerV2 and applies the existing overrides present +// as part of the configuration passed through. +func newHTTPForwarderHandlerViperConfig(overrides *viper.Viper) *viper.Viper { + values := map[string]any{ + "transport": defaultTransport, + "compress": defaultCompress, + "compression-type": defaultCompressionType, + "compression-level": defaultCompressionLevel, + "api-endpoint": defaultApiEndpoint, + "max-requests": defaultMaxRequests, + "max-request-elapsed-time": defaultMaxRequestElapsedTime, + "consolidator-slots": gostatsd.DefaultMaxParsers, + "flush-interval": defaultConsolidatorFlushInterval, + "concurrent-merge": defaultConcurrentMerge, + } + maps.Copy(values, util.GetSubViper(overrides, "http-transport").AllSettings()) + + v := viper.New() + _ = v.MergeConfigMap(values) + + return v +} + // NewHttpForwarderHandlerV2FromViper returns a new http API client. func NewHttpForwarderHandlerV2FromViper(logger logrus.FieldLogger, v *viper.Viper, pool *transport.TransportPool, fc flush.Coordinator) (*HttpForwarderHandlerV2, error) { - subViper := util.GetSubViper(v, "http-transport") - subViper.SetDefault("transport", defaultTransport) - subViper.SetDefault("compress", defaultCompress) - subViper.SetDefault("compression-type", defaultCompressionType) - subViper.SetDefault("compression-level", defaultCompressionLevel) - subViper.SetDefault("api-endpoint", defaultApiEndpoint) - subViper.SetDefault("max-requests", defaultMaxRequests) - subViper.SetDefault("max-request-elapsed-time", defaultMaxRequestElapsedTime) - subViper.SetDefault("consolidator-slots", v.GetInt(gostatsd.ParamMaxParsers)) - subViper.SetDefault("flush-interval", defaultConsolidatorFlushInterval) - subViper.SetDefault("concurrent-merge", defaultConcurrentMerge) + values := newHTTPForwarderHandlerViperConfig(v) return NewHttpForwarderHandlerV2( logger, - subViper.GetString("transport"), - subViper.GetString("api-endpoint"), - subViper.GetInt("consolidator-slots"), - subViper.GetInt("max-requests"), - subViper.GetInt("concurrent-merge"), - subViper.GetBool("compress"), - subViper.GetString("compression-type"), - subViper.GetInt("compression-level"), - subViper.GetDuration("max-request-elapsed-time"), - subViper.GetDuration("flush-interval"), - subViper.GetStringMapString("custom-headers"), - subViper.GetStringSlice("dynamic-headers"), + values.GetString("transport"), + values.GetString("api-endpoint"), + values.GetInt("consolidator-slots"), + values.GetInt("max-requests"), + values.GetInt("concurrent-merge"), + values.GetBool("compress"), + values.GetString("compression-type"), + values.GetInt("compression-level"), + values.GetDuration("max-request-elapsed-time"), + values.GetDuration("flush-interval"), + values.GetStringMapString("custom-headers"), + values.GetStringSlice("dynamic-headers"), pool, fc, ) diff --git a/pkg/statsd/handler_http_forwarder_v2_test.go b/pkg/statsd/handler_http_forwarder_v2_test.go index aef3d43b..e69a3ad7 100644 --- a/pkg/statsd/handler_http_forwarder_v2_test.go +++ b/pkg/statsd/handler_http_forwarder_v2_test.go @@ -28,6 +28,38 @@ import ( "github.com/atlassian/gostatsd/pkg/web" ) +type testServer struct { + s *httptest.Server + called uint64 + pineappleCount uint64 + derpCount uint64 + derpValue int64 +} + +func newTestServer(tb *testing.T) *testServer { + t := &testServer{} + t.s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddUint64(&t.called, 1) + + buf, err := io.ReadAll(r.Body) + require.NoError(tb, err, "Must not error when reading request") + + var data pb.RawMessageV2 + require.NoError(tb, proto.Unmarshal(buf, &data)) + + counters, ok := data.GetCounters()["pineapples"] + if ok { + atomic.AddUint64(&t.pineappleCount, 1) + val, ok := counters.GetTagMap()["derpinton"] + if ok { + atomic.AddUint64(&t.derpCount, 1) + t.derpValue = val.GetValue() + } + } + })) + return t +} + func TestHttpForwarderDeepCheck(t *testing.T) { t.Parallel() @@ -468,34 +500,27 @@ func TestManualFlush(t *testing.T) { assert.EqualValues(t, 10, atomic.LoadInt64(&ts.derpValue)) } -type testServer struct { - s *httptest.Server - called uint64 - pineappleCount uint64 - derpCount uint64 - derpValue int64 -} - -func newTestServer(tb *testing.T) *testServer { - t := &testServer{} - t.s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddUint64(&t.called, 1) - - buf, err := io.ReadAll(r.Body) - require.NoError(tb, err, "Must not error when reading request") - - var data pb.RawMessageV2 - require.NoError(tb, proto.Unmarshal(buf, &data)) +func TestViperMerges(t *testing.T) { + t.Parallel() - counters, ok := data.GetCounters()["pineapples"] - if ok { - atomic.AddUint64(&t.pineappleCount, 1) - val, ok := counters.GetTagMap()["derpinton"] - if ok { - atomic.AddUint64(&t.derpCount, 1) - t.derpValue = val.GetValue() - } - } - })) - return t + overrides := viper.New() + overrides.SetDefault("http-transport.api-endpoint", "localhost") + + values := newHTTPForwarderHandlerViperConfig(overrides) + assert.Equal( + t, + map[string]any{ + "transport": defaultTransport, + "compress": defaultCompress, + "compression-type": defaultCompressionType, + "compression-level": defaultCompressionLevel, + "api-endpoint": "localhost", + "max-requests": defaultMaxRequests, + "max-request-elapsed-time": defaultMaxRequestElapsedTime, + "consolidator-slots": gostatsd.DefaultMaxParsers, + "flush-interval": defaultConsolidatorFlushInterval, + "concurrent-merge": defaultConcurrentMerge, + }, + values.AllSettings(), + ) }