Skip to content

Commit

Permalink
fix: fix context propagation in distributed tracing (#402)
Browse files Browse the repository at this point in the history
  • Loading branch information
pulak-opti committed Nov 2, 2023
1 parent ddb2e15 commit a221e9d
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 140 deletions.
7 changes: 3 additions & 4 deletions cmd/optimizely/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (

"github.com/optimizely/agent/config"
"github.com/optimizely/agent/pkg/metrics"
"github.com/optimizely/agent/pkg/middleware"
"github.com/optimizely/agent/pkg/optimizely"
"github.com/optimizely/agent/pkg/routers"
"github.com/optimizely/agent/pkg/server"
Expand All @@ -50,6 +49,7 @@ import (
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
Expand Down Expand Up @@ -151,7 +151,6 @@ func getStdOutTraceProvider(conf config.OTELTracingConfig) (*sdktrace.TracerProv
return sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exp),
sdktrace.WithResource(res),
sdktrace.WithIDGenerator(middleware.NewTraceIDGenerator(conf.TraceIDHeaderKey)),
), nil
}

Expand Down Expand Up @@ -199,7 +198,6 @@ func getRemoteTraceProvider(conf config.OTELTracingConfig) (*sdktrace.TracerProv
sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(conf.Services.Remote.SampleRate))),
sdktrace.WithResource(res),
sdktrace.WithSpanProcessor(bsp),
sdktrace.WithIDGenerator(middleware.NewTraceIDGenerator(conf.TraceIDHeaderKey)),
), nil
}

Expand Down Expand Up @@ -246,6 +244,7 @@ func main() {
}
}()
otel.SetTracerProvider(tp)
otel.SetTextMapPropagator(propagation.TraceContext{})
log.Info().Msg(fmt.Sprintf("Tracing enabled with service %q", conf.Tracing.OpenTelemetry.Default))
} else {
log.Info().Msg("Tracing disabled")
Expand Down Expand Up @@ -275,7 +274,7 @@ func main() {
cancel()
}()

apiRouter := routers.NewDefaultAPIRouter(optlyCache, *conf, agentMetricsRegistry)
apiRouter := routers.NewDefaultAPIRouter(optlyCache, conf.API, agentMetricsRegistry)
adminRouter := routers.NewAdminRouter(*conf)

log.Info().Str("version", conf.Version).Msg("Starting services.")
Expand Down
7 changes: 3 additions & 4 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ log:
##
## tracing: tracing configuration
##
## For distributed tracing, trace context should be sent on "traceparent" header
## The value set in HTTP Header must be a hex compliant with the W3C trace-context specification.
## See more at https://www.w3.org/TR/trace-context/#trace-id
tracing:
## bydefault tracing is disabled
## to enable tracing set enabled to true
Expand All @@ -43,10 +46,6 @@ tracing:
## tracing environment name
## example: for production environment env can be set as "prod"
env: "dev"
## HTTP Header Key for TraceID in Distributed Tracing
## The value set in HTTP Header must be a hex compliant with the W3C trace-context specification.
## See more at https://www.w3.org/TR/trace-context/#trace-id
traceIDHeaderKey: "X-Optimizely-Trace-ID"
## tracing service configuration
services:
## stdout exporter configuration
Expand Down
9 changes: 4 additions & 5 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,10 @@ const (
)

type OTELTracingConfig struct {
Default TracingServiceType `json:"default"`
ServiceName string `json:"serviceName"`
Env string `json:"env"`
TraceIDHeaderKey string `json:"traceIDHeaderKey"`
Services TracingServiceConfig `json:"services"`
Default TracingServiceType `json:"default"`
ServiceName string `json:"serviceName"`
Env string `json:"env"`
Services TracingServiceConfig `json:"services"`
}

type TracingServiceConfig struct {
Expand Down
63 changes: 5 additions & 58 deletions pkg/middleware/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,75 +18,22 @@
package middleware

import (
"context"
crand "crypto/rand"
"encoding/binary"
"math/rand"
"net/http"
"sync"

"github.com/go-chi/chi/v5/middleware"
"github.com/optimizely/agent/config"
"github.com/rs/zerolog/log"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"go.opentelemetry.io/otel/trace"
)

type traceIDGenerator struct {
sync.Mutex
randSource *rand.Rand
traceIDHeaderKey string
}

func NewTraceIDGenerator(traceIDHeaderKey string) *traceIDGenerator {
var rngSeed int64
_ = binary.Read(crand.Reader, binary.LittleEndian, &rngSeed)
return &traceIDGenerator{
randSource: rand.New(rand.NewSource(rngSeed)),
traceIDHeaderKey: traceIDHeaderKey,
}
}

func (gen *traceIDGenerator) NewSpanID(ctx context.Context, traceID trace.TraceID) trace.SpanID {
gen.Lock()
defer gen.Unlock()
sid := trace.SpanID{}
_, _ = gen.randSource.Read(sid[:])
return sid
}

func (gen *traceIDGenerator) NewIDs(ctx context.Context) (trace.TraceID, trace.SpanID) {
gen.Lock()
defer gen.Unlock()
tid := trace.TraceID{}
_, _ = gen.randSource.Read(tid[:])
sid := trace.SpanID{}
_, _ = gen.randSource.Read(sid[:])

// read trace id from header if provided
traceIDHeader := ctx.Value(gen.traceIDHeaderKey)
if val, ok := traceIDHeader.(string); ok {
if val != "" {
headerTraceId, err := trace.TraceIDFromHex(val)
if err == nil {
tid = headerTraceId
} else {
log.Error().Err(err).Msg("failed to parse trace id from header, invalid trace id")
}
}
}

return tid, sid
}

func AddTracing(conf config.TracingConfig, tracerName, spanName string) func(http.Handler) http.Handler {
func AddTracing(tracerName, spanName string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
pctx := context.WithValue(r.Context(), conf.OpenTelemetry.TraceIDHeaderKey, r.Header.Get(conf.OpenTelemetry.TraceIDHeaderKey))
prop := otel.GetTextMapPropagator()
propCtx := prop.Extract(r.Context(), propagation.HeaderCarrier(r.Header))

ctx, span := otel.Tracer(tracerName).Start(pctx, spanName)
ctx, span := otel.Tracer(tracerName).Start(propCtx, spanName)
defer span.End()

span.SetAttributes(
Expand Down
47 changes: 1 addition & 46 deletions pkg/middleware/trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@
package middleware

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/optimizely/agent/config"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/trace"
)

func TestAddTracing(t *testing.T) {
Expand All @@ -37,7 +32,7 @@ func TestAddTracing(t *testing.T) {

req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
middleware := http.Handler(AddTracing(config.TracingConfig{}, "test-tracer", "test-span")(handler))
middleware := http.Handler(AddTracing("test-tracer", "test-span")(handler))

// Serve the request through the middleware
middleware.ServeHTTP(rr, req)
Expand All @@ -54,43 +49,3 @@ func TestAddTracing(t *testing.T) {
t.Errorf("Expected Content-Type header %v, but got %v", "application/text", typeHeader)
}
}

func TestNewIDs(t *testing.T) {
gen := NewTraceIDGenerator("")
n := 1000

for i := 0; i < n; i++ {
traceID, spanID := gen.NewIDs(context.Background())
assert.Truef(t, traceID.IsValid(), "trace id: %s", traceID.String())
assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String())
}
}

func TestNewSpanID(t *testing.T) {
gen := NewTraceIDGenerator("")
testTraceID := [16]byte{123, 123}
n := 1000

for i := 0; i < n; i++ {
spanID := gen.NewSpanID(context.Background(), testTraceID)
assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String())
}
}

func TestNewSpanIDWithInvalidTraceID(t *testing.T) {
gen := NewTraceIDGenerator("")
spanID := gen.NewSpanID(context.Background(), trace.TraceID{})
assert.Truef(t, spanID.IsValid(), "span id: %s", spanID.String())
}

func TestTraceIDWithGivenHeaderValue(t *testing.T) {
traceHeader := "X-Trace-ID"
traceID := "9b8eac67e332c6f8baf1e013de6891bb"

gen := NewTraceIDGenerator(traceHeader)

ctx := context.WithValue(context.Background(), traceHeader, traceID)
genTraceID, _ := gen.NewIDs(ctx)
assert.Truef(t, genTraceID.IsValid(), "trace id: %s", genTraceID.String())
assert.Equal(t, traceID, genTraceID.String())
}
33 changes: 16 additions & 17 deletions pkg/routers/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ func forbiddenHandler(message string) http.HandlerFunc {
}

// NewDefaultAPIRouter creates a new router with the default backing optimizely.Cache
func NewDefaultAPIRouter(optlyCache optimizely.Cache, agentConf config.AgentConfig, metricsRegistry *metrics.Registry) http.Handler {
conf := agentConf.API
func NewDefaultAPIRouter(optlyCache optimizely.Cache, conf config.APIConfig, metricsRegistry *metrics.Registry) http.Handler {
authProvider := middleware.NewAuth(&conf.Auth)
if authProvider == nil {
log.Error().Msg("unable to initialize api auth middleware.")
Expand Down Expand Up @@ -109,19 +108,19 @@ func NewDefaultAPIRouter(optlyCache optimizely.Cache, agentConf config.AgentConf
corsHandler: corsHandler,
}

return NewAPIRouter(spec, agentConf.Tracing)
return NewAPIRouter(spec)
}

// NewAPIRouter returns HTTP API router backed by an optimizely.Cache implementation
func NewAPIRouter(opt *APIOptions, traceConf config.TracingConfig) *chi.Mux {
func NewAPIRouter(opt *APIOptions) *chi.Mux {
r := chi.NewRouter()
WithAPIRouter(opt, r, traceConf)
WithAPIRouter(opt, r)
return r
}

// WithAPIRouter appends routes and middleware to the given router.
// See https://godoc.org/github.com/go-chi/chi/v5#Mux.Group for usage
func WithAPIRouter(opt *APIOptions, r chi.Router, traceConf config.TracingConfig) {
func WithAPIRouter(opt *APIOptions, r chi.Router) {
getConfigTimer := middleware.Metricize("get-config", opt.metricsRegistry)
getDatafileTimer := middleware.Metricize("get-datafile", opt.metricsRegistry)
activateTimer := middleware.Metricize("activate", opt.metricsRegistry)
Expand All @@ -134,17 +133,17 @@ func WithAPIRouter(opt *APIOptions, r chi.Router, traceConf config.TracingConfig
createAccesstokenTimer := middleware.Metricize("create-api-access-token", opt.metricsRegistry)
contentTypeMiddleware := chimw.AllowContentType("application/json")

configTracer := middleware.AddTracing(traceConf, "configHandler", "OptimizelyConfig")
datafileTracer := middleware.AddTracing(traceConf, "datafileHandler", "OptimizelyDatafile")
activateTracer := middleware.AddTracing(traceConf, "activateHandler", "Activate")
decideTracer := middleware.AddTracing(traceConf, "decideHandler", "Decide")
trackTracer := middleware.AddTracing(traceConf, "trackHandler", "Track")
overrideTracer := middleware.AddTracing(traceConf, "overrideHandler", "Override")
lookupTracer := middleware.AddTracing(traceConf, "lookupHandler", "Lookup")
saveTracer := middleware.AddTracing(traceConf, "saveHandler", "Save")
sendOdpEventTracer := middleware.AddTracing(traceConf, "sendOdpEventHandler", "SendOdpEvent")
nStreamTracer := middleware.AddTracing(traceConf, "notificationHandler", "SendNotificationEvent")
authTracer := middleware.AddTracing(traceConf, "authHandler", "AuthToken")
configTracer := middleware.AddTracing("configHandler", "OptimizelyConfig")
datafileTracer := middleware.AddTracing("datafileHandler", "OptimizelyDatafile")
activateTracer := middleware.AddTracing("activateHandler", "Activate")
decideTracer := middleware.AddTracing("decideHandler", "Decide")
trackTracer := middleware.AddTracing("trackHandler", "Track")
overrideTracer := middleware.AddTracing("overrideHandler", "Override")
lookupTracer := middleware.AddTracing("lookupHandler", "Lookup")
saveTracer := middleware.AddTracing("saveHandler", "Save")
sendOdpEventTracer := middleware.AddTracing("sendOdpEventHandler", "SendOdpEvent")
nStreamTracer := middleware.AddTracing("notificationHandler", "SendNotificationEvent")
authTracer := middleware.AddTracing("authHandler", "AuthToken")

if opt.maxConns > 0 {
// Note this is NOT a rate limiter, but a concurrency threshold
Expand Down
12 changes: 6 additions & 6 deletions pkg/routers/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (suite *APIV1TestSuite) SetupTest() {
corsHandler: testCorsHandler,
}

suite.mux = NewAPIRouter(opts, config.TracingConfig{})
suite.mux = NewAPIRouter(opts)
}

func (suite *APIV1TestSuite) TestValidRoutes() {
Expand All @@ -138,7 +138,7 @@ func (suite *APIV1TestSuite) TestValidRoutes() {
}
return http.HandlerFunc(fn)
}
suite.mux = NewAPIRouter(opts, config.TracingConfig{})
suite.mux = NewAPIRouter(opts)

routes := []struct {
method string
Expand Down Expand Up @@ -328,7 +328,7 @@ func TestAPIV1TestSuite(t *testing.T) {
}

func TestNewDefaultAPIV1Router(t *testing.T) {
client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{}, metricsRegistry)
client := NewDefaultAPIRouter(MockCache{}, config.APIConfig{}, metricsRegistry)
assert.NotNil(t, client)
}

Expand All @@ -353,7 +353,7 @@ func TestNewDefaultAPIV1RouterInvalidHandlerConfig(t *testing.T) {
EnableNotifications: false,
EnableOverrides: false,
}
client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{API: invalidAPIConfig}, metricsRegistry)
client := NewDefaultAPIRouter(MockCache{}, invalidAPIConfig, metricsRegistry)
assert.Nil(t, client)
}

Expand All @@ -368,12 +368,12 @@ func TestNewDefaultClientRouterInvalidMiddlewareConfig(t *testing.T) {
EnableNotifications: false,
EnableOverrides: false,
}
client := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{API: invalidAPIConfig}, metricsRegistry)
client := NewDefaultAPIRouter(MockCache{}, invalidAPIConfig, metricsRegistry)
assert.Nil(t, client)
}

func TestForbiddenRoutes(t *testing.T) {
mux := NewDefaultAPIRouter(MockCache{}, config.AgentConfig{}, metricsRegistry)
mux := NewDefaultAPIRouter(MockCache{}, config.APIConfig{}, metricsRegistry)

routes := []struct {
method string
Expand Down

0 comments on commit a221e9d

Please sign in to comment.