From a221e9daa254deb1c68cd552db4f9a2c436adaed Mon Sep 17 00:00:00 2001 From: pulak-opti <129880418+pulak-opti@users.noreply.github.com> Date: Wed, 1 Nov 2023 22:28:24 +0600 Subject: [PATCH] fix: fix context propagation in distributed tracing (#402) --- cmd/optimizely/main.go | 7 ++-- config.yaml | 7 ++-- config/config.go | 9 +++--- pkg/middleware/trace.go | 63 +++--------------------------------- pkg/middleware/trace_test.go | 47 +-------------------------- pkg/routers/api.go | 33 +++++++++---------- pkg/routers/api_test.go | 12 +++---- 7 files changed, 38 insertions(+), 140 deletions(-) diff --git a/cmd/optimizely/main.go b/cmd/optimizely/main.go index 7586e7dd..1e64396d 100644 --- a/cmd/optimizely/main.go +++ b/cmd/optimizely/main.go @@ -34,7 +34,6 @@ import ( "github.com/optimizely/agent/config" "github.com/optimizely/agent/pkg/metrics" - "github.com/optimizely/agent/pkg/middleware" "github.com/optimizely/agent/pkg/optimizely" "github.com/optimizely/agent/pkg/routers" "github.com/optimizely/agent/pkg/server" @@ -50,6 +49,7 @@ import ( "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" + "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/sdk/resource" sdktrace "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.4.0" @@ -151,7 +151,6 @@ func getStdOutTraceProvider(conf config.OTELTracingConfig) (*sdktrace.TracerProv return sdktrace.NewTracerProvider( sdktrace.WithBatcher(exp), sdktrace.WithResource(res), - sdktrace.WithIDGenerator(middleware.NewTraceIDGenerator(conf.TraceIDHeaderKey)), ), nil } @@ -199,7 +198,6 @@ func getRemoteTraceProvider(conf config.OTELTracingConfig) (*sdktrace.TracerProv sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(conf.Services.Remote.SampleRate))), sdktrace.WithResource(res), sdktrace.WithSpanProcessor(bsp), - sdktrace.WithIDGenerator(middleware.NewTraceIDGenerator(conf.TraceIDHeaderKey)), ), nil } @@ -246,6 +244,7 @@ func main() { } }() otel.SetTracerProvider(tp) + otel.SetTextMapPropagator(propagation.TraceContext{}) log.Info().Msg(fmt.Sprintf("Tracing enabled with service %q", conf.Tracing.OpenTelemetry.Default)) } else { log.Info().Msg("Tracing disabled") @@ -275,7 +274,7 @@ func main() { cancel() }() - apiRouter := routers.NewDefaultAPIRouter(optlyCache, *conf, agentMetricsRegistry) + apiRouter := routers.NewDefaultAPIRouter(optlyCache, conf.API, agentMetricsRegistry) adminRouter := routers.NewAdminRouter(*conf) log.Info().Str("version", conf.Version).Msg("Starting services.") diff --git a/config.yaml b/config.yaml index 56f8db7c..b8b85779 100644 --- a/config.yaml +++ b/config.yaml @@ -29,6 +29,9 @@ log: ## ## tracing: tracing configuration ## +## For distributed tracing, trace context should be sent on "traceparent" header +## The value set in HTTP Header must be a hex compliant with the W3C trace-context specification. +## See more at https://www.w3.org/TR/trace-context/#trace-id tracing: ## bydefault tracing is disabled ## to enable tracing set enabled to true @@ -43,10 +46,6 @@ tracing: ## tracing environment name ## example: for production environment env can be set as "prod" env: "dev" - ## HTTP Header Key for TraceID in Distributed Tracing - ## The value set in HTTP Header must be a hex compliant with the W3C trace-context specification. - ## See more at https://www.w3.org/TR/trace-context/#trace-id - traceIDHeaderKey: "X-Optimizely-Trace-ID" ## tracing service configuration services: ## stdout exporter configuration diff --git a/config/config.go b/config/config.go index b0412dc5..4155b3ef 100644 --- a/config/config.go +++ b/config/config.go @@ -197,11 +197,10 @@ const ( ) type OTELTracingConfig struct { - Default TracingServiceType `json:"default"` - ServiceName string `json:"serviceName"` - Env string `json:"env"` - TraceIDHeaderKey string `json:"traceIDHeaderKey"` - Services TracingServiceConfig `json:"services"` + Default TracingServiceType `json:"default"` + ServiceName string `json:"serviceName"` + Env string `json:"env"` + Services TracingServiceConfig `json:"services"` } type TracingServiceConfig struct { diff --git a/pkg/middleware/trace.go b/pkg/middleware/trace.go index 6caf33d3..e5270230 100644 --- a/pkg/middleware/trace.go +++ b/pkg/middleware/trace.go @@ -18,75 +18,22 @@ package middleware import ( - "context" - crand "crypto/rand" - "encoding/binary" - "math/rand" "net/http" - "sync" "github.com/go-chi/chi/v5/middleware" - "github.com/optimizely/agent/config" - "github.com/rs/zerolog/log" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/propagation" semconv "go.opentelemetry.io/otel/semconv/v1.4.0" - "go.opentelemetry.io/otel/trace" ) -type traceIDGenerator struct { - sync.Mutex - randSource *rand.Rand - traceIDHeaderKey string -} - -func NewTraceIDGenerator(traceIDHeaderKey string) *traceIDGenerator { - var rngSeed int64 - _ = binary.Read(crand.Reader, binary.LittleEndian, &rngSeed) - return &traceIDGenerator{ - randSource: rand.New(rand.NewSource(rngSeed)), - traceIDHeaderKey: traceIDHeaderKey, - } -} - -func (gen *traceIDGenerator) NewSpanID(ctx context.Context, traceID trace.TraceID) trace.SpanID { - gen.Lock() - defer gen.Unlock() - sid := trace.SpanID{} - _, _ = gen.randSource.Read(sid[:]) - return sid -} - -func (gen *traceIDGenerator) NewIDs(ctx context.Context) (trace.TraceID, trace.SpanID) { - gen.Lock() - defer gen.Unlock() - tid := trace.TraceID{} - _, _ = gen.randSource.Read(tid[:]) - sid := trace.SpanID{} - _, _ = gen.randSource.Read(sid[:]) - - // read trace id from header if provided - traceIDHeader := ctx.Value(gen.traceIDHeaderKey) - if val, ok := traceIDHeader.(string); ok { - if val != "" { - headerTraceId, err := trace.TraceIDFromHex(val) - if err == nil { - tid = headerTraceId - } else { - log.Error().Err(err).Msg("failed to parse trace id from header, invalid trace id") - } - } - } - - return tid, sid -} - -func AddTracing(conf config.TracingConfig, tracerName, spanName string) func(http.Handler) http.Handler { +func AddTracing(tracerName, spanName string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - pctx := context.WithValue(r.Context(), conf.OpenTelemetry.TraceIDHeaderKey, r.Header.Get(conf.OpenTelemetry.TraceIDHeaderKey)) + prop := otel.GetTextMapPropagator() + propCtx := prop.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) - ctx, span := otel.Tracer(tracerName).Start(pctx, spanName) + ctx, span := otel.Tracer(tracerName).Start(propCtx, spanName) defer span.End() span.SetAttributes( diff --git a/pkg/middleware/trace_test.go b/pkg/middleware/trace_test.go index 7c747947..91e15e6d 100644 --- a/pkg/middleware/trace_test.go +++ b/pkg/middleware/trace_test.go @@ -18,14 +18,9 @@ package middleware import ( - "context" "net/http" "net/http/httptest" "testing" - - "github.com/optimizely/agent/config" - "github.com/stretchr/testify/assert" - "go.opentelemetry.io/otel/trace" ) func TestAddTracing(t *testing.T) { @@ -37,7 +32,7 @@ func TestAddTracing(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() - middleware := http.Handler(AddTracing(config.TracingConfig{}, "test-tracer", "test-span")(handler)) + middleware := http.Handler(AddTracing("test-tracer", "test-span")(handler)) // Serve the request through the middleware middleware.ServeHTTP(rr, req) @@ -54,43 +49,3 @@ func TestAddTracing(t *testing.T) { t.Errorf("Expected Content-Type header %v, but got %v", "application/text", typeHeader) } } - -func TestNewIDs(t *testing.T) { - gen := NewTraceIDGenerator("") - n := 1000 - - for i := 0; i < n; i++ { - traceID, spanID := gen.NewIDs(context.Background()) - assert.Truef(t, traceID.IsValid(), "trace id: %s", traceID.String()) - assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String()) - } -} - -func TestNewSpanID(t *testing.T) { - gen := NewTraceIDGenerator("") - testTraceID := [16]byte{123, 123} - n := 1000 - - for i := 0; i < n; i++ { - spanID := gen.NewSpanID(context.Background(), testTraceID) - assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String()) - } -} - -func TestNewSpanIDWithInvalidTraceID(t *testing.T) { - gen := NewTraceIDGenerator("") - spanID := gen.NewSpanID(context.Background(), trace.TraceID{}) - assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String()) -} - -func TestTraceIDWithGivenHeaderValue(t *testing.T) { - traceHeader := "X-Trace-ID" - traceID := "9b8eac67e332c6f8baf1e013de6891bb" - - gen := NewTraceIDGenerator(traceHeader) - - ctx := context.WithValue(context.Background(), traceHeader, traceID) - genTraceID, _ := gen.NewIDs(ctx) - assert.Truef(t, genTraceID.IsValid(), "trace id: %s", genTraceID.String()) - assert.Equal(t, traceID, genTraceID.String()) -} diff --git a/pkg/routers/api.go b/pkg/routers/api.go index 88897df6..b7d2aac5 100644 --- a/pkg/routers/api.go +++ b/pkg/routers/api.go @@ -63,8 +63,7 @@ func forbiddenHandler(message string) http.HandlerFunc { } // NewDefaultAPIRouter creates a new router with the default backing optimizely.Cache -func NewDefaultAPIRouter(optlyCache optimizely.Cache, agentConf config.AgentConfig, metricsRegistry *metrics.Registry) http.Handler { - conf := agentConf.API +func NewDefaultAPIRouter(optlyCache optimizely.Cache, conf config.APIConfig, metricsRegistry *metrics.Registry) http.Handler { authProvider := middleware.NewAuth(&conf.Auth) if authProvider == nil { log.Error().Msg("unable to initialize api auth middleware.") @@ -109,19 +108,19 @@ func NewDefaultAPIRouter(optlyCache optimizely.Cache, agentConf config.AgentConf corsHandler: corsHandler, } - return NewAPIRouter(spec, agentConf.Tracing) + return NewAPIRouter(spec) } // NewAPIRouter returns HTTP API router backed by an optimizely.Cache implementation -func NewAPIRouter(opt *APIOptions, traceConf config.TracingConfig) *chi.Mux { +func NewAPIRouter(opt *APIOptions) *chi.Mux { r := chi.NewRouter() - WithAPIRouter(opt, r, traceConf) + WithAPIRouter(opt, r) return r } // WithAPIRouter appends routes and middleware to the given router. // See https://godoc.org/github.com/go-chi/chi/v5#Mux.Group for usage -func WithAPIRouter(opt *APIOptions, r chi.Router, traceConf config.TracingConfig) { +func WithAPIRouter(opt *APIOptions, r chi.Router) { getConfigTimer := middleware.Metricize("get-config", opt.metricsRegistry) getDatafileTimer := middleware.Metricize("get-datafile", opt.metricsRegistry) activateTimer := middleware.Metricize("activate", opt.metricsRegistry) @@ -134,17 +133,17 @@ func WithAPIRouter(opt *APIOptions, r chi.Router, traceConf config.TracingConfig createAccesstokenTimer := middleware.Metricize("create-api-access-token", opt.metricsRegistry) contentTypeMiddleware := chimw.AllowContentType("application/json") - configTracer := middleware.AddTracing(traceConf, "configHandler", "OptimizelyConfig") - datafileTracer := middleware.AddTracing(traceConf, "datafileHandler", "OptimizelyDatafile") - activateTracer := middleware.AddTracing(traceConf, "activateHandler", "Activate") - decideTracer := middleware.AddTracing(traceConf, "decideHandler", "Decide") - trackTracer := middleware.AddTracing(traceConf, "trackHandler", "Track") - overrideTracer := middleware.AddTracing(traceConf, "overrideHandler", "Override") - lookupTracer := middleware.AddTracing(traceConf, "lookupHandler", "Lookup") - saveTracer := middleware.AddTracing(traceConf, "saveHandler", "Save") - sendOdpEventTracer := middleware.AddTracing(traceConf, "sendOdpEventHandler", "SendOdpEvent") - nStreamTracer := middleware.AddTracing(traceConf, "notificationHandler", "SendNotificationEvent") - authTracer := middleware.AddTracing(traceConf, "authHandler", "AuthToken") + configTracer := middleware.AddTracing("configHandler", "OptimizelyConfig") + datafileTracer := middleware.AddTracing("datafileHandler", "OptimizelyDatafile") + activateTracer := middleware.AddTracing("activateHandler", "Activate") + decideTracer := middleware.AddTracing("decideHandler", "Decide") + trackTracer := middleware.AddTracing("trackHandler", "Track") + overrideTracer := middleware.AddTracing("overrideHandler", "Override") + lookupTracer := middleware.AddTracing("lookupHandler", "Lookup") + saveTracer := middleware.AddTracing("saveHandler", "Save") + sendOdpEventTracer := middleware.AddTracing("sendOdpEventHandler", "SendOdpEvent") + nStreamTracer := middleware.AddTracing("notificationHandler", "SendNotificationEvent") + authTracer := middleware.AddTracing("authHandler", "AuthToken") if opt.maxConns > 0 { // Note this is NOT a rate limiter, but a concurrency threshold diff --git a/pkg/routers/api_test.go b/pkg/routers/api_test.go index 9196d880..bc7776e9 100644 --- a/pkg/routers/api_test.go +++ b/pkg/routers/api_test.go @@ -126,7 +126,7 @@ func (suite *APIV1TestSuite) SetupTest() { corsHandler: testCorsHandler, } - suite.mux = NewAPIRouter(opts, config.TracingConfig{}) + suite.mux = NewAPIRouter(opts) } func (suite *APIV1TestSuite) TestValidRoutes() { @@ -138,7 +138,7 @@ func (suite *APIV1TestSuite) TestValidRoutes() { } return http.HandlerFunc(fn) } - suite.mux = NewAPIRouter(opts, config.TracingConfig{}) + suite.mux = NewAPIRouter(opts) routes := []struct { method string @@ -328,7 +328,7 @@ func TestAPIV1TestSuite(t *testing.T) { } func TestNewDefaultAPIV1Router(t *testing.T) { - client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{}, metricsRegistry) + client := NewDefaultAPIRouter(MockCache{}, config.APIConfig{}, metricsRegistry) assert.NotNil(t, client) } @@ -353,7 +353,7 @@ func TestNewDefaultAPIV1RouterInvalidHandlerConfig(t *testing.T) { EnableNotifications: false, EnableOverrides: false, } - client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{API: invalidAPIConfig}, metricsRegistry) + client := NewDefaultAPIRouter(MockCache{}, invalidAPIConfig, metricsRegistry) assert.Nil(t, client) } @@ -368,12 +368,12 @@ func TestNewDefaultClientRouterInvalidMiddlewareConfig(t *testing.T) { EnableNotifications: false, EnableOverrides: false, } - client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{API: invalidAPIConfig}, metricsRegistry) + client := NewDefaultAPIRouter(MockCache{}, invalidAPIConfig, metricsRegistry) assert.Nil(t, client) } func TestForbiddenRoutes(t *testing.T) { - mux := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{}, metricsRegistry) + mux := NewDefaultAPIRouter(MockCache{}, config.APIConfig{}, metricsRegistry) routes := []struct { method string