diff --git a/balancer/rls/control_channel.go b/balancer/rls/control_channel.go index 60e6a021d133..e8e0c980525a 100644 --- a/balancer/rls/control_channel.go +++ b/balancer/rls/control_channel.go @@ -21,6 +21,7 @@ package rls import ( "context" "fmt" + "sync" "time" "google.golang.org/grpc" @@ -29,7 +30,6 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal" - "google.golang.org/grpc/internal/buffer" internalgrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/pretty" @@ -44,6 +44,16 @@ type adaptiveThrottler interface { RegisterBackendResponse(throttled bool) } +// newConnectivityStateSubscriber is a variable that can be overridden in tests +// to wrap the connectivity state subscriber for testing purposes. +var newConnectivityStateSubscriber = connStateSubscriber + +// connStateSubscriber returns the subscriber as-is. This function can be +// overridden in tests to wrap the subscriber. +func connStateSubscriber(sub grpcsync.Subscriber) grpcsync.Subscriber { + return sub +} + // controlChannel is a wrapper around the gRPC channel to the RLS server // specified in the service config. type controlChannel struct { @@ -57,12 +67,14 @@ type controlChannel struct { // hammering the RLS service while it is overloaded or down. throttler adaptiveThrottler - cc *grpc.ClientConn - client rlsgrpc.RouteLookupServiceClient - logger *internalgrpclog.PrefixLogger - connectivityStateCh *buffer.Unbounded - unsubscribe func() - monitorDoneCh chan struct{} + cc *grpc.ClientConn + client rlsgrpc.RouteLookupServiceClient + logger *internalgrpclog.PrefixLogger + unsubscribe func() + + // All fields below are guarded by mu. + mu sync.Mutex + seenTransientFailure bool } // newControlChannel creates a controlChannel to rlsServerName and uses @@ -70,11 +82,9 @@ type controlChannel struct { // gRPC channel. func newControlChannel(rlsServerName, serviceConfig string, rpcTimeout time.Duration, bOpts balancer.BuildOptions, backToReadyFunc func()) (*controlChannel, error) { ctrlCh := &controlChannel{ - rpcTimeout: rpcTimeout, - backToReadyFunc: backToReadyFunc, - throttler: newAdaptiveThrottler(), - connectivityStateCh: buffer.NewUnbounded(), - monitorDoneCh: make(chan struct{}), + rpcTimeout: rpcTimeout, + backToReadyFunc: backToReadyFunc, + throttler: newAdaptiveThrottler(), } ctrlCh.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[rls-control-channel %p] ", ctrlCh)) @@ -88,11 +98,10 @@ func newControlChannel(rlsServerName, serviceConfig string, rpcTimeout time.Dura } // Subscribe to connectivity state before connecting to avoid missing initial // updates, which are only delivered to active subscribers. - ctrlCh.unsubscribe = internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(ctrlCh.cc, ctrlCh) + ctrlCh.unsubscribe = internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(ctrlCh.cc, newConnectivityStateSubscriber(ctrlCh)) ctrlCh.cc.Connect() ctrlCh.client = rlsgrpc.NewRouteLookupServiceClient(ctrlCh.cc) ctrlCh.logger.Infof("Control channel created to RLS server at: %v", rlsServerName) - go ctrlCh.monitorConnectivityState() return ctrlCh, nil } @@ -101,7 +110,40 @@ func (cc *controlChannel) OnMessage(msg any) { if !ok { panic(fmt.Sprintf("Unexpected message type %T , wanted connectectivity.State type", msg)) } - cc.connectivityStateCh.Put(st) + + cc.mu.Lock() + defer cc.mu.Unlock() + + switch st { + case connectivity.Ready: + // Only reset backoff when transitioning from TRANSIENT_FAILURE to READY. + // This indicates the RLS server has recovered from being unreachable, so + // we reset backoff state in all cache entries to allow pending RPCs to + // proceed immediately. We skip benign transitions like READY → IDLE → READY + // since those don't represent actual failures. + if cc.seenTransientFailure { + if cc.logger.V(2) { + cc.logger.Infof("Control channel back to READY after TRANSIENT_FAILURE") + } + cc.seenTransientFailure = false + if cc.backToReadyFunc != nil { + cc.backToReadyFunc() + } + } else { + if cc.logger.V(2) { + cc.logger.Infof("Control channel is READY") + } + } + case connectivity.TransientFailure: + // Track that we've entered TRANSIENT_FAILURE state so we know to reset + // backoffs when we recover to READY. + cc.logger.Warningf("Control channel is TRANSIENT_FAILURE") + cc.seenTransientFailure = true + default: + if cc.logger.V(2) { + cc.logger.Infof("Control channel connectivity state is %s", st) + } + } } // dialOpts constructs the dial options for the control plane channel. @@ -148,68 +190,8 @@ func (cc *controlChannel) dialOpts(bOpts balancer.BuildOptions, serviceConfig st return dopts, nil } -func (cc *controlChannel) monitorConnectivityState() { - cc.logger.Infof("Starting connectivity state monitoring goroutine") - defer close(cc.monitorDoneCh) - - // Since we use two mechanisms to deal with RLS server being down: - // - adaptive throttling for the channel as a whole - // - exponential backoff on a per-request basis - // we need a way to avoid double-penalizing requests by counting failures - // toward both mechanisms when the RLS server is unreachable. - // - // To accomplish this, we monitor the state of the control plane channel. If - // the state has been TRANSIENT_FAILURE since the last time it was in state - // READY, and it then transitions into state READY, we push on a channel - // which is being read by the LB policy. - // - // The LB the policy will iterate through the cache to reset the backoff - // timeouts in all cache entries. Specifically, this means that it will - // reset the backoff state and cancel the pending backoff timer. Note that - // when cancelling the backoff timer, just like when the backoff timer fires - // normally, a new picker is returned to the channel, to force it to - // re-process any wait-for-ready RPCs that may still be queued if we failed - // them while we were in backoff. However, we should optimize this case by - // returning only one new picker, regardless of how many backoff timers are - // cancelled. - - // Wait for the control channel to become READY for the first time. - for s, ok := <-cc.connectivityStateCh.Get(); s != connectivity.Ready; s, ok = <-cc.connectivityStateCh.Get() { - if !ok { - return - } - - cc.connectivityStateCh.Load() - if s == connectivity.Shutdown { - return - } - } - cc.connectivityStateCh.Load() - cc.logger.Infof("Connectivity state is READY") - - for { - s, ok := <-cc.connectivityStateCh.Get() - if !ok { - return - } - cc.connectivityStateCh.Load() - - if s == connectivity.Shutdown { - return - } - if s == connectivity.Ready { - cc.logger.Infof("Control channel back to READY") - cc.backToReadyFunc() - } - - cc.logger.Infof("Connectivity state is %s", s) - } -} - func (cc *controlChannel) close() { cc.unsubscribe() - cc.connectivityStateCh.Close() - <-cc.monitorDoneCh cc.cc.Close() cc.logger.Infof("Shutdown") } diff --git a/balancer/rls/control_channel_test.go b/balancer/rls/control_channel_test.go index 5a30820c3b47..acd278279d30 100644 --- a/balancer/rls/control_channel_test.go +++ b/balancer/rls/control_channel_test.go @@ -26,6 +26,7 @@ import ( "fmt" "os" "regexp" + "sync" "testing" "time" @@ -33,9 +34,12 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpcsync" rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" + "google.golang.org/grpc/internal/testutils" rlstest "google.golang.org/grpc/internal/testutils/rls" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -463,3 +467,257 @@ func (s) TestNewControlChannelUnsupportedCredsBundle(t *testing.T) { t.Fatal("newControlChannel succeeded when expected to fail") } } + +// wrappingConnectivityStateSubscriber wraps a connectivity state subscriber +// and exposes state changes to tests via a channel. +type wrappingConnectivityStateSubscriber struct { + delegate grpcsync.Subscriber + connStateCh chan connectivity.State +} + +func (w *wrappingConnectivityStateSubscriber) OnMessage(msg any) { + w.delegate.OnMessage(msg) + w.connStateCh <- msg.(connectivity.State) +} + +// TestControlChannelConnectivityStateTransitions_TransientFailure verifies that +// the control channel resets backoff when recovering from TRANSIENT_FAILURE. +// It stops the RLS server to trigger TRANSIENT_FAILURE, then restarts it and +// verifies that backoff is reset when the channel becomes READY again. +func (s) TestControlChannelConnectivityStateTransitions_TransientFailure(t *testing.T) { + // Create a restartable listener for the RLS server. + l, err := testutils.LocalTCPListener() + if err != nil { + t.Fatalf("net.Listen() failed: %v", err) + } + lis := testutils.NewRestartableListener(l) + + // Start an RLS server with the restartable listener. + rlsServer, _ := rlstest.SetupFakeRLSServer(t, lis) + + // Override the connectivity state subscriber to wrap it for testing. + wrappedSubscriber := &wrappingConnectivityStateSubscriber{connStateCh: make(chan connectivity.State, 10)} + origConnectivityStateSubscriber := newConnectivityStateSubscriber + newConnectivityStateSubscriber = func(delegate grpcsync.Subscriber) grpcsync.Subscriber { + wrappedSubscriber.delegate = delegate + return wrappedSubscriber + } + defer func() { newConnectivityStateSubscriber = origConnectivityStateSubscriber }() + + // Setup callback to track invocations. + var mu sync.Mutex + var callbackCount int + callback := func() { + mu.Lock() + callbackCount++ + mu.Unlock() + } + + // Create control channel. + ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{}, callback) + if err != nil { + t.Fatalf("Failed to create control channel: %v", err) + } + defer ctrlCh.close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Verify that the control channel moves to READY. + wantStates := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + } + for _, wantState := range wantStates { + select { + case gotState := <-wrappedSubscriber.connStateCh: + if gotState != wantState { + t.Fatalf("Unexpected connectivity state: got %v, want %v", gotState, wantState) + } + case <-ctx.Done(): + t.Fatalf("Timeout waiting for RLS control channel to become %q", wantState) + } + } + + // Verify no callbacks have been invoked yet (initial READY doesn't trigger callback). + mu.Lock() + if callbackCount != 0 { + mu.Unlock() + t.Fatalf("Got %d callback invocations for initial READY, want 0", callbackCount) + } + mu.Unlock() + + // Stop the RLS server to trigger TRANSIENT_FAILURE. + lis.Stop() + + // Verify that the control channel moves to IDLE. + wantStates = []connectivity.State{ + connectivity.Idle, + } + for _, wantState := range wantStates { + select { + case gotState := <-wrappedSubscriber.connStateCh: + if gotState != wantState { + t.Fatalf("Unexpected connectivity state: got %v, want %v", gotState, wantState) + } + case <-ctx.Done(): + t.Fatalf("Timeout waiting for RLS control channel to become %q", wantState) + } + } + + // Trigger a reconnection attempt by making a lookup (which will fail). + // This should cause the channel to attempt to reconnect and move to TRANSIENT_FAILURE. + ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, "", func(_ []string, _ string, _ error) {}) + + // Verify that the control channel moves to TRANSIENT_FAILURE. + wantStates = []connectivity.State{ + connectivity.Connecting, + connectivity.TransientFailure, + } + for _, wantState := range wantStates { + select { + case gotState := <-wrappedSubscriber.connStateCh: + if gotState != wantState { + t.Fatalf("Unexpected connectivity state: got %v, want %v", gotState, wantState) + } + case <-ctx.Done(): + t.Fatalf("Timeout waiting for RLS control channel to become %q", wantState) + } + } + + // Restart the RLS server. + lis.Restart() + + // The control channel should eventually reconnect and move to READY. + // This transition from TRANSIENT_FAILURE → READY should trigger the callback. + // We drain states until we see READY, as the channel may go through intermediate + // states (CONNECTING) very quickly after restart. + for { + select { + case gotState := <-wrappedSubscriber.connStateCh: + if gotState == connectivity.Ready { + goto ready + } + case <-ctx.Done(): + t.Fatalf("Timeout waiting for RLS control channel to become READY") + } + } +ready: + + // Verify that the callback was invoked exactly once (for TRANSIENT_FAILURE → READY). + mu.Lock() + got := callbackCount + mu.Unlock() + if got != 1 { + t.Fatalf("Got %d callback invocations, want 1", got) + } +} + +// TestControlChannelConnectivityStateTransitions_IdleDoesNotTriggerCallback +// verifies that IDLE → READY transitions do not trigger backoff reset callbacks. +func (s) TestControlChannelConnectivityStateTransitions_IdleDoesNotTriggerCallback(t *testing.T) { + // Create a restartable listener for the RLS server. + l, err := testutils.LocalTCPListener() + if err != nil { + t.Fatalf("net.Listen() failed: %v", err) + } + lis := testutils.NewRestartableListener(l) + + // Start an RLS server with the restartable listener. + rlsServer, _ := rlstest.SetupFakeRLSServer(t, lis) + + // Override the connectivity state subscriber to wrap it for testing. + wrappedSubscriber := &wrappingConnectivityStateSubscriber{connStateCh: make(chan connectivity.State, 10)} + origConnectivityStateSubscriber := newConnectivityStateSubscriber + newConnectivityStateSubscriber = func(delegate grpcsync.Subscriber) grpcsync.Subscriber { + wrappedSubscriber.delegate = delegate + return wrappedSubscriber + } + defer func() { newConnectivityStateSubscriber = origConnectivityStateSubscriber }() + + // Setup callback to track invocations. + var mu sync.Mutex + var callbackCount int + callback := func() { + mu.Lock() + callbackCount++ + mu.Unlock() + } + + // Create control channel. + ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{}, callback) + if err != nil { + t.Fatalf("Failed to create control channel: %v", err) + } + defer ctrlCh.close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Verify that the control channel moves to READY. + wantStates := []connectivity.State{ + connectivity.Connecting, + connectivity.Ready, + } + for _, wantState := range wantStates { + select { + case gotState := <-wrappedSubscriber.connStateCh: + if gotState != wantState { + t.Fatalf("Unexpected connectivity state: got %v, want %v", gotState, wantState) + } + case <-ctx.Done(): + t.Fatalf("Timeout waiting for RLS control channel to become %q", wantState) + } + } + + // Stop the RLS server (without triggering TRANSIENT_FAILURE first). + lis.Stop() + + // Verify that the control channel moves to IDLE. + wantStates = []connectivity.State{ + connectivity.Idle, + } + for _, wantState := range wantStates { + select { + case gotState := <-wrappedSubscriber.connStateCh: + if gotState != wantState { + t.Fatalf("Unexpected connectivity state: got %v, want %v", gotState, wantState) + } + case <-ctx.Done(): + t.Fatalf("Timeout waiting for RLS control channel to become %q", wantState) + } + } + + // Restart the RLS server before the channel goes to TRANSIENT_FAILURE. + lis.Restart() + + // Trigger a reconnection by making a lookup. + ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, "", func(_ []string, _ string, _ error) {}) + + // The control channel should reconnect and move to READY. + // This transition from IDLE → READY should NOT trigger the callback. + // We drain states until we see READY, as the channel may go through intermediate + // states (CONNECTING) very quickly. + for { + select { + case gotState := <-wrappedSubscriber.connStateCh: + if gotState == connectivity.Ready { + goto idleready + } + case <-ctx.Done(): + t.Fatalf("Timeout waiting for RLS control channel to become READY") + } + } +idleready: + + // Wait a bit to ensure no callback is triggered. + time.Sleep(100 * time.Millisecond) + + // Verify that the callback was never invoked (IDLE → READY doesn't trigger callback). + mu.Lock() + got := callbackCount + mu.Unlock() + if got != 0 { + t.Fatalf("Got %d callback invocations for IDLE → READY, want 0", got) + } +}