diff --git a/service/sharddistributor/client/spectatorclient/peer_chooser.go b/service/sharddistributor/client/spectatorclient/peer_chooser.go new file mode 100644 index 00000000000..a46a60ac8cf --- /dev/null +++ b/service/sharddistributor/client/spectatorclient/peer_chooser.go @@ -0,0 +1,178 @@ +package spectatorclient + +import ( + "context" + "fmt" + "sync" + + "go.uber.org/fx" + "go.uber.org/yarpc/api/peer" + "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/peer/hostport" + "go.uber.org/yarpc/yarpcerrors" + + "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/log/tag" +) + +const ( + NamespaceHeader = "x-shard-distributor-namespace" + grpcAddressMetadataKey = "grpc_address" +) + +// SpectatorPeerChooserInterface extends peer.Chooser with SetSpectators method +type SpectatorPeerChooserInterface interface { + peer.Chooser + SetSpectators(spectators *Spectators) +} + +// SpectatorPeerChooser is a peer.Chooser that uses the Spectator to route requests +// to the correct executor based on shard ownership. +// This is the shard distributor equivalent of Cadence's RingpopPeerChooser. +// +// Flow: +// 1. Client calls RPC with yarpc.WithShardKey("shard-key") +// 2. Choose() is called with req.ShardKey = "shard-key" +// 3. Query Spectator for shard owner +// 4. Extract grpc_address from owner metadata +// 5. Create/reuse peer for that address +// 6. Return peer to YARPC for connection +type SpectatorPeerChooser struct { + spectators *Spectators + transport peer.Transport + logger log.Logger + namespace string + + peersMutex sync.RWMutex + peers map[string]peer.Peer // grpc_address -> peer +} + +type SpectatorPeerChooserParams struct { + fx.In + Transport peer.Transport + Logger log.Logger +} + +// NewSpectatorPeerChooser creates a new peer chooser that routes based on shard distributor ownership +func NewSpectatorPeerChooser( + params SpectatorPeerChooserParams, +) SpectatorPeerChooserInterface { + return &SpectatorPeerChooser{ + transport: params.Transport, + logger: params.Logger, + peers: make(map[string]peer.Peer), + } +} + +// Start satisfies the peer.Chooser interface +func (c *SpectatorPeerChooser) Start() error { + c.logger.Info("Starting shard distributor peer chooser", tag.ShardNamespace(c.namespace)) + return nil +} + +// Stop satisfies the peer.Chooser interface +func (c *SpectatorPeerChooser) Stop() error { + c.logger.Info("Stopping shard distributor peer chooser", tag.ShardNamespace(c.namespace)) + + // Release all peers + c.peersMutex.Lock() + defer c.peersMutex.Unlock() + + for addr, p := range c.peers { + if err := c.transport.ReleasePeer(p, &noOpSubscriber{}); err != nil { + c.logger.Error("Failed to release peer", tag.Error(err), tag.Address(addr)) + } + } + c.peers = make(map[string]peer.Peer) + + return nil +} + +// IsRunning satisfies the peer.Chooser interface +func (c *SpectatorPeerChooser) IsRunning() bool { + return true +} + +// Choose returns a peer for the given shard key by: +// 0. Looking up the spectator for the namespace using the x-shard-distributor-namespace header +// 1. Looking up the shard owner via the Spectator +// 2. Extracting the grpc_address from the owner's metadata +// 3. Creating/reusing a peer for that address +// +// The ShardKey in the request is the shard key (e.g., shard ID) +// The function returns +// peer: the peer to use for the request +// onFinish: a function to call when the request is finished (currently no-op) +// err: the error if the request failed +func (c *SpectatorPeerChooser) Choose(ctx context.Context, req *transport.Request) (peer peer.Peer, onFinish func(error), err error) { + if req.ShardKey == "" { + return nil, nil, yarpcerrors.InvalidArgumentErrorf("chooser requires ShardKey to be non-empty") + } + + // Get the spectator for the namespace + namespace, ok := req.Headers.Get(NamespaceHeader) + if !ok || namespace == "" { + return nil, nil, yarpcerrors.InvalidArgumentErrorf("chooser requires x-shard-distributor-namespace header to be non-empty") + } + + spectator, err := c.spectators.ForNamespace(namespace) + if err != nil { + return nil, nil, yarpcerrors.InvalidArgumentErrorf("get spectator for namespace %s: %w", namespace, err) + } + + // Query spectator for shard owner + owner, err := spectator.GetShardOwner(ctx, req.ShardKey) + if err != nil { + return nil, nil, yarpcerrors.UnavailableErrorf("get shard owner for key %s: %v", req.ShardKey, err) + } + + // Extract GRPC address from owner metadata + grpcAddress, ok := owner.Metadata[grpcAddressMetadataKey] + if !ok || grpcAddress == "" { + return nil, nil, yarpcerrors.InternalErrorf("no grpc_address in metadata for executor %s owning shard %s", owner.ExecutorID, req.ShardKey) + } + + // Get peer for this address + peer, err = c.getOrCreatePeer(grpcAddress) + if err != nil { + return nil, nil, yarpcerrors.InternalErrorf("get or create peer for address %s: %v", grpcAddress, err) + } + + return peer, func(error) {}, nil +} + +func (c *SpectatorPeerChooser) SetSpectators(spectators *Spectators) { + c.spectators = spectators +} + +func (c *SpectatorPeerChooser) getOrCreatePeer(grpcAddress string) (peer.Peer, error) { + c.peersMutex.RLock() + peer, ok := c.peers[grpcAddress] + c.peersMutex.RUnlock() + + if ok { + return peer, nil + } + + // Create new peer for this address + c.peersMutex.Lock() + defer c.peersMutex.Unlock() + + // Check again in case another goroutine added it + if peer, ok := c.peers[grpcAddress]; ok { + return peer, nil + } + + peer, err := c.transport.RetainPeer(hostport.Identify(grpcAddress), &noOpSubscriber{}) + if err != nil { + return nil, fmt.Errorf("retain peer: %w", err) + } + + c.peers[grpcAddress] = peer + return peer, nil +} + +// noOpSubscriber is a no-op implementation of peer.Subscriber +type noOpSubscriber struct{} + +func (*noOpSubscriber) NotifyStatusChanged(peer.Identifier) {} diff --git a/service/sharddistributor/client/spectatorclient/peer_chooser_test.go b/service/sharddistributor/client/spectatorclient/peer_chooser_test.go new file mode 100644 index 00000000000..569c354d01f --- /dev/null +++ b/service/sharddistributor/client/spectatorclient/peer_chooser_test.go @@ -0,0 +1,189 @@ +package spectatorclient + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "go.uber.org/yarpc/api/peer" + "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/transport/grpc" + + "github.com/uber/cadence/common/log/testlogger" +) + +func TestSpectatorPeerChooser_Choose_MissingShardKey(t *testing.T) { + chooser := &SpectatorPeerChooser{ + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + } + + req := &transport.Request{ + ShardKey: "", + Headers: transport.NewHeaders(), + } + + p, onFinish, err := chooser.Choose(context.Background(), req) + + assert.Error(t, err) + assert.Nil(t, p) + assert.Nil(t, onFinish) + assert.Contains(t, err.Error(), "ShardKey") +} + +func TestSpectatorPeerChooser_Choose_MissingNamespaceHeader(t *testing.T) { + chooser := &SpectatorPeerChooser{ + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + } + + req := &transport.Request{ + ShardKey: "shard-1", + Headers: transport.NewHeaders(), + } + + p, onFinish, err := chooser.Choose(context.Background(), req) + + assert.Error(t, err) + assert.Nil(t, p) + assert.Nil(t, onFinish) + assert.Contains(t, err.Error(), "x-shard-distributor-namespace") +} + +func TestSpectatorPeerChooser_Choose_SpectatorNotFound(t *testing.T) { + chooser := &SpectatorPeerChooser{ + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + spectators: &Spectators{spectators: make(map[string]Spectator)}, + } + + req := &transport.Request{ + ShardKey: "shard-1", + Headers: transport.NewHeaders().With(NamespaceHeader, "unknown-namespace"), + } + + p, onFinish, err := chooser.Choose(context.Background(), req) + + assert.Error(t, err) + assert.Nil(t, p) + assert.Nil(t, onFinish) + assert.Contains(t, err.Error(), "spectator not found") +} + +func TestSpectatorPeerChooser_StartStop(t *testing.T) { + chooser := &SpectatorPeerChooser{ + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + } + + err := chooser.Start() + require.NoError(t, err) + + assert.True(t, chooser.IsRunning()) + + err = chooser.Stop() + assert.NoError(t, err) +} + +func TestSpectatorPeerChooser_SetSpectators(t *testing.T) { + chooser := &SpectatorPeerChooser{ + logger: testlogger.New(t), + } + + spectators := &Spectators{spectators: make(map[string]Spectator)} + chooser.SetSpectators(spectators) + + assert.Equal(t, spectators, chooser.spectators) +} + +func TestSpectatorPeerChooser_Choose_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockSpectator := NewMockSpectator(ctrl) + peerTransport := grpc.NewTransport() + + chooser := &SpectatorPeerChooser{ + transport: peerTransport, + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + spectators: &Spectators{ + spectators: map[string]Spectator{ + "test-namespace": mockSpectator, + }, + }, + } + + ctx := context.Background() + req := &transport.Request{ + ShardKey: "shard-1", + Headers: transport.NewHeaders().With(NamespaceHeader, "test-namespace"), + } + + // Mock spectator to return shard owner with grpc_address + mockSpectator.EXPECT(). + GetShardOwner(ctx, "shard-1"). + Return(&ShardOwner{ + ExecutorID: "executor-1", + Metadata: map[string]string{ + grpcAddressMetadataKey: "127.0.0.1:7953", + }, + }, nil) + + // Execute + p, onFinish, err := chooser.Choose(ctx, req) + + // Assert + assert.NoError(t, err) + assert.NotNil(t, p) + assert.NotNil(t, onFinish) + assert.Equal(t, "127.0.0.1:7953", p.Identifier()) + assert.Len(t, chooser.peers, 1) +} + +func TestSpectatorPeerChooser_Choose_ReusesPeer(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockSpectator := NewMockSpectator(ctrl) + peerTransport := grpc.NewTransport() + + chooser := &SpectatorPeerChooser{ + transport: peerTransport, + logger: testlogger.New(t), + peers: make(map[string]peer.Peer), + spectators: &Spectators{ + spectators: map[string]Spectator{ + "test-namespace": mockSpectator, + }, + }, + } + + req := &transport.Request{ + ShardKey: "shard-1", + Headers: transport.NewHeaders().With(NamespaceHeader, "test-namespace"), + } + + // First call creates the peer + mockSpectator.EXPECT(). + GetShardOwner(gomock.Any(), "shard-1"). + Return(&ShardOwner{ + ExecutorID: "executor-1", + Metadata: map[string]string{ + grpcAddressMetadataKey: "127.0.0.1:7953", + }, + }, nil).Times(2) + + firstPeer, _, err := chooser.Choose(context.Background(), req) + require.NoError(t, err) + + // Second call should reuse the same peer + secondPeer, _, err := chooser.Choose(context.Background(), req) + + // Assert - should reuse existing peer + assert.NoError(t, err) + assert.Equal(t, firstPeer, secondPeer) + assert.Len(t, chooser.peers, 1) +}